训练和测试模块
# 训练和测试模块
# 训练模块
只需要通过model.train()就可以将运行模式设置为训练模式。
model.train()
1
# 前向传播
相关解释:
- batch_id:分成的块序号
- input、labels: 训 练 数 据X 和 标 签 Y。
- data_loader:train data
- target:预测值
- loss:损失
- zero_grad:梯度清零
# 后向传播
optimizer:优化器,更新参数
# 测试模块
# model.eval()
要通过model.eval()就可以将运行模式设置为测试模式。保证每个参数都固定,确保每个min-batch的均值和方差都不变,尤其是针对包含Dropout和BatchNormalization的网络,更需要调整网络的模式,避免参数更新。
科普:
- Dropout:扔掉一些神经元,为了更快
- BN(BatchNormalization):修正分布。因为到下一层可能就不是标准正态分布了
# with torch.no_grad()
为了确保参数的梯度不进行变化,需要通过with torch.no_grad()模块改变测试状态,在该模块下,所有计算得出的tensor的requires_grad都自动设置为False,不会对模型的权重和偏差求导。
# 测试
# 训练和测试模块的区别
训练模块 | 测试模块 | |
---|---|---|
参数 | 变 | 不变 |
BN(BatchNormalization) | √ | × |
Dropout | √ | × |