训练和测试模块

# 训练和测试模块

# 训练模块

只需要通过model.train()就可以将运行模式设置为训练模式。

model.train()
1

# 前向传播

image-20250507112235938

相关解释:

  • batch_id:分成的块序号
  • input、labels: 训 练 数 据X 和 标 签 Y。
  • data_loader:train data
  • target:预测值
  • loss:损失
  • zero_grad:梯度清零

# 后向传播

image-20250507112413605

optimizer:优化器,更新参数

# 测试模块

# model.eval()

要通过model.eval()就可以将运行模式设置为测试模式。保证每个参数都固定,确保每个min-batch的均值和方差都不变,尤其是针对包含Dropout和BatchNormalization的网络,更需要调整网络的模式,避免参数更新。

科普:

  • Dropout:扔掉一些神经元,为了更快
  • BN(BatchNormalization):修正分布。因为到下一层可能就不是标准正态分布了

# with torch.no_grad()

为了确保参数的梯度不进行变化,需要通过with torch.no_grad()模块改变测试状态,在该模块下,所有计算得出的tensor的requires_grad都自动设置为False,不会对模型的权重和偏差求导。

image-20250507112615787

# 测试

image-20250507112917885

# 训练和测试模块的区别

训练模块 测试模块
参数 不变
BN(BatchNormalization) ×
Dropout ×