Trying to push sparse attention further, and what didn't work

The previous post ended on a question. Calibrated Sparse Attention reduces attention compute in pretrained transformers by ~10× with no fine-tuning. The win lives in the attention pattern — most heads only use a small subset of keys, and you can skip the rest cheaply. So: same thinking applied to the projection machinery — the W_Q, W_K, W_V, W_O matrices that produce queries, keys, values, and outputs in the first place. Different per-head head widths. Variable per-token ranks. Heads that turn off when they’re not needed. The kind of intervention you’d hope a learned router could orchestrate.

I tried it. Most of what I tried lost to a static baseline at matched cost. The signal exists — every head’s importance varies meaningfully per token — but extracting it with a small learned router didn’t beat just dropping the least-useful heads with a single static mask. Four progressive technique improvements took the router from 10× worse than dense at the first attempt down to within +15% of a same-active-count static baseline. Still losing.

That’s a negative result. It’s also the most interesting finding from this thread, and I’m writing it up even though it doesn’t end where I hoped.

What the experiment was trying to show

The hypothesis is structurally identical to the one CSA validated, just applied to a different resource. Dense Q/K/V/O projections give every head the same width, every layer the same allocation, and every token access to every projection direction. That uniform allocation is wasteful if projection demand is actually heterogeneous — across layers, across heads, and across tokens. A learned projection policy that decides per (layer, head, token) which heads to run and how much rank to spend should, by analogy to CSA’s attention-reads result, match dense quality at substantially lower projection cost.

There are two cost axes that matter and they matter in different regimes:

The unified claim was: per-token routing should beat the existing uniform-static recipes on at least one of these axes at matched quality.

What looked promising on the way in

Two things gave me reason to expect this would work.

Existing literature on head importance. Voita 2019 and Michel 2019 are pretty clear that many heads in pretrained transformers can be pruned with little quality loss — 30–40% with training-time gates. GPT-2 small specifically — 144 heads across 12 layers — gives up 16–18 heads (~11–12% of 144) at ~5% perplexity cost on WikiText-2 with post-hoc iterative-greedy ablation, no retraining. The gap between 11% post-hoc and 40% trained-in is the cost of not retraining, and is a plausible floor for how much room a learned policy could find.

Per-token importance is real, not a measurement artifact. Before training any router, I instrumented per-(token, layer, head) gradient magnitudes on 1022 tokens of WikiText-2. Every head had coefficient of variation ≥ 1 across tokens; the top-16 head set’s Jaccard similarity across adjacent tokens was 0.17 — much closer to random (0.06) than static (1.0). A static mask captures essentially none of the token-conditional structure. If a router can extract this signal, it should beat static pruning.

So: per-head importance is non-uniform with measurable static slack, and per-token importance is genuinely token-varying. The remaining question was whether a cheap learned router could turn the per-token variance into an active-head budget that beat the static recipe.

Where it went sideways

The first router was as naive as possible. A nn.Linear(d_model=768, n_heads=12) per layer, trained closed-form via OLS on the gradient magnitudes captured above. About 110k parameters total across all 12 layers. Per-layer R² in [0.69, 0.84], mean 0.79 — the router predicts the gradient targets accurately. Inference: per-token top-K heads per layer.

PolicyActive heads/tokenPerplexityΔ vs dense
Dense14426.49
Static greedy, 16 dropped12827.51+3.87%
Router K=10/layer (120 active)120.0287+985%
Router K=8/layer (96 active)96.0432+1529%
Router K=6/layer (72 active)72.0599+2160%

At 120 active heads — only 24 fewer than dense, more heads kept than static greedy — perplexity is already 10× dense. The router doesn’t just lose; it collapses.

Two failures compound:

  1. Per-layer uniform K is layer-blind. Static greedy drops 16 heads heterogeneously: layer 0 keeps all 12 of its heads (dropping a single layer-0 head can cost +70 ppl on its own), while layer 11 loses 4. Per-layer K=10 forces every layer to drop exactly two, including layer 0 where the static policy keeps all. The router’s output shape can’t represent the right answer.

  2. Gradient ≠ ablation impact. Pearson correlation between gradient magnitude and single-shot ablation Δppl is only 0.27. Training the router on gradients gives high R² on the gradient target but poor ranking on the thing we actually care about. The lowest-gradient heads inside a layer aren’t the safest to drop.

This was diagnosable, so I tried fixes.

Four progressive improvements

Fix 1 — Force-keep critical heads. Use static iterative-greedy ranking to mark the top-N most-important heads as always-on. The router only decides among the rest. This eliminates the layer-0 catastrophe without changing the routing architecture. Improved noticeably but still well behind static greedy.

Fix 2 — Soft mask end-to-end with cost penalty. Replace the closed-form OLS step with end-to-end training. Use a soft sigmoid gate per (layer, head, token), add a λ · projection_cost term to the loss, train on WikiText-2. Lets the router learn what to drop directly from LM loss rather than from a gradient-target proxy. Better, but the soft-vs-binary distribution gap (soft training, binary inference) damaged the model: post-train dense perplexity got worse, not better.

Fix 3 — Straight-through estimator. Apply a binary top-K mask in the forward pass, route gradients through the soft sigmoid in the backward pass. Eliminates the soft/binary distribution gap. Post-train dense perplexity went from 26.49 → 23.31 (better than pre-train, the extra LM training helped), and the router converged monotonically.

Fix 4 — Longer STE training. 5× the compute (500 steps vs 100), same setup.

TechniqueRouter perplexity at K=10 (120 active)
Post-hoc OLS287
Post-hoc OLS + force-keep105
E2E soft mask (50 steps)66
STE 100 steps35
STE 500 steps28.5
Static greedy/24 (post-train baseline)24.8
router perplexity at K=10 (120 active heads, lower is better) naive linear collapses to 287 ppl, 10× dense + force-keep 105 + E2E soft mask 66 + STE, 100 steps 35 + STE, 500 steps 28.5 static greedy/24 = 24.8 router never crosses 287
Five progressive router techniques. Each technical fix narrowed the gap to the static greedy baseline (vertical dashed line), but even after STE end-to-end training at 500 steps the router lands 15% short. The naive linear router's bar runs off the legible scale at 287 ppl.

Each fix produced real, monotone progress. The router is learning. It just isn’t beating the bar. At 500 STE steps, the trained router is +15% behind static greedy at matched active-head count.

The training loss plateaued from step ~250 onward — last 200 steps moved mean LM loss by 0.06. The router is converged at this architecture and setup.

What’s actually going on

A few plausible (non-exclusive) reasons the router still loses:

Probably some combination. The natural next experiments — MLP router, global top-K, longer training — each add complexity without guaranteed payoff at this scale. The cheap experiments have been done.

What this is not

This isn’t a falsification of the hypothesis. The hypothesis was: dense projections are overallocated, and a learned per-token policy can match dense quality at lower projection cost. Per-token signal exists; static head pruning recovers ~12% slack; SVD-with-fine-tune recovery probably opens more room. None of those went away.

What is falsified is the specific design — per-layer linear router, per-layer uniform top-K, GPT-2-small scale, hundreds of training steps. That design converges to a result the simpler static baseline beats by ~15%. The signal the gradient-based instrumentation said was there is, at this combination of architecture and training budget, sub-margin.

What it suggests

Two things come back to me from running this:

Static heuristics keep winning at small scale. CSA’s result was the same shape: a learned router matched but didn’t beat a static per-(layer, head) allocation rule derived from a single calibration pass. Here the static baseline isn’t even beaten. The pattern is consistent: at the scales a laptop can run, the predictable structure in transformer compute is dominated by static per-(layer, head) properties, with per-token routing adding less than the cost of evaluating it.

This may not generalize to larger models, longer training, or richer router architectures — and the natural next experiments would scale exactly those things up. But it should at least make us suspicious of the architecture work that’s assuming per-token routing will pay off without checking against the static baseline first.

The right intervention may not be dynamic routing of dense capacity. It may be reshaping the dense capacity itself. If projection-FLOPs reductions on pretrained models come from per-projection SVD (no), per-head pruning (yes, modestly), or learned-rank fine-tuning (untested here, expensive) — then the practical recipe is probably closer to “calibrate, prune, optionally refit” than “train a router to choose dynamically.” Boring. Plausibly correct.

The KV-memory axis I never got to. MLA-style shared low-rank K/V latent is the strongest existing static baseline there; beating it would require running an MLA reproduction at the same model scale and showing learned heterogeneity helps on top. That’s a real piece of work, and it’s the natural place to go if anyone has GPU budget pointed at this problem.