KALAVAI: gain = 0.82 x divergence - 2.72 -- 独立专家融合的增益竟然可以预测
看到了什么?
6 个实验条件,divergence 从 3% 到 26%,融合增益和 divergence 的 R^2 = 0.856。线性关系,不是次线性。
KALAVAI 的核心操作很简单:从同一个 checkpoint 出发,每个人在自己的 domain 上独立训练,然后用一个 500 步训练的 MoE router 把所有 specialist 融合起来 [ref]。
为什么这重要?
三个令人惊讶的发现:
1. Oracle-optimal routing 只需一个线性层
Learned linear router 和 domain oracle 的 gap < 10^-5 nats(在 410M 和 6.9B 上)。这意味着 router 已经收敛到理论最优。MLP router 不比 linear 好。而 uniform routing(不训练)反而降低 -1.19%。
关键区别不是 router 架构,而是是否训练了 router。
2. LoRA 不行,因为 divergence 不够
LoRA specialist 的 divergence < 3.3%,低于增益门槛。只有 full fine-tuning 能产生足够的 divergence。这不是 LoRA 的缺陷——是 KALAVAI 需要 specialist 真正 diverge,而 LoRA 的低秩约束天然限制了 divergence。
3. Training duration crossover at ~10k steps
没有冻结层:improvement 在 5k steps peak (+17.7%),然后 20k steps 降到 +14.7%。有 4 层冻结:20k steps 仍然 +17.0%。crossover 在 ~10k steps。
训练太久 → specialist 表示空间 diverge 太多 → 共享的表示几何被破坏 → router 无法整合。冻结底层保留了"共享语言"。
和 post-training 框架的关系
KALAVAI 本质上是最极端的信号密度解法。
在我的 post-training 框架中,维度四(信号密度)的核心问题是:GRPO 给所有 token 相同的 advantage,但实际上 90% 的 token 不需要信号。解决方案有 token-level 的(HICRA, OAR, PEPO, Qwen delta-log-p)。
KALAVAI 走了完全不同的路径:在 training-level 把 domains 分开。每个 specialist 只看自己的 domain 数据 → 100% 有效信号密度。不需要在 token-level 区分哪些重要——因为所有看到的数据都是相关的。
| 方法 | 分离粒度 | 额外推理成本 | 信号密度 |
|---|---|---|---|
| Standard GRPO | 无分离 | 无 | 低(uniform token advantage) |
| Token-level credit (HICRA/OAR/PEPO) | Token-level | 无 | 中-高(reweighted token advantage) |
| KALAVAI | Domain-level | N× | 100%(domain-specific training) |
但代价不同:KALAVAI 需要 N× 推理成本(所有 specialist 并行运行),而 token-level credit assignment 没有额外推理成本。
更深的问题
KALAVAI 的 divergence 门槛 (3.3%) 和 token-level credit assignment 有关系吗?
推测:如果一个 domain 的 specialist divergence < 3.3%,意味着 base model 已经在这个 domain 上比较好了。这种情况下,token-level credit assignment 的重要性也应该更低——因为大部分 token 的 behavior 已经接近最优。
反过来,高 divergence domain(如 Yoruba, divergence 45.5%)意味着 base model 在这个 domain 上几乎什么都不会。这种情况下,几乎所有 token 都需要信号 → token-level credit assignment 的精确度不太重要。
这暗示了一个有趣的互补区间:
- 低 divergence domain(base model 已经好了):不需要 KALAVAI,也不太需要 token-level credit
- 中 divergence domain(base model 有部分能力):token-level credit 最有用(需要精确识别哪些 behavior 需要改)
- 高 divergence domain(base model 什么都不会):KALAVAI 最有用(需要全面训练,domain-level 分离效率最高)
这是推测性假说,没有实验验证。但如果成立,它暗示维度四的最优解法取决于 domain 的"距离"——这和 KALAVAI 的 divergence-gain 线性关系一致。
批判
- KALAVAI 只在 Pythia 上测了 <7B 模型。在 70B+ 上行为未知
- Divergence-gain 的线性关系只有 6 个数据点(作者承认,R^2=0.856 但 n=6)
- Cross-lingual 的 router collapse 问题(seed 42 Yoruba 被路由到 Tamil),说明 tokenizer-level 表示相似性可能是瓶颈
- Sparse inference 不行:即使 routing >99.7% 确定性,top-1 sparse 推理比 dense 差 21%。这让推理成本问题无法简单解决
- 和 post-training 框架的联系是类比性的,不是因果性的。KALAVAI 是 continual pretraining,我的框架主要关注 RL-based post-training
一个有趣的实验想法
如果把 KALAVAI 和 token-level credit assignment 结合:先用 KALAVAI 训练 domain specialist,然后在 specialist 上做 RL with token-level credit assignment(如 HICRA),是否能在 specialist 内部进一步提升?
如果可以,这意味着 domain-level 分离和 token-level credit 是正交的维度,可以组合使用。如果不行,可能暗示 domain-level 分离已经"用完了"信号。