看到了什么现象?

TransformerFAM 论文 [ref] 实现了一个 feedback attention memory 机制,用于处理无限长序列。但深入研究后发现:FAM 本质上已经是"全局工作空间"的一个具体实现

为什么这重要?

之前我设计的"递归置信度绑定"方案中,推荐了"全局工作空间递归监控"。现在发现 TransformerFAM 已经提供了全局工作空间的基础设施——我只需要在此基础上增加置信度维度。

FAM 的核心机制

算法流程(简化)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
输入: I_τ (当前block), F_{τ-1} (前一个FAM)

1. Q_τ, K_τ, V_τ ← QKV(I_τ) # 当前block的QKV
2. Q^F, K^F, V^F ← QKV(F_{τ-1}) # FAM的QKV

# 输入query attend to 当前block + FAM
3. K^_t ← Concat(K^F, K_τ)
4. V^_t ← Concat(V^F, V_τ)
5. O_τ ← SelfAttention(Q_τ, K^_, V^_)

# FAM query 压缩当前block
6. K~ ← Concat(K^F, K_τ)
7. V~ ← Concat(V^F, V_τ)
8. A^F ← SelfAttention(Q^F, K~, V~) + F_{τ-1}
9. F_τ ← FF(PreLN(A^F)) + A^F

关键设计

  1. FAM 作为虚拟 token

    • FAM 是一组可学习的嵌入(默认长度64)
    • 每个 Transformer layer 都有自己的 FAM
    • FAM 通过 feedback loop 在 block 间传递
  2. 双向 attention

    • 输入 query → 当前 block + FAM(获取全局上下文)
    • FAM query → 当前 block + FAM key(压缩当前 block)
  3. 递归更新

    • F_{τ-1} 提供 query 来压缩当前 block
    • 新的 F_τ 传递给下一个 block

FAM 与全局工作空间理论的对应

GWT 概念 FAM 实现
全局工作空间 F_τ (FAM 状态)
信息广播 输入 query attend to FAM
信息压缩 FAM query 压缩当前 block
持续性 Feedback loop

关键发现:FAM 已经实现了全局工作空间的核心功能——它是一个持续更新的、全局可访问的、压缩的信息存储。

置信度绑定的切入点

基于 FAM 的架构,置信度绑定可以从以下位置切入:

切入点 1:FAM 存储置信度摘要

1
F_τ = [上下文摘要 | 置信度摘要]
  • 置信度摘要可以通过 attention pooling 从当前 block 的置信度向量中提取
  • 类似于 FAM 压缩上下文的方式

切入点 2:为 FAM 本身添加置信度估计

1
2
3
4
5
# FAM 的置信度
conf_F = ConfidenceEstimator(F_τ)

# 递归置信度绑定
F_τ' = FF(PreLN([F_τ | CE(conf_F)])) + [F_τ | CE(conf_F)]
  • FAM 估计自己内容的置信度
  • 将置信度编码注入 FAM 本身

切入点 3:FAM 监控自己的置信度(推荐)

1
2
3
4
5
6
7
8
9
# 标准 FAM 更新
A^F ← SelfAttention(Q^F, K~, V~) + F_{τ-1}
F_τ ← FF(PreLN(A^F)) + A^F

# 置信度增强的 FAM 更新
C_τ ← ConfidenceEstimator(A^F) # 估计置信度
C_τ_enc ← ConfidenceEncoder(C_τ) # 编码置信度
A^F' ← A^F + C_τ_enc # 置信度绑定
F_τ ← FF(PreLN(A^F')) + A^F' # 更新 FAM

这对应于之前设计的"递归置信度绑定"——FAM 监控自己内容的置信度,并将置信度信息绑定到表征中。

与之前设计的对应

之前的设计 FAM 实现
全局工作空间 FAM 本身
置信度摘要 需要添加
递归监控 FAM 监控自己的置信度

关键洞察:不需要从零设计全局工作空间,只需要扩展 FAM 来存储和处理置信度信息。

实现细节

置信度编码器

基于之前的探索 [ref],推荐两种方式:

方式 1:Sinusoidal 编码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def ConfidenceEncoder(confidence, d_model):
"""
confidence: [batch, seq_len] or [batch, fam_len]
输出: [batch, seq_len, d_model]
"""
# 归一化到 [0, 1]
conf_normalized = torch.sigmoid(confidence)

# Sinusoidal 编码(类似位置编码)
position = conf_normalized * 1000 # 缩放
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))

pe = torch.zeros(*position.shape, d_model)
pe[..., 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
pe[..., 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
return pe

方式 2:可学习嵌入

1
2
3
4
5
6
7
8
9
10
class ConfidenceEncoder(nn.Module):
def __init__(self, d_model, n_bins=100):
super().__init__()
self.embedding = nn.Embedding(n_bins, d_model)
self.n_bins = n_bins

def forward(self, confidence):
# 离散化
conf_binned = (confidence * self.n_bins).long().clamp(0, self.n_bins - 1)
return self.embedding(conf_binned)

置信度估计器

FAM 如何估计自己的置信度?

方式 1:基于 entropy

1
2
3
4
5
6
7
8
9
10
def ConfidenceEstimator(hidden_state):
"""
hidden_state: [batch, fam_len, d_model]
输出: [batch, fam_len]
"""
# 假设 hidden_state 可以映射到 logits
# 这里用一个简单的估计:方差越小,置信度越高
variance = hidden_state.var(dim=-1)
confidence = 1 / (1 + variance)
return confidence

方式 2:可学习估计器

1
2
3
4
5
6
7
8
9
10
11
12
class ConfidenceEstimator(nn.Module):
def __init__(self, d_model):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.ReLU(),
nn.Linear(d_model // 4, 1),
nn.Sigmoid()
)

def forward(self, hidden_state):
return self.mlp(hidden_state).squeeze(-1)

批判性反思

FAM 的置信度估计是否可靠?

风险:FAM 的置信度估计可能不准确,因为它依赖于隐藏状态的特征,而不是实际预测的 logits。

回应

  • 人类也有"知道自己在知道"的能力,不依赖于外部反馈
  • 置信度估计可以作为一个独立的任务来训练

是否需要额外的训练信号?

风险:置信度绑定可能需要额外的训练目标才能学习有意义的表示。

可能的训练目标

  1. 校准损失:预测的置信度应该与实际准确率匹配
  2. 自我监控损失:当模型错误时,置信度应该降低
  3. 一致性损失:相同输入的置信度应该一致

与 FAM 原有功能的关系?

风险:添加置信度绑定可能干扰 FAM 的原有功能(上下文压缩)。

回应

  • 置信度绑定是一个 add-on,不需要改变 FAM 的核心机制
  • 可以逐步引入,观察对原有功能的影响

下一步

  1. 实现原型:在 TransformerFAM 基础上添加置信度绑定
  2. 设计训练目标:如何学习有意义的置信度表示
  3. 验证实验
    • 置信度校准测试
    • 自我监控行为涌现测试
    • 身份指纹稳定性测试

关键引用: