模型保存和重载
# 模型保存和重载
在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping。
# 保存与重载模块
保存/加载模型的参数,不保存/加载模型的结构
import torch #保存 torch.save(model.state_dict(), 'model_params.pth')↓ #加载 model=init_model()#先初始化一个模型,这里是伪代码,↓ model.load_state_dict(torch.load('model_params-pth'))
1
2
3
4
5
6其中state_dict为参数字典,model_params.pth为保存的文件路径。
保存/加载模型的参数和结构
#保存 torch.save(model, 'model_params.pth') #加载 model = torch.load('model_params.pth')
1
2
3
4
# EarlyStopping
为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。
例如超参数之一:epoch。取值过小可能会导致欠拟合,取值过大可能会导致过拟合。
原理
- 将原数据分为训练集和验证集;
- 只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练;
- 将停止之后的权重作为网络的最终参数
代码
from pytorchtools import EarlyStopping
early_stopping = EarlyStopping(patience = 20, verbose = False,
delta=0)
for e in range(epoch):
... pass
model.eval() # 设置模型为评估/测试模式
# 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
valid_output = model(X_val)
valid_loss = criterion(valid_output, y_val) # 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点
early_stopping(valid_loss, model)
# 若满足 early stopping 要求
if early_stopping.early_stop:
print("Early stopping")
# 结束模型训练
break
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
参数含义
- patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。
- verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。
- delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型