模型保存和重载

# 模型保存和重载

​ 在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping。

# 保存与重载模块

  1. 保存/加载模型的参数,不保存/加载模型的结构

    import torch
    #保存
    torch.save(model.state_dict(), 'model_params.pth')↓
    #加载
    model=init_model()#先初始化一个模型,这里是伪代码,↓
    model.load_state_dict(torch.load('model_params-pth'))
    
    1
    2
    3
    4
    5
    6

    其中state_dict为参数字典,model_params.pth为保存的文件路径。

  2. 保存/加载模型的参数和结构

    #保存
    torch.save(model, 'model_params.pth')
    #加载
    model = torch.load('model_params.pth')
    
    1
    2
    3
    4

# EarlyStopping

为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。

例如超参数之一:epoch。取值过小可能会导致欠拟合,取值过大可能会导致过拟合。

原理

  1. 将原数据分为训练集和验证集
  2. 只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练
  3. 将停止之后的权重作为网络的最终参数

代码

from pytorchtools import EarlyStopping
early_stopping = EarlyStopping(patience = 20, verbose = False,
delta=0)

for e in range(epoch):
   ... pass
   model.eval() # 设置模型为评估/测试模式
   # 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
   valid_output = model(X_val)
   valid_loss = criterion(valid_output, y_val)	# 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点
   
   early_stopping(valid_loss, model)
   # 若满足 early stopping 要求
   if early_stopping.early_stop:
      print("Early stopping")
      # 结束模型训练
      break
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

参数含义

  • patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。
  • verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。
  • delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型