1. 从LayerNorm到RMSNorm:为什么我们需要更轻量的归一化?
如果你玩过大型语言模型,比如GPT或者LLaMA,肯定对“归一化”这个词不陌生。它就像是模型训练过程中的“稳定器”,能防止数值爆炸或消失,让训练过程更平稳。在Transformer架构里,LayerNorm(层归一化)几乎是标配,从BERT到早期的GPT,它都立下了汗马功劳。
但不知道你有没有发现,最近几年冒出来的新模型,比如Meta的LLaMA、Google的T5,还有GPT-NeoX,都悄悄换掉了一个关键零件——它们不再用传统的LayerNorm,而是改用了一个叫RMSNorm(Root Mean Square Layer Normalization)的家伙。我第一次在LLaMA的代码里看到RMSNorm时,心里也犯嘀咕:这玩意儿到底有啥魔力,能让这么多顶尖团队放弃经典的LayerNorm?
简单来说,RMSNorm是LayerNorm的一个“简化版”。它把LayerNorm公式里那个“减去均值”的操作给去掉了,只保留了“除以均方根”的缩放部分。听起来好像只是少算了一个均值,能有多大差别?但实际用起来,尤其是在动辄几百亿、上千亿参数的大模型上,这点差别带来的效率提升和内存节省,简直是天壤之别。
我打个比方,LayerNorm就像是一个做事非常严谨的管家,每次都要把数据“居中”(减去均值)再“标准化”(除以标准差),确保一切井井有条。而RMSNorm则像是一个更注重效率的实干家,它认为“居中”这一步在很多情况下不是必须的,只要把数据的“尺度”(scale)控制好就行。这个看似微小的改变,背后其实是对大规模模型训练痛点的精准打击:计算开销和内存占用。
当模型参数规模从几亿膨胀到几千亿,每一个额外的计算操作、每一个可学习的参数,都会被放大成巨大的成本。LayerNorm需要计算均值和方差,这意味着它要对数据进行两次遍历(或者一次遍历但计算更复杂),并且需要维护两个可学习参数(缩放因子γ和偏移量β)。RMSNorm砍掉了均值和偏移量,计算量直接减少了大约15%-20%,参数量也减半。别小看这个比例,在千亿模型上,这节省的可能是几GB的显存和每天数万美元的算力成本。
所以,RMSNorm的核心优势就两个字:高效。它用更少的计算、更少的参数,达到了和LayerNorm相近甚至更好的训练稳定性。接下来,我们就掰开揉碎,看看它是怎么做到的。
2. 拆解RMSNorm:公式、代码与核心思想
要理解RMSNorm,最直接的方式就是看它的公式,并和LayerNorm做个对比。咱们先来回顾一下LayerNorm是怎么做的。
对于一个输入向量 x = [x1, x2, ..., xd],LayerNorm的计算分三步:
- 计算均值:
μ = (1/d) * Σ(xi) - 计算方差:
σ² = (1/d) * Σ((xi - μ)²) - 归一化并仿射变换:
y = γ * ((x - μ) / √(σ² + ε)) + β
这里的 γ 和 β 是可学习的缩放和偏移参数,ε 是一个很小的数(比如1e-5),防止除以零。
RMSNorm做了个大胆的简化,它认为第一步“减去均值”可以省略。它的公式是这样的:
-
计算均方根(RMS):
RMS(x) = √( (1/d) * Σ(xi²) + ε )注意,这里是对xi的平方求平均再开方,没有减去任何均值。 -
归一化并缩放:
y = γ * (x / RMS(x))
看出来区别了吗?RMSNorm的公式里,没有均值 μ,也没有偏移参数 β。它只用一个RMS值来衡量数据的“尺度”,然后用这个尺度去缩放每一个元素,最后乘上一个可学习的缩放因子 γ。
为什么可以去掉均值?论文作者和后续的实践发现,在像Transformer这样的架构中,经过自注意力机制和前馈网络


1153

被折叠的 条评论
为什么被折叠?



