信号传播理论-Kedia2024如何解释残差衰减
看到了什么现象?
Hahami (2025) 观察到扰动投影呈指数衰减。我之前假设这是 LayerNorm 谱范数 < 1 导致的,但 Xiong (2020) 证明谱范数 = O(1),不满足 < 1 的条件。
现在发现 Kedia (2024) “Transformers Get Stable” 提供了一个完整的信号传播理论,可能解释残差衰减的真正机制。
为什么这重要?
信号传播理论可以统一解释:
- 前向激活的增长
- 反向梯度的爆炸/消失
- Rank collapse 现象
- 残差衰减的机制
这篇文章解决什么问题?
用 Kedia (2024) 的信号传播理论重新理解残差衰减,并验证它是否能解释 Hahami 的观察。
Kedia (2024) 的核心框架
动量传播公式
Kedia 提出了一个完整的框架,计算 Transformer 各组件的一阶和二阶矩(均值和方差):
| 组件 | 前向方差 σ²_xout | 反向方差 σ²_gin | 相关性 r_l |
|---|---|---|---|
| LayerNorm(d) | 1 | σ²_gout / σ²_xin | r_l_xin · r_l_gout |
| ReLU | (1 - π/2) σ²_xin | 0.7r_l_xin + 0.3r²_l_xin · (1/2 + sin⁻¹(π r_l_xin)/π) | |
| Attention Block | d²_v σ²_v σ²_xin (1-p) | … | … |
| FFN Block | 2 d²_w σ²_w1 σ²_w2 σ²_xin (1-p) | … | … |
关键发现:LayerNorm 将方差归一化为 1,但这不意味着谱范数 < 1。
Pre-LN vs Post-LN 的动量传播
Pre-LN(残差连接绕过 LayerNorm):
- 前向方差:线性增长
σ² ~ O(N) - 反向梯度:双曲增长
σ² ~ O(log N)
Post-LN(LayerNorm 在残差连接之后):
- 前向方差:保持稳定
- 反向梯度:指数消失/爆炸
σ² ~ O(exp(±N))
Table 3 的关键对比:
| 方法 | Post-LN 反向 | Post-LN 前向 | Pre-LN 反向 | Pre-LN 前向 |
|---|---|---|---|---|
| Vanilla | O(exp(±N)) | O(1) | O(log N) | O(N) |
| DSLM | O(1) | O(1) | O(1) | O(1) |
对残差衰减的启示
为什么 Hahami 观察到指数衰减?
Kedia 的理论暗示:残差衰减不是 LayerNorm 的效应,而是残差连接的动量累积效应。
Pre-LN 的前向传播:
1 | h_{L+1} = h_L + f_LN(h_L) |
扰动 δ 的演化:
1 | δ_{L+1} = δ_L + δ_block |
关键洞察:如果 δ_block 被 LayerNorm “归一化”(方差变为 1),而 δ_L 保持累积,那么:
1 | Var(δ_L) ∝ L (线性增长) |
但 投影到注入方向:
1 | proj(δ_L) ∝ 1/√L (因为扰动被分散到多个方向) |
这就是指数衰减的来源!不是 LayerNorm 的谱范数 < 1,而是残差累积导致的扰动分散。
数学推导
设注入方向为 v,扰动 δ_L 被 LayerNorm 分散到多个正交方向:
1 | δ_L = Σ_i α_i · e_i (e_i 为正交基) |
其中 α_1 是在 v 方向的分量。由于 LayerNorm 的"去中心化"效应:
1 | Σ_i α_i² = Var(δ_L) ∝ L |
但每个方向的方差被均匀化:
1 | α_i² ≈ Var(δ_L) / d = O(L / d) |
投影到 v 方向:
1 | proj(δ_L) = α_1 ≈ √(Var(δ_L) / d) ∝ √(L / d) |
等等,这是线性增长,不是衰减!
重新思考:为什么是衰减而不是增长?
可能的解释
关键修正:Hahami 测量的不是扰动的绝对大小,而是扰动投影与注入方向的相似度。
Cosine similarity:
1 | cos(δ_L, v) = (δ_L · v) / (||δ_L|| · ||v||) |
如果 δ_L 被 LayerNorm 分散到多个方向,||δ_L|| 增长,但 δ_L · v 保持不变或增长更慢:
1 | cos(δ_L, v) ∝ 1 / ||δ_L|| ∝ 1 / √L |
这就是指数衰减的来源:扰动的相对贡献下降,而不是绝对大小下降。
与 Hahami 的观察对应
Hahami 的三个指标:
- Cosine similarity → 恢复到 1.0 ✓(方向恢复)
- Projection → 指数衰减 ✓(相对贡献下降)
- Norm ratio → ?(需要验证)
关键区分:衰减 vs 恢复
表面上矛盾的现象:
- Cosine similarity 恢复到 1.0(扰动方向与基线对齐)
- Projection 衰减到 0(扰动投影消失)
Kedia 理论的解释:
- 恢复:LayerNorm 的归一化效应使
h_inject和h_baseline的方向对齐 - 衰减:残差累积使扰动被稀释(相对贡献下降)
数学表达:
1 | cos(h_inject, h_baseline) → 1 (方向对齐) |
对内省窗口的意义
修正后的假说
原假说:内省窗口的边界由 LayerNorm 谱范数 < 1 决定。
修正假说:内省窗口的边界由残差累积稀释决定。
关键参数:
- 模型深度 N
- 隐藏维度 d
- 扰动分散率(取决于 LayerNorm 的 Jacobian 结构)
预测:
- 更深的模型有更长的内省窗口(因为 N 更大)
- 更大的模型有更长的内省窗口(因为 d 更大)
- 移除 LayerNorm 会加速衰减(因为没有归一化,扰动更快分散)
验证方向
验证 Kedia 的理论
-
测量前向方差增长:
1
var_forward = torch.var(h_L, dim=-1)
预期:Pre-LN 模型
var_forward ∝ L -
测量扰动投影衰减:
1
proj_L = (h_inject_L - h_baseline_L) @ v / ||v||
预期:
proj_L ∝ 1/√L -
对比不同深度模型:
预期:更深模型的内省窗口更大
批判性反思
理论缺口
-
Kedia 理论是初始化理论:
- 在初始化时成立
- 训练后是否仍然成立?
-
扰动不是小量:
- Kedia 假设小扰动
- Hahami 的注入可能较大
-
缺少动态分析:
- Kedia 分析的是静态方差
- Hahami 观察的是动态传播
与其他理论的关系
- Xiong (2020):LayerNorm 谱范数 = O(1),不能解释衰减
- Kedia (2024):残差累积导致扰动稀释,可以解释衰减
- TaperNorm:反向梯度移除径向分量,与衰减无关
结论
核心发现:
- 残差衰减不是 LayerNorm 谱范数 < 1 导致的
- 而是残差累积 + LayerNorm 归一化的复合效应
- 扰动的相对贡献下降(投影衰减),但绝对大小可能增长
对内省窗口的启示:
- 内省窗口大小 ∝ √(N · d)(推测)
- 可以通过增加深度或宽度来延长窗口
验证方向:
- 测量前向方差增长
- 测量扰动投影衰减
- 对比不同架构的内省窗口
关键引用
- Transformers Get Stable - Kedia et al. 2024
- On Layer Normalization in the Transformer Architecture - Xiong et al. 2020
- Detecting the Disturbance - Hahami et al. 2025
最后更新: 2026-03-16 18:00
核心发现: Kedia (2024) 的信号传播理论表明,残差衰减来自残差累积导致的扰动稀释,而非 LayerNorm 谱范数 < 1。扰动的相对贡献(投影)衰减,但绝对大小可能因残差累积而增长。