四种归一化方式介绍
上面的图非常直观地展示了深度学习中四种常见的归一化(Normalization)方法(RMSNorm 可视为 LayerNorm 的变体)。为了便于理解,我们假设输入的是一个图像特征张量(Tensor),其形状为 ( N , C , H , W ) (N, C, H, W) ( N , C , H , W ) ,其中:
N N N (Batch Size) :批量大小(样本数量)。
C C C (Channel) :特征通道数。
H , W H, W H , W (Height, Width) :特征图的空间维度(高和宽)。在图中,这两个维度被合并为了一个轴。
图中的蓝色区域 代表了在进行一次 归一化操作时,计算均值(μ \mu μ )和方差(σ 2 \sigma^2 σ 2 )所涵盖的数据范围。
以下是四种归一化方法的详细理论、操作次数以及具体例子:
1. 批量归一化 (Batch Normalization, BatchNorm)
理论说明:
BatchNorm 是在通道(Channel)维度 上保持独立,而跨越所有的样本(N N N )和空间维度(H , W H, W H , W )进行归一化。换句话说,它将同一个批次中所有图像的同一个通道的特征像素放在一起进行标准化。这种方法在卷积神经网络(CNN)中最为常见,但在 Batch Size 非常小的情况下效果会变差。
归一化操作次数:
一共进行 C C C 次归一化(每个通道一次)。
具体例子:
假设输入张量的形状为 ( 4 , 64 , 32 , 32 ) (4, 64, 32, 32) ( 4 , 64 , 32 , 32 ) ,即 4 张图片,64 个通道,特征图大小为 32 × 32 32 \times 32 32 × 32 。
BatchNorm 会分别对这 64 个通道计算均值和方差。
操作次数:64 次。
每次计算涵盖的元素个数:4 × 32 × 32 = 4 \times 32 \times 32 = 4 × 32 × 32 = 4096 个元素。
2. 层归一化 (Layer Normalization, LayerNorm)
理论说明:
LayerNorm 是在样本(Batch)维度 上保持独立,而跨越所有的通道(C C C )和空间维度(H , W H, W H , W )进行归一化。它将单张图片的所有通道和所有像素放在一起标准化。这种方法不受 Batch Size 大小的影响,常用于循环神经网络(RNN)和 Transformer 架构(如自然语言处理任务)中。
归一化操作次数:
一共进行 N N N 次归一化(每个样本一次)。
具体例子:
假设输入张量形状依然为 ( 4 , 64 , 32 , 32 ) (4, 64, 32, 32) ( 4 , 64 , 32 , 32 ) 。
LayerNorm 会对这 4 张图片,每张图片单独计算一个均值和方差。
操作次数:4 次。
每次计算涵盖的元素个数:64 × 32 × 32 = 64 \times 32 \times 32 = 64 × 32 × 32 = 65536 个元素。
3. 实例归一化 (Instance Normalization, InstanceNorm)
理论说明:
InstanceNorm 在样本(Batch)和通道(Channel)维度 上都保持独立,仅仅在空间维度(H , W H, W H , W )上进行归一化。也就是对单张图片的单个通道的特征图进行标准化。它常用于图像风格迁移(Style Transfer)任务中,因为它能过滤掉图像特定实例的对比度信息,保留内容信息。
归一化操作次数:
一共进行 N × C N \times C N × C 次归一化(每个样本的每个通道一次)。
具体例子:
假设输入张量形状为 ( 4 , 64 , 32 , 32 ) (4, 64, 32, 32) ( 4 , 64 , 32 , 32 ) 。
InstanceNorm 会对每一张图片的每一个通道单独计算均值和方差。
操作次数:4 × 64 = 4 \times 64 = 4 × 64 = 256 次。
每次计算涵盖的元素个数:32 × 32 = 32 \times 32 = 32 × 32 = 1024 个元素。
4. 分组归一化 (Group Normalization, GroupNorm)
理论说明:
GroupNorm 是 LayerNorm 和 InstanceNorm 的折中方案。它首先将所有的通道(C C C )分成 G G G 个组(Group),然后在单个样本(Batch)内 ,跨越各个组内的通道和空间维度(H , W H, W H , W )进行归一化。当 Batch Size 太小无法使用 BatchNorm,且 LayerNorm 或 InstanceNorm 效果不佳时,GroupNorm 是很好的替代方案。
归一化操作次数:
一共进行 N × G N \times G N × G 次归一化(每个样本的每个通道组一次)。
具体例子:
假设输入张量形状为 ( 4 , 64 , 32 , 32 ) (4, 64, 32, 32) ( 4 , 64 , 32 , 32 ) ,我们设定分组数 G = 8 G = 8 G = 8 (此时每个组有 8 个通道)。
GroupNorm 会对 4 张图片,每张图片的 8 个分组分别计算均值和方差。
操作次数:4 × 8 = 4 \times 8 = 4 × 8 = 32 次。
每次计算涵盖的元素个数:8 (组内通道数) × 32 × 32 = 8 \text{ (组内通道数)} \times 32 \times 32 = 8 ( 组内通道数 ) × 32 × 32 = 8192 个元素。
5. 均方根归一化 (Root Mean Square Normalization, RMSNorm)
理论说明:
RMSNorm 是 LayerNorm 的一种高效变体,在近年的大型语言模型(如 LLaMA, Gemma 等)中占据统治地位。研究发现,LayerNorm 成功的关键在于缩放特征的方差,而不是均值中心化(即减去均值)。因此,RMSNorm 极其激进地去掉了计算均值和减去均值的步骤,直接利用数据的均方根 (Root Mean Square) 进行缩放。这样不仅保留了归一化的能力,还减少了约 10%-50% 的计算开销。
归一化操作次数:
与 LayerNorm 类似,通常是在特征维度上进行计算。
具体例子:
在 NLP 任务中,输入通常是 ( N , S e q _ L e n , D ) (N, Seq\_Len, D) ( N , S e q _ L e n , D ) 。RMSNorm 会针对每一个词向量(维度为 D D D )单独计算均方根并进行除法操作。
通用的归一化公式
前面四种归一化方法,底层执行的数学公式是完全相同的,唯一的区别在于参与计算的集合 S i S_i S i 中包含了哪些元素 (即图中蓝色方块的划分方式)。
假设 x x x 为蓝色集合 S i S_i S i 中的某一个元素,m m m 为集合 S i S_i S i 中元素的总个数。归一化步骤如下:
计算均值 (Mean):
μ = 1 m ∑ x ∈ S i x \mu = \frac{1}{m} \sum_{x \in S_i} x
μ = m 1 x ∈ S i ∑ x
计算方差 (Variance):
σ 2 = 1 m ∑ x ∈ S i ( x − μ ) 2 \sigma^2 = \frac{1}{m} \sum_{x \in S_i} (x - \mu)^2
σ 2 = m 1 x ∈ S i ∑ ( x − μ ) 2
标准化 (Normalize):
x ^ = x − μ σ 2 + ϵ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}
x ^ = σ 2 + ϵ x − μ
(其中 ϵ \epsilon ϵ 是一个非常小的常数,用于防止分母为零)
缩放和平移 (Scale and Shift - 仿射变换):
y = γ x ^ + β y = \gamma \hat{x} + \beta
y = γ x ^ + β
(其中 γ \gamma γ 和 β \beta β 是模型在训练过程中通过反向传播学习到的可学习参数,用于恢复网络原本的表达能力)
RMSNorm 的特例公式
RMSNorm 省略了上述的第 1 步和第 2 步,不再计算均值中心化,公式简化为:
计算均方根 (RMS):
R M S ( x ) = 1 m ∑ x ∈ S i x 2 + ϵ RMS(x) = \sqrt{\frac{1}{m} \sum_{x \in S_i} x^2 + \epsilon}
R M S ( x ) = m 1 x ∈ S i ∑ x 2 + ϵ
直接缩放:
y = γ x R M S ( x ) y = \gamma \frac{x}{RMS(x)}
y = γ R M S ( x ) x
(注:通常 RMSNorm 连偏置项 β \beta β 也会省略掉,仅保留缩放参数 γ \gamma γ )
PyTorch 实现
为了加深理解,下面我们不调用 torch.nn 中现成的 BatchNorm2d 等 API,而是利用基本的张量操作(如 mean, var)来手动实现这五种归一化方法。
注:在 PyTorch 中使用 .var() 时,必须设置 unbiased=False,因为归一化算法中的方差计算使用的是有偏估计(分母为 m m m 而不是 m − 1 m-1 m − 1 )。
1. BatchNorm2d 实现 (CV视角)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 import torchimport torch.nn as nnclass MyBatchNorm2d (nn.Module): def __init__ (self, num_features, eps=1e-5 ): super ().__init__() self .gamma = nn.Parameter(torch.ones(1 , num_features, 1 , 1 )) self .beta = nn.Parameter(torch.zeros(1 , num_features, 1 , 1 )) self .eps = eps def forward (self, x ): mean = x.mean(dim=(0 , 2 , 3 ), keepdim=True ) var = x.var(dim=(0 , 2 , 3 ), unbiased=False , keepdim=True ) x_hat = (x - mean) / torch.sqrt(var + self .eps) return self .gamma * x_hat + self .beta
2. LayerNorm 实现 (NLP视角)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class MyLayerNormNLP (nn.Module): def __init__ (self, dim, eps=1e-5 ): super ().__init__() self .gamma = nn.Parameter(torch.ones(dim)) self .beta = nn.Parameter(torch.zeros(dim)) self .eps = eps def forward (self, x ): mean = x.mean(dim=-1 , keepdim=True ) var = x.var(dim=-1 , unbiased=False , keepdim=True ) x_hat = (x - mean) / torch.sqrt(var + self .eps) return self .gamma * x_hat + self .beta
3. InstanceNorm 实现 (CV视角)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class MyInstanceNorm2d (nn.Module): def __init__ (self, num_features, eps=1e-5 ): super ().__init__() self .gamma = nn.Parameter(torch.ones(1 , num_features, 1 , 1 )) self .beta = nn.Parameter(torch.zeros(1 , num_features, 1 , 1 )) self .eps = eps def forward (self, x ): mean = x.mean(dim=(2 , 3 ), keepdim=True ) var = x.var(dim=(2 , 3 ), unbiased=False , keepdim=True ) x_hat = (x - mean) / torch.sqrt(var + self .eps) return self .gamma * x_hat + self .beta
4. GroupNorm 实现 (CV视角)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 class MyGroupNorm (nn.Module): def __init__ (self, num_groups, num_channels, eps=1e-5 ): super ().__init__() self .num_groups = num_groups self .gamma = nn.Parameter(torch.ones(1 , num_channels, 1 , 1 )) self .beta = nn.Parameter(torch.zeros(1 , num_channels, 1 , 1 )) self .eps = eps def forward (self, x ): N, C, H, W = x.shape x_reshaped = x.view(N, self .num_groups, C // self .num_groups, H, W) mean = x_reshaped.mean(dim=(2 , 3 , 4 ), keepdim=True ) var = x_reshaped.var(dim=(2 , 3 , 4 ), unbiased=False , keepdim=True ) x_hat = (x_reshaped - mean) / torch.sqrt(var + self .eps) x_hat = x_hat.view(N, C, H, W) return self .gamma * x_hat + self .beta
5. RMSNorm 实现 (NLP视角)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class MyRMSNorm (nn.Module): def __init__ (self, dim, eps=1e-6 ): super ().__init__() self .weight = nn.Parameter(torch.ones(dim)) self .eps = eps def forward (self, x ): rms = torch.sqrt(x.pow (2 ).mean(dim=-1 , keepdim=True ) + self .eps) return self .weight * (x / rms)