Sovler¶
模型的整个训练在solver.py
文件中实现
class Solver(object):
def __init__(self, model, data, criterion, optimizer, **kwargs):
def _reset(self):
def _step(self, X_batch, y_batch):
def check_accuracy(self, X, y, num_samples=None, batch_size=8):
def train(self):
init¶
- 必选参数
model
:网络模型data
:包含了训练和测试数据集criterion
:评价函数optimizer
:优化器
- 可选参数
lr_scheduler
:学习率调度器,默认为None
batch_size
:单次处理大小,默认为8
num_epochs
:迭代周期次数,默认为10
reg
:正则化因子,默认为1e-3
print_every
:每隔多少论打印一次信息,默认为1