
2026/03/08 12:57
**TPU 上で Flash Attention を無理やり導入し、試行錯誤で学び取る**
RSS: https://news.ycombinator.com/rss
要約▶
Japanese Translation:
記事では、GPU向けに Triton で書かれた Flash‑Attention カーネルを TPU v5e 上で動作する JAX/XLA に移植することを調査しています。TPU ハードウェアはすでに Flash Attention が最適化しているタイル化とメモリアクセスパターンを実装しているため、アルゴリズム自体は変更する必要がなく、プログラミングモデルだけが変わります。
JAX ではコードが抽象値で一度トレースされ、不変の HLO グラフが生成されます。このグラフではミューテーションを純粋関数(、lax.dynamic_update_slice)で表現する必要があります。jax.lax.fori_loopを使った単純実装は 4096 トークン時に XLA の融合済み標準注意機構より約 35 倍遅くなることが判明しました。外側ループをfori_loopに切り替え、反復の独立性を宣言すると劇的な速度向上(16384 トークンで約 45 倍速)が得られ、さらにスコア行列が VMEM に収まらなくなるとき(≈8192 トークン)には融合済みベースラインを上回りました。vmap
TPU v5e アーキテクチャは、128 × 128 MXU を 4 個備えた TensorCore、1 台の VPU、スカラー単位、および約 128 MB のオンチップ VMEM を持ちます。重みが固定される重量定常型のシステマティック配列により活性化をストリームしながら重みを保持します。エミュレータ結果では、標準注意機構は単一大規模行列積で約 94 % の MXU 利用率を達成しますが、Flash Attention はタイルごとに 2 回の小規模行列積しか実行せず、20–30 % に留まります。これは XLA が正しく融合できない場合の性能ギャップを説明しています。
JAX の組み込みはすでに試験済みサイズ全てで融合済み標準注意機構と同等の最適化実装をトリガーします。Google の Splash Attention(Pallas カーネル)は、非常に長いシーケンスやマルチヘッド設定でこのベースラインを上回るために必要です。記事は結論として、TPU ハードウェアは GPU 上のソフトウェアで Flash Attention が行うタイル化を実質的に実装していると述べています。単一ヘッド、d = 64 のワークロードでは手動タイル化は不要ですが、極端なシーケンス長やマルチヘッド構成ではカスタム Pallas カーネルが有利になることがあります。主要な洞察は、jax.nn.dot_product_attentionを通じて反復の独立性を明示することで、XLA が低レベル最適化を JAX で手動再実装するよりも遥かに効果的に演算を融合できるという点です。vmap
本文
LLM内部構造シリーズ第5回:TPUでFlash Attentionを実装する
1. はじめに
前回(Part 1)では注意機構、Part 2は生成、Part 3はFlash Attention アルゴリズム、そしてPart 4はTriton カーネルをGPU上に移植しました。今回の投稿では、Part 4で作った Triton カーネルを TPU へ持ち込みます。
Colab の無料 TPU ランタイム(Runtime → Change runtime type → TPU)を使えば、JAX と XLA を通じてコードを書き、コンパイルし、実行できます。
目次
- JAX/XLA:TPU プログラミングモデル
- 標準的な因果注意(causal attention)
- JAX による Flash Attention の実装
- ベンチマーク ― 真相の瞬間
- 実際に起きたことは?
- vmap の洞察
- それでも TPU は何なのか?
- システロリック・アレイを通じてデータが流れる仕組み
- システロリック・アレイエミュレーターの構築
- エミュレーターから得られた知見
- Pallas:コンパイラに勝つためには何が必要か?
- 実際に学んだこと
JAX/XLA:TPU プログラミングモデル
Part 4 では Triton カーネル(
program_id, tl.load / tl.store 等)を手書きして、どのバイトがどこへ移動するかを正確に制御しました。JAX はその上位レイヤーで、
matmul, exp, where などの演算を直接表現し、XLA コンパイラがハードウェアへのマッピングを決定します。
jax.jit が呼ばれると次のことが起きます:
- トレース – Python 関数を抽象値で一度実行して、どの演算が発生するかを記録。
- HLO 生成 – これを High‑Level Operations(dot, reduce, broadcast 等)というグラフに変換。
- 最適化 – 主に要素ごとの演算列を単一カーネルへ融合し、中間結果が HBM に流れないようにする。
- デバイスコード生成 – GPU 用は PTX、TPU 用は VLIW 命令。
Python は TPU 上で実行されるわけではなく、静的なバイナリへコンパイルされる仕様書です。
代償:可変性が失われる
Triton では
tl.store(ptr, val) のようにポインタを書き換えられます。JAX 配列は不変で、
out[i] = val は存在しません。理由は
jax.jit が関数を純粋計算グラフへトレースするためです。副作用があるとトレースが壊れてしまうからです。
| Triton (Part 4) | JAX (本稿) |
|---|---|
| , 可変 acc → 新状態: |
は Python の for 文をトレース時に展開します。jax.lax.fori_loop
はスライス位置が実行時値になり得る配列を返すだけです。dynamic_update_slice
標準的な因果注意
Part 3・4 と同じ基礎(全
(n, n) スコア行列を生成):
def standard_causal_attention(Q: jax.Array, K: jax.Array, V: jax.Array) -> jax.Array: """標準的な因果注意。Q, K, V は (n, d)、出力は (n, d)。""" assert Q.ndim == K.ndim == V.ndim == 2 assert Q.shape == K.shape == V.shape n, d = Q.shape scale = jnp.float32(1.0 / math.sqrt(d)) q = Q.astype(jnp.float32) k = K.astype(jnp.float32) v = V.astype(jnp.float32) scores = (q @ k.T) * scale # (n, n) causal_mask = jnp.triu(jnp.ones((n, n), dtype=bool), k=1) scores = jnp.where(causal_mask, -jnp.inf, scores) # (n, n) weights = jax.nn.softmax(scores, axis=-1) # (n, n) out = weights @ v # (n, d) return out.astype(Q.dtype)
standard_causal_attention_jit = jax.jit(standard_causal_attention)。
XLA はこの式全体を見て、単一の最適化カーネルへ融合します。中間行列が HBM へ流れないのでベースラインです。
JAX による Flash Attention
Part 3 の NumPy バージョンと Part 4 の Triton カーネルと同じアルゴリズム(
running_max, running_sum, acc を持つオンライン softmax)を、JAX で実装します。
@partial(jax.jit, static_argnames=("block_m", "block_n")) def flash_attention_tiled(Q: jax.Array, K: jax.Array, V: jax.Array, block_m: int = 128, block_n: int = 128) -> jax.Array: """因果 Flash Attention(タイル付きオンライン softmax)""" assert Q.ndim == K.ndim == V.ndim == 2 assert Q.shape == K.shape == V.shape assert block_m > 0 and block_n > 0 n, d = Q.shape q = Q.astype(jnp.float32) k_all = K.astype(jnp.float32) v_all = V.astype(jnp.float32) scale = jnp.float32(1.0 / math.sqrt(d)) SoftmaxState = tuple[jax.Array, jax.Array, jax.Array] num_q_blocks = math.ceil(n / block_m) num_k_blocks = math.ceil(n / block_n) n_pad = num_q_blocks * block_m out = jnp.zeros((n_pad, d), dtype=jnp.float32) q_offsets = jnp.arange(block_m) k_offsets = jnp.arange(block_n) def q_body(q_block: int, out_buf: jax.Array) -> jax.Array: q_start = q_block * block_m q_idx = q_start + q_offsets # (block_m,) q_mask = q_idx < n q_safe = jnp.minimum(q_idx, n - 1) q_tile = jnp.where(q_mask[:, None], q[q_safe, :], 0.0) # (block_m, d) running_max = jnp.full((block_m,), -jnp.inf, dtype=jnp.float32) running_sum = jnp.zeros((block_m,), dtype=jnp.float32) acc = jnp.zeros((block_m, d), dtype=jnp.float32) def k_body(k_block: int, state: SoftmaxState) -> SoftmaxState: running_max, running_sum, acc = state k_start = k_block * block_n k_idx = k_start + k_offsets # (block_n,) k_mask = k_idx < n k_safe = jnp.minimum(k_idx, n - 1) k_tile = jnp.where(k_mask[:, None], k_all[k_safe, :], 0.0) v_tile = jnp.where(k_mask[:, None], v_all[k_safe, :], 0.0) scores = (q_tile @ k_tile.T) * scale # (block_m, block_n) causal = q_idx[:, None] >= k_idx[None, :] valid = q_mask[:, None] & k_mask[None, :] & causal scores = jnp.where(valid, scores, -jnp.inf) tile_max = jnp.max(scores, axis=1) # (block_m,) new_max = jnp.maximum(running_max, tile_max) rescale = jnp.where( jnp.isfinite(running_max), jnp.exp(running_max - new_max), 0.0, ) weights = jnp.where( jnp.isfinite(new_max)[:, None], jnp.exp(scores - new_max[:, None]), 0.0, ) running_sum = rescale * running_sum + jnp.sum(weights, axis=1) acc = rescale[:, None] * acc + weights @ v_tile return new_max, running_sum, acc running_max, running_sum, acc = jax.lax.fori_loop( 0, num_k_blocks, k_body, (running_max, running_sum, acc)) out_tile = jnp.where(running_sum[:, None] > 0, acc / running_sum[:, None], 0.0) out_buf = jax.lax.dynamic_update_slice(out_buf, out_tile, (q_start, 0)) return out_buf out = jax.lax.fori_loop(0, num_q_blocks, q_body, out) return out[:n, :].astype(Q.dtype)
実装で気づいた点
| Triton(Part 4) | JAX(本稿) |
|---|---|
ポインタ演算と でメモリアクセスを手動で計算 | 配列インデックスだけで済む(コンパイラが最適化) |
のような in‑place 更新 | 本体は純粋関数:状態を入力し、更新後の状態を返す |
| ループ展開は手動で行う | を使い、XLA に「実際にループがある」ことを知らせる |
| SRAM/HBM の配置を明示できる | XLA が計算グラフからオンチップ化のヒントを推測する |
でタイルインデックスを作成 | , を JAX 配列として作成し、境界チェックとマスクに利用 |
Triton は細かい制御が可能だがコード量も増える。
JAX はポータビリティ(CPU・GPU・TPU で同じコード)が得られ、コンパイラに任せることで最適化を期待できる。
正確性チェック
n=257, d=64, blocks=(64,64) match=True max_abs=0.004399 n=513, d=64, blocks=(128,128) match=True max_abs=0.003483 n=777, d=80, blocks=(128,64) match=True max_abs=0.005013
ベンチマーク ― 真相の瞬間
- 環境:Colab TPU v5e(単一チップ)、JAX 0.7.2、float32、1 ヘッド
。(n, 64) - 測定方法:10 回平均(warmup 1 回)で
を呼び、非同期ディスパッチを除外。block_until_ready() - 結果:
| n | Standard (fused) 時間 | Flash Attention 時間 | 速度差 |
|---|---|---|---|
| 4096 | 0.1 s | 3.5 s | 35×遅い |
XLA の融合が HBM ラウンドトリップを 14.5 倍速化していることも確認できた。
vmap の洞察
jax.vmap は「一つのアイテムを処理する関数」をバッチに拡張し、XLA に各バッチが独立であると伝える。外側の
fori_loop を vmap へ置き換えることで、Q ブロックごとの計算を同時並列化できるようになる。
@partial(jax.jit, static_argnames=("block_m", "block_n")) def flash_attention_vmap(Q, K, V, block_m=128, block_n=128) -> jax.Array: # ... 省略(前述の実装と同じが vmap でラップ)
ベンチマークでは、
n=16384 の時に 45× の高速化を確認。アルゴリズム・タイルは変わらず、コンパイラへのヒントだけが違った。