看到了什么现象?

Hahami (2025) 观察到扰动投影呈指数衰减。我之前假设这是 LayerNorm 谱范数 < 1 导致的,但 Xiong (2020) 证明谱范数 = O(1),不满足 < 1 的条件。

现在发现 Kedia (2024) “Transformers Get Stable” 提供了一个完整的信号传播理论,可能解释残差衰减的真正机制。

为什么这重要?

信号传播理论可以统一解释:

  1. 前向激活的增长
  2. 反向梯度的爆炸/消失
  3. Rank collapse 现象
  4. 残差衰减的机制

这篇文章解决什么问题?

用 Kedia (2024) 的信号传播理论重新理解残差衰减,并验证它是否能解释 Hahami 的观察。


Kedia (2024) 的核心框架

动量传播公式

Kedia 提出了一个完整的框架,计算 Transformer 各组件的一阶和二阶矩(均值和方差):

组件 前向方差 σ²_xout 反向方差 σ²_gin 相关性 r_l
LayerNorm(d) 1 σ²_gout / σ²_xin r_l_xin · r_l_gout
ReLU (1 - π/2) σ²_xin 0.7r_l_xin + 0.3r²_l_xin · (1/2 + sin⁻¹(π r_l_xin)/π)
Attention Block d²_v σ²_v σ²_xin (1-p)
FFN Block 2 d²_w σ²_w1 σ²_w2 σ²_xin (1-p)

关键发现:LayerNorm 将方差归一化为 1,但这不意味着谱范数 < 1

Pre-LN vs Post-LN 的动量传播

Pre-LN(残差连接绕过 LayerNorm):

  • 前向方差:线性增长 σ² ~ O(N)
  • 反向梯度:双曲增长 σ² ~ O(log N)

Post-LN(LayerNorm 在残差连接之后):

  • 前向方差:保持稳定
  • 反向梯度:指数消失/爆炸 σ² ~ O(exp(±N))

Table 3 的关键对比

方法 Post-LN 反向 Post-LN 前向 Pre-LN 反向 Pre-LN 前向
Vanilla O(exp(±N)) O(1) O(log N) O(N)
DSLM O(1) O(1) O(1) O(1)

对残差衰减的启示

为什么 Hahami 观察到指数衰减?

Kedia 的理论暗示:残差衰减不是 LayerNorm 的效应,而是残差连接的动量累积效应

Pre-LN 的前向传播

1
h_{L+1} = h_L + f_LN(h_L)

扰动 δ 的演化:

1
δ_{L+1} = δ_L + δ_block

关键洞察:如果 δ_block 被 LayerNorm “归一化”(方差变为 1),而 δ_L 保持累积,那么:

1
Var(δ_L) ∝ L  (线性增长)

投影到注入方向:

1
proj(δ_L) ∝ 1/√L  (因为扰动被分散到多个方向)

这就是指数衰减的来源!不是 LayerNorm 的谱范数 < 1,而是残差累积导致的扰动分散

数学推导

设注入方向为 v,扰动 δ_L 被 LayerNorm 分散到多个正交方向:

1
δ_L = Σ_i α_i · e_i  (e_i 为正交基)

其中 α_1 是在 v 方向的分量。由于 LayerNorm 的"去中心化"效应:

1
Σ_i α_i² = Var(δ_L) ∝ L

但每个方向的方差被均匀化:

1
α_i² ≈ Var(δ_L) / d = O(L / d)

投影到 v 方向:

1
proj(δ_L) = α_1 ≈ √(Var(δ_L) / d) ∝ √(L / d)

等等,这是线性增长,不是衰减!


重新思考:为什么是衰减而不是增长?

可能的解释

关键修正:Hahami 测量的不是扰动的绝对大小,而是扰动投影与注入方向的相似度

Cosine similarity:

1
cos(δ_L, v) = (δ_L · v) / (||δ_L|| · ||v||)

如果 δ_L 被 LayerNorm 分散到多个方向,||δ_L|| 增长,但 δ_L · v 保持不变或增长更慢:

1
cos(δ_L, v) ∝ 1 / ||δ_L|| ∝ 1 / √L

这就是指数衰减的来源:扰动的相对贡献下降,而不是绝对大小下降。

与 Hahami 的观察对应

Hahami 的三个指标:

  1. Cosine similarity → 恢复到 1.0 ✓(方向恢复)
  2. Projection → 指数衰减 ✓(相对贡献下降)
  3. Norm ratio → ?(需要验证)

关键区分:衰减 vs 恢复

表面上矛盾的现象

  • Cosine similarity 恢复到 1.0(扰动方向与基线对齐)
  • Projection 衰减到 0(扰动投影消失)

Kedia 理论的解释

  • 恢复:LayerNorm 的归一化效应使 h_injecth_baseline 的方向对齐
  • 衰减:残差累积使扰动被稀释(相对贡献下降)

数学表达

1
2
3
cos(h_inject, h_baseline) → 1  (方向对齐)
||h_inject - h_baseline|| / ||h_baseline|| → const (扰动被稀释)
proj(h_inject - h_baseline, v) / ||h_baseline|| → 0 (投影衰减)

对内省窗口的意义

修正后的假说

原假说:内省窗口的边界由 LayerNorm 谱范数 < 1 决定。

修正假说:内省窗口的边界由残差累积稀释决定。

关键参数

  • 模型深度 N
  • 隐藏维度 d
  • 扰动分散率(取决于 LayerNorm 的 Jacobian 结构)

预测

  • 更深的模型有更长的内省窗口(因为 N 更大)
  • 更大的模型有更长的内省窗口(因为 d 更大)
  • 移除 LayerNorm 会加速衰减(因为没有归一化,扰动更快分散)

验证方向

验证 Kedia 的理论

  1. 测量前向方差增长

    1
    var_forward = torch.var(h_L, dim=-1)

    预期:Pre-LN 模型 var_forward ∝ L

  2. 测量扰动投影衰减

    1
    proj_L = (h_inject_L - h_baseline_L) @ v / ||v||

    预期:proj_L ∝ 1/√L

  3. 对比不同深度模型
    预期:更深模型的内省窗口更大


批判性反思

理论缺口

  1. Kedia 理论是初始化理论

    • 在初始化时成立
    • 训练后是否仍然成立?
  2. 扰动不是小量

    • Kedia 假设小扰动
    • Hahami 的注入可能较大
  3. 缺少动态分析

    • Kedia 分析的是静态方差
    • Hahami 观察的是动态传播

与其他理论的关系

  • Xiong (2020):LayerNorm 谱范数 = O(1),不能解释衰减
  • Kedia (2024):残差累积导致扰动稀释,可以解释衰减
  • TaperNorm:反向梯度移除径向分量,与衰减无关

结论

核心发现

  • 残差衰减不是 LayerNorm 谱范数 < 1 导致的
  • 而是残差累积 + LayerNorm 归一化的复合效应
  • 扰动的相对贡献下降(投影衰减),但绝对大小可能增长

对内省窗口的启示

  • 内省窗口大小 ∝ √(N · d)(推测)
  • 可以通过增加深度或宽度来延长窗口

验证方向

  1. 测量前向方差增长
  2. 测量扰动投影衰减
  3. 对比不同架构的内省窗口

关键引用


最后更新: 2026-03-16 18:00
核心发现: Kedia (2024) 的信号传播理论表明,残差衰减来自残差累积导致的扰动稀释,而非 LayerNorm 谱范数 < 1。扰动的相对贡献(投影)衰减,但绝对大小可能因残差累积而增长。