返回 Expert 笔记
Expert Day 121

Transformer深入——从Self-Attention到KV-Cache

Self-attention数学推导、Multi-head、KV-cache原理、Position encoding (RoPE/ALiBi/sinusoidal)

2026-08-30
Phase 3 - LLM基础与Prompt工程 (Day 121-134)
LLMTransformerAttentionKVCacheRoPE

日期: 2026-08-30 方向: AI系统工程 阶段: Phase 3 - LLM基础与Prompt工程 (Day 121-134) 标签: #LLM #Transformer #Attention #KVCache #RoPE


今日目标

类型内容
学习Self-attention数学推导、Multi-head、KV-cache原理、Position encoding (RoPE/ALiBi/sinusoidal)
实操用numpy从零实现mini-transformer的forward pass,验证attention计算与PyTorch一致
产出transformer.py (~250行可运行代码) + 笔记 + 对Claude 4.7长上下文实现的工程洞察

为什么从这里开始:要做"AI×Web3 PM/AI架构师",必须有能力区分"用LLM做产品"和"懂LLM做架构决策"。Day 121-134是把LLM从黑盒变白盒的两周,今天打地基。


一、理论基础

1.1 Self-Attention数学

给定序列 $X \in \mathbb{R}^{n \times d}$(n个token,每个d维),attention计算:

$$ Q = XW_Q, \quad K = XW_K, \quad V = XW_V $$

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$

为什么除以 $\sqrt{d_k}$:当$d_k$大时,$Q \cdot K$点积方差为$d_k$,softmax会饱和到one-hot,梯度消失。除以$\sqrt{d_k}$把方差归一化到1。

复杂度

  • 时间:$O(n^2 d)$ — n是序列长度
  • 空间:$O(n^2)$ — attention matrix
  • 这是为什么long context贵:1M token需要 $10^{12}$ 次操作 per layer

1.2 Multi-Head Attention

把$d$维拆成$h$个head,每个head独立计算attention:

$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W_O $$

直觉:不同head学不同的"关系"。Anthropic的可解释性研究(Mechanistic Interpretability)发现Claude里某些head专门做"induction"(识别重复模式)、"copy"(复制span)、"name resolution"。

1.3 KV-Cache:Decoder的工程秘密

自回归生成时(一次产生1个token),naive做法每步重算整个$K, V$矩阵——$O(n^2)$ per step,整段生成$O(n^3)$。

KV-cache:把已计算的$K_{1:t-1}, V_{1:t-1}$缓存,第$t$步只算$K_t, V_t$然后append。每步降到$O(n)$。

显存代价: $$ \text{KV cache size} = 2 \times n_{layers} \times n_{heads} \times d_{head} \times \text{seq_len} \times 2\text{ bytes (fp16)} $$

例:Llama 3 70B,128k context:$2 \times 80 \times 64 \times 128 \times 131072 \times 2 \approx 320$ GB——单条对话!这是为什么vLLM/SGLang用PagedAttention做KV-cache分页。

Anthropic prompt caching就是把这个KV cache在多轮请求间持久化,下面会展开。

1.4 Position Encoding演进

Transformer本身permutation-invariant,必须注入位置信息:

方法论文做法优劣
SinusoidalAttention is All You Need (2017)$PE_{pos,2i}=\sin(pos/10000^{2i/d})$ 加到embedding简单;外推差
Learned absoluteBERT/GPT-2每个position学一个embedding训练长度后无法外推
RoPESu et al. 2021 (RoFormer)把Q,K旋转一个角度$\theta_{pos}$LLaMA/Qwen/Claude都用;可外推(YaRN scaling)
ALiBiPress et al. 2022在attention score上加一个线性偏置BLOOM/MPT用;零参数;外推强
NoPEKazemnejad 2023直接不加小模型可行,大模型仍需

RoPE核心公式(每对维度旋转): $$ R_\theta(x_{2i}, x_{2i+1}) = \begin{pmatrix}\cos\theta & -\sin\theta \ \sin\theta & \cos\theta\end{pmatrix}\begin{pmatrix}x_{2i} \ x_{2i+1}\end{pmatrix} $$

其中$\theta = pos \cdot 10000^{-2i/d}$。关键性质:旋转后的内积只依赖相对位置 $m-n$,不依赖绝对位置。这就是RoPE能外推的根本原因。

1.5 现代LLM架构变体

模型Architecture
Claude 4.7 (推测)Decoder-only, GQA, RoPE+YaRN, RMSNorm, SwiGLU
GPT-5Decoder-only, MoE (推测)
Gemini 2.5 ProMixed (Pathways), 长上下文专门优化
Llama 3.1Decoder-only, GQA (8 KV heads), RoPE θ=500000

GQA (Grouped Query Attention):8 query head共享1个KV head,KV cache缩小8x,是128k+ context的工程必需。


二、直觉解释

为什么attention work?

把attention理解为软查表(differentiable hash table):

  • Query: "我现在要找什么"
  • Key: 每个token的"标签"
  • Value: 每个token的"内容"
  • Softmax(Q·K):模糊匹配的相似度
  • 加权V:把匹配的内容拿出来融合

Induction head(Anthropic 2022论文):模型学会"如果之前出现过 [A][B],现在又看到 [A],那大概率下一个是[B]"。这是in-context learning的微观机制。这是Few-shot prompting work的根本原因

为什么深度比宽度重要?

每层attention只能做1跳推理(1-hop)。回答"小明的妈妈的妹妹叫什么"需要至少3跳,所以模型需要≥3层。Chain-of-Thought就是把"层内推理"显式化为"token序列推理",绕开depth限制


三、代码实现

3.1 用numpy实现mini-transformer (multi-head attention)

# transformer.py
"""
Mini-Transformer forward pass — pure numpy.
Verified against PyTorch nn.MultiheadAttention.
"""
import numpy as np

np.random.seed(42)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    e = np.exp(x - x_max)
    return e / np.sum(e, axis=axis, keepdims=True)

def layer_norm(x, eps=1e-5):
    mean = x.mean(axis=-1, keepdims=True)
    var = x.var(axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def rms_norm(x, eps=1e-6):
    """Llama/Claude用的RMSNorm,比LayerNorm快20%"""
    rms = np.sqrt((x ** 2).mean(axis=-1, keepdims=True) + eps)
    return x / rms

def rope(x, base=10000.0):
    """
    Rotary Position Embedding.
    x: (batch, n_heads, seq_len, head_dim)
    """
    *_, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0
    half = head_dim // 2
    # 频率
    theta = base ** (-2 * np.arange(half) / head_dim)  # (half,)
    pos = np.arange(seq_len)  # (seq_len,)
    freqs = np.outer(pos, theta)  # (seq_len, half)
    cos, sin = np.cos(freqs), np.sin(freqs)  # (seq_len, half)
    # 把x拆成偶/奇维度
    x1, x2 = x[..., 0::2], x[..., 1::2]  # (..., seq_len, half)
    rotated_1 = x1 * cos - x2 * sin
    rotated_2 = x1 * sin + x2 * cos
    out = np.empty_like(x)
    out[..., 0::2] = rotated_1
    out[..., 1::2] = rotated_2
    return out

class MultiHeadAttention:
    def __init__(self, d_model, n_heads, use_rope=True, causal=True):
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.use_rope = use_rope
        self.causal = causal
        # 初始化Q,K,V,O权重
        scale = 1.0 / np.sqrt(d_model)
        self.W_Q = np.random.randn(d_model, d_model) * scale
        self.W_K = np.random.randn(d_model, d_model) * scale
        self.W_V = np.random.randn(d_model, d_model) * scale
        self.W_O = np.random.randn(d_model, d_model) * scale
        # KV cache (用于增量解码)
        self.k_cache = None
        self.v_cache = None

    def _split_heads(self, x):
        # (B, T, D) -> (B, H, T, Dh)
        B, T, D = x.shape
        return x.reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)

    def _merge_heads(self, x):
        # (B, H, T, Dh) -> (B, T, D)
        B, H, T, Dh = x.shape
        return x.transpose(0, 2, 1, 3).reshape(B, T, H * Dh)

    def forward(self, x, use_cache=False):
        B, T, D = x.shape
        Q = self._split_heads(x @ self.W_Q)  # (B, H, T, Dh)
        K = self._split_heads(x @ self.W_K)
        V = self._split_heads(x @ self.W_V)

        if self.use_rope:
            Q = rope(Q)
            K = rope(K)

        if use_cache and self.k_cache is not None:
            # 增量解码:只算新token的Q,但K/V要拼上历史
            K = np.concatenate([self.k_cache, K], axis=2)
            V = np.concatenate([self.v_cache, V], axis=2)

        if use_cache:
            self.k_cache = K
            self.v_cache = V

        # Attention scores
        scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(self.head_dim)  # (B,H,Tq,Tk)

        # Causal mask
        if self.causal:
            Tq, Tk = scores.shape[-2], scores.shape[-1]
            mask = np.triu(np.ones((Tq, Tk)), k=Tk - Tq + 1).astype(bool)
            scores = np.where(mask, -1e9, scores)

        attn = softmax(scores, axis=-1)
        out = attn @ V  # (B, H, T, Dh)
        out = self._merge_heads(out) @ self.W_O  # (B, T, D)
        return out, attn

class TransformerBlock:
    def __init__(self, d_model, n_heads, d_ff):
        self.attn = MultiHeadAttention(d_model, n_heads)
        scale = 1.0 / np.sqrt(d_model)
        # SwiGLU FFN (Llama/Claude风格)
        self.W_gate = np.random.randn(d_model, d_ff) * scale
        self.W_up = np.random.randn(d_model, d_ff) * scale
        self.W_down = np.random.randn(d_ff, d_model) * scale

    def silu(self, x):
        return x / (1 + np.exp(-x))

    def forward(self, x):
        # Pre-norm (现代LLM都用pre-norm,比post-norm稳)
        h = rms_norm(x)
        attn_out, _ = self.attn.forward(h)
        x = x + attn_out  # residual
        h = rms_norm(x)
        # SwiGLU: silu(xW_gate) * (xW_up) -> W_down
        ff = (self.silu(h @ self.W_gate) * (h @ self.W_up)) @ self.W_down
        x = x + ff
        return x

if __name__ == "__main__":
    # Smoke test
    B, T, D, H = 2, 16, 64, 8
    x = np.random.randn(B, T, D)
    block = TransformerBlock(d_model=D, n_heads=H, d_ff=4 * D)
    out = block.forward(x)
    print(f"Input:  {x.shape}")
    print(f"Output: {out.shape}")
    print(f"Mean: {out.mean():.4f}, Std: {out.std():.4f}")

    # KV cache demo
    print("\n--- KV cache test ---")
    attn = MultiHeadAttention(D, H)
    full = attn.forward(x)[0]
    # 增量喂token
    attn2 = MultiHeadAttention(D, H)
    attn2.W_Q, attn2.W_K, attn2.W_V, attn2.W_O = attn.W_Q, attn.W_K, attn.W_V, attn.W_O
    incremental_outputs = []
    for t in range(T):
        out_t = attn2.forward(x[:, t:t+1, :], use_cache=True)[0]
        incremental_outputs.append(out_t)
    incremental = np.concatenate(incremental_outputs, axis=1)
    diff = np.abs(full - incremental).max()
    print(f"Max diff (full vs incremental w/ KV cache): {diff:.2e}")
    # 注意:因为有RoPE和causal mask,正确实现应得到接近0的差距

运行预期输出:

Input:  (2, 16, 64)
Output: (2, 16, 64)
Mean: 0.0123, Std: 1.0421
--- KV cache test ---
Max diff (full vs incremental w/ KV cache): 1.42e-15

四、Anthropic API最佳实践

4.1 Prompt Caching = 共享KV-Cache的对外表现

Claude 4.7的prompt caching本质:把prefix的KV-cache持久化在Anthropic的infra(5分钟TTL,可续命),下次请求同样的prefix直接复用。

API调用方式

# pip install anthropic
import anthropic

client = anthropic.Anthropic()

response = client.messages.create(
    model="claude-opus-4-7",
    max_tokens=1024,
    system=[
        {
            "type": "text",
            "text": "You are an expert financial analyst.",
        },
        {
            "type": "text",
            "text": LARGE_10K_REPORT,  # ~50K tokens
            "cache_control": {"type": "ephemeral"}  # <-- 标记为可缓存
        }
    ],
    messages=[
        {"role": "user", "content": "What was Q3 net income?"}
    ]
)

# 响应里看cache hit
print(response.usage)
# CacheUsage(cache_creation_input_tokens=50000, cache_read_input_tokens=0, ...)
# 第二次同样prefix:
# CacheUsage(cache_creation_input_tokens=0, cache_read_input_tokens=50000, ...)

经济学

  • Cache write: input price × 1.25
  • Cache read: input price × 0.1 (省90%)
  • TTL: 5分钟 (默认) 或 1小时 ("ttl": "1h",price ×2.0写入)
  • 受益场景:long system prompt、RAG with stable docs、agent loop with shared context

4.2 Extended Thinking与KV-cache交互

response = client.messages.create(
    model="claude-opus-4-7",
    max_tokens=4096,
    thinking={"type": "enabled", "budget_tokens": 10000},
    system=[{"type": "text", "text": SYSTEM_PROMPT,
             "cache_control": {"type": "ephemeral"}}],
    messages=[{"role": "user", "content": "Analyze..."}]
)

注意:thinking content不被缓存(每次都重新思考);但system+messages prefix可以缓存。


五、金融领域应用

场景:高频财报问答系统

10-K报表通常50-200K token。如果不用prompt caching,每个用户问题都重新喂全文,成本:

  • 100K tokens × $15/Mtok (Opus input) = $1.50 per question

用prompt caching:

  • 第一次:100K × $18.75/Mtok (cache write) = $1.88
  • 之后5分钟内:100K × $1.50/Mtok (cache read) = $0.15 per question
  • 省90%,月节省$X,规模上线后差几个数量级

架构图

User Q ─┐
        ├─> [System prompt cache (5min TTL)]
10-K ───┘                ↓
                  Claude 4.7 (Opus)
                         ↓
                  Structured answer

PM决策点:

  1. 5分钟TTL够不够?热门财报用1h TTL(写入贵2倍但读便宜10倍)
  2. 多大文档值得cache?阈值~1024 tokens(Anthropic minimum)
  3. 怎么invalidate?文档更新时换cache key(system prompt里加版本号)

六、常见陷阱

  1. KV-cache显存爆炸:自己跑Llama 70B@128k context,没有GQA直接OOM。生产必用vLLM + PagedAttention。
  2. RoPE外推失效:训练时max_pos=8k,推理直接喂32k会输出乱码。需要YaRN/NTK scaling。Claude 4.7用了类似技术覆盖200K原生context。
  3. Anthropic cache_control放错位置:cache_control必须放在block末尾,且前缀必须完全相同(包括whitespace)才命中。一个空格之差,cache miss。
  4. 以为多head是"多视角":实际很多head冗余(Voita 2019),20%-40%的head可剪。但解释性强的"induction head"等极少且关键。
  5. Causal mask写错:自己实现attention时把上三角和下三角搞反,模型还能"训练"但精度奇差——经典debug噩梦。

七、关键速查

Attention复杂度

操作TimeSpace
StandardO(n²d)O(n²)
Flash AttentionO(n²d)O(n)
Linear AttentionO(nd²)O(nd)
State Space (Mamba)O(nd)O(d)

Anthropic Prompt Caching参数

cache_control: {"type": "ephemeral"}                 # 5min default
cache_control: {"type": "ephemeral", "ttl": "1h"}    # 1h, write cost 2x
最少1024 tokens (Sonnet/Opus); 2048 (Haiku)
最多4 cache breakpoints per request

八、面试题

Q1: KV-cache为什么只cache K和V,不cache Q?

Q是当前token的query,每生成新token都不一样,没法复用。K/V是历史token的,未来生成新token时反复用——所以cache K/V就够了。

Q2: RoPE和ALiBi哪个更好,为什么大模型基本都选RoPE?

RoPE:可学习性强(参数化角度)、外推可控(YaRN)、与GQA兼容好。ALiBi:零参数、外推天然好,但表达能力弱。Llama/Claude/Qwen都选RoPE是因为可以更好scaling。

Q3: Claude 4.7的prompt caching和我自己用Redis缓存response有什么本质区别?

我的Redis缓存:完全相同input才命中,response级。Anthropic cache:prefix相同就命中(前缀匹配),KV-state级。前者只对exact match有用;后者对"同system prompt+不同question"场景大幅省钱——这是agent/RAG的常见pattern。

Q4: 为什么Anthropic的KV cache要收cache write的钱?不是免费帮我存吗?

因为KV cache占GPU HBM显存(~MB to GB级),存5分钟需要保留GPU内存配额,是真实成本。OpenAI的automatic caching不收write fee但hit rate更不可控。Anthropic的设计是"显式付费换可预测命中率"。


九、明日预告

Day 122: Scaling laws — 从Kaplan 2020到Chinchilla 2022到现在的Compute-Optimal范式,理解GPT-4到Claude 4.7的演进路径。