Sovler

模型的整个训练在solver.py文件中实现

class Solver(object):

    def __init__(self, model, data, criterion, optimizer, **kwargs):

    def _reset(self):

    def _step(self, X_batch, y_batch):

    def check_accuracy(self, X, y, num_samples=None, batch_size=8):

    def train(self):

init

  • 必选参数
    • model:网络模型
    • data:包含了训练和测试数据集
    • criterion:评价函数
    • optimizer:优化器
  • 可选参数
    • lr_scheduler:学习率调度器,默认为None
    • batch_size:单次处理大小,默认为8
    • num_epochs:迭代周期次数,默认为10
    • reg:正则化因子,默认为1e-3
    • print_every:每隔多少论打印一次信息,默认为1