BatchNorm怎样解决训练和推理时batch size 不同的问题?
BatchNorm怎样解决训练和推理时batch size 不同的问题?
BatchNorm是在batch维度上计算每个相同通道上的均值和方差,通常情况下,训练阶段的batchsize较大,而推理时batchsize基本为1。这样的话,就会导致训练和推理阶段得到不同的标准化,均值和方差时靠每一个mini-batch的统计得到的,因为推理时只有一个样本,在只有1个向量的数据组上进行标准化后,成了一个全0向量,导致模型出现BUG。为了解决这个问题,不改变训练时的BatchNorm计算方式,仅仅改变推理时计算均值和方差方法。
做法就是用训练集来估计总体均值 μ \mu μ和总体标准差 σ \sigma σ。主要有两种方法:简单平均法和移动指数平均
- 简单平均法
把每个mini-batch的均值和方差都保存下来,然后训练完了求均值的均值,方差的均值即可。
- 移动指数平均(Exponential Moving Average)
本文仅以 μ \mu μ的计算为例:
μ t o t a l = d e c a y ∗ μ t o t a l + ( 1 − d e c a y ) ∗ μ \mu_{total}=decay*\mu_{total}+(1-decay)*\mu μtotal=decay∗μtotal+(1−decay)∗μ
其中decay是衰减系数。即总均值 μ t o t a l \mu_{total} μtotal是前一个mini-batch统计的总均值和本次mini-batch的 μ \mu μ加权求和。至于衰减率 decay在区间[0,1]之间,decay越接近1,结果 μ t o t a l \mu_{total} μtotal越稳定,越受较远的大范围的样本影响;decay越接近0,结果 μ t o t a l \mu_{total} μtotal越波动,越受较近的小范围的样本影响。
事实上,简单平均可能更好,简单平均本质上是平均权重,但是简单平均需要保存所有BN层在所有mini-batch上的均值向量和方差向量,如果训练数据量很大,会有较可观的存储代价。移动指数平均在实际的框架中更常见(例如tensorflow),可能的好处是EMA不需要存储每一个mini-batch的值,永远只保存着三个值:总统计值、本batch的统计值,decay系数。
在训练阶段同步获得了 μ t o t a l \mu_{total} μtotal和 σ t o t a l \sigma_{total} σtotal后,在推理时即可对样本进行BN操作。