Skip to content

[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96

Open
Longxmas wants to merge 8 commits into
inclusionAI:mainfrom
Longxmas:feat/kda-mtp-verify-4ops
Open

[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96
Longxmas wants to merge 8 commits into
inclusionAI:mainfrom
Longxmas:feat/kda-mtp-verify-4ops

Conversation

@Longxmas

Copy link
Copy Markdown
Collaborator


📌 Description

Adds KDA (Kimi Delta Attention) multi-token-prediction (MTP) decode — the target-side gated-delta-rule recurrence for speculative decoding — in two complementary forms:

  • Recurrent (kda_decode_mtp): a single register-resident CuTe kernel threading the recurrence over the T draft tokens per (batch, head). Two 1-CTA ops — vk (lane=K butterfly-reduce) / ws (4-warp warp-spec) — behind a unified dispatch.

  • KVBuffer chunkwise verify (kda_decode_mtp_kvbuffer): the parallel-verification path — verify emits a compact u-buffer (~2·T·d) instead of the T per-token states (T·d²), and a rank-m flush rebuilds the accepted state. Two variants — tp (token-parallel SIMT) / cute-gemm (sm90 tensor-core, flat-in-T).

## What changed

  • cula/ops/kda_decode_mtp.py — recurrent vk/ws ops + unified dispatch.

  • cula/ops/kda_decode_mtp_kvbuffer.py — tp / cute-gemm chunkwise verify + rank-m flush.

  • tests/test_kda_decode_mtp.py — unit (vs fp32 oracle) + bit-exact determinism.

  • benchmarks/bench_kda_decode_mtp.py — unified verify-chain benchmark.

🔍 Related Issues

Closes #17

🧪 Tests

pytest tests/test_kda_decode_mtp.py — recurrent (vk/ws/kv) + kvbuffer (tp/cg) verify output & rank-m flush vs the fp32 single-token recurrence oracle, plus bit-exact determinism (torch.equal).

tests/test_kda_decode_mtp.py::test_oracle_vs_loop[N16-T8-H8-HV16-randstate] PASSED
tests/test_kda_decode_mtp.py::test_ws_decode[N64-T8-H8-HV16-tvNone-ilpNone-smemNone] PASSED
tests/test_kda_decode_mtp.py::test_small_batch_decode[N16-T4-H16-HV32-vk-bv-1-ks1] PASSED
tests/test_kda_decode_mtp.py::test_small_batch_decode[N16-T4-H16-HV32-kv-bv32-ks-1] PASSED
tests/test_kda_decode_mtp.py::test_determinism[sb_vk] PASSED
tests/test_kda_decode_mtp.py::test_intermediate_vs_oracle_and_final[64-4-True] PASSED
tests/test_kda_decode_mtp.py::test_tp_kvbuffer_verify_and_flush[4-4-16-16] PASSED
tests/test_kda_decode_mtp.py::test_cg_kvbuffer_verify_and_flush[4-6-16-16] PASSED
tests/test_kda_decode_mtp.py::test_kvbuffer_dispatch_routes_by_T[4-cg] PASSED
tests/test_kda_decode_mtp.py::test_kvbuffer_verify_determinism[cg-4-6-16-16] PASSED
tests/test_kda_decode_mtp.py::test_kvbuffer_flush_determinism[tp-4-4-16-16] PASSED
============== 106 passed, 27 warnings in 287.71s ==============

⚡ Performance

H200 (HBM3e), K=V=128, bf16, accept m=full, official sglang scatter commit, CUDA-graph kernel-only. Each cell = the best-dispatch verify + state-update chain speedup vs official Triton recurrent (fused_sigmoid_gating_delta_rule_update, SGLang). The unified dispatch picks the fastest of {vk, ws, tp, cg} per shape — recurrent vk/ws write T·d² states + scatter commit; KVBuffer tp/cg write a u-buffer + rank-m flush. >1 means faster than Triton.

Best method / Triton (HV=H=32)

B T=2 T=3 T=4 T=6
1 1.80 1.73 1.75 1.91
2 1.48 1.52 1.63 1.86
4 1.41 1.41 1.63 2.10
8 1.70 1.72 1.82 2.18
16 1.97 1.98 2.20 2.46
32 1.61 1.61 1.78 2.05
64 1.61 1.59 1.80 1.96
128 1.59 1.61 1.84 2.05

Best method / Triton (HV=H=64)

B T=2 T=3 T=4 T=6
1 1.46 1.49 1.68 1.91
2 1.41 1.41 1.57 2.09
4 1.65 1.71 1.82 2.13
8 2.03 2.03 2.23 2.50
16 1.61 1.62 1.78 2.06
32 1.63 1.60 1.80 1.96
64 1.58 1.61 1.84 2.06
128 1.59 1.64 1.85 2.09

Takeaways:

  • Best-dispatch chain vs Triton: 1.41× – 2.50× across all shapes (both HV), strongest at T≥4. The dispatch routes T=2 → tp (token-parallel SIMT, best at small T), T≥3 → cute-gemm (flat-in-T), and tiny B≤2 at small T → recurrent vk.

  • flat-in-T: the cute-gemm verify kernel grows only +14–20% from T=2→6 while Triton recurrent grows +104–124%, reproducing the KVBuffer paper's Fig.4; the verify kernel alone (accept-independent) reaches up to 3.51× (B=128, T=6).

  • Memory: the u-buffer (~2·T·d) replaces the T·d² intermediate states → ~43× less rollback storage (d=128), independent of latency.

  • Correctness: tp ≤ 6.1e-5, cg ≤ 2.44e-4 max|Δ| vs Triton (bf16), well within bf16 noise.

Longxmas added 2 commits June 16, 2026 16:37
Single-kernel recurrent gated-delta-rule multi-token-prediction decode with register-resident state. vk (lane=K butterfly-reduce) and ws (4-warp warp-spec) CuTe ops behind a unified dispatch; single-token T=1 routes to vk regardless of batch.
Chunkwise parallel-verification KVBuffer ops for KDA MTP speculative decoding: tp (token-parallel SIMT) and cute-gemm (sm90 tensor-core) verify emit a compact u-buffer instead of T per-token states; rank-m flush rebuilds the accepted state. Adds unit + determinism tests and the unified decode-mtp benchmark.
@Longxmas Longxmas requested a review from icavan June 16, 2026 09:12
gemini-code-assist[bot]

This comment was marked as outdated.

@Longxmas Longxmas requested a review from zheyang0825 June 16, 2026 09:30
Longxmas added 2 commits June 16, 2026 21:30
…added CTAs, empty dummies

flush kvbuffer: accept_len is now per-request and read at runtime from an
[N] int32 buffer (m_buf[i_n]) instead of a compile-time constant. The kernel
statically unrolls T and masks i_i < m_n, so it compiles exactly once per
(shape, BV) regardless of accept length; b_m uses the per-request token m_n-1.
Host API accepts an int (broadcast to all N) or a per-request [N] tensor.

small-batch decode (vk + kv): wrap the compute body in `if cache_idx >= 0:` so
padded slots (cache_idx < 0) skip the whole T-loop, matching the ws kernel
(~1.3x on a half-padded batch). kv hoists its k_split constexpr decisions to
top level so they stay python constants inside the guarded block.

kvbuffer verify: torch.empty instead of torch.zeros for the write_ubuf=False
dummy buffers (only ever written, never read) — drops a per-call memset.
- cg SMEM stride K+8 -> K+4: fixes MMA fragment bank conflict (-16~19% @ N>=4, bit-identical)
- drop redundant P4 doubling-chain barrier
- 3xTF32 emulation on P3 & GEMM1 GEMMs: cgkvb max|Δ| 2.44e-4 -> ~vk level/bf16 floor (+5~6% cgkvb_v @ large N)
- bench: default H=HV=32, graph-calls=20
@Longxmas Longxmas force-pushed the feat/kda-mtp-verify-4ops branch from 4fd3268 to 9bf0493 Compare June 23, 2026 10:53
Model config kda_safe_gate=true with kda_lower_bound set uses the safe gate
g = lower_bound * sigmoid(exp(A_log) * x); the MTP decode gates previously only
implemented the softplus gate g = -exp(A_log) * softplus(x). Add the lower_bound
branch (compile-time const_expr; lower_bound=None keeps the softplus path,
bit-identical) to all five gates: vk/ws/kv in kda_decode_mtp and tp/cg in
kda_decode_mtp_kvbuffer, threaded through launcher / compile cache key / host /
dispatch. Tests: oracle lower_bound branch + test_lower_bound_safe_gate
(vk/ws/kv) + test_lower_bound_kvbuffer (tp/cg).
@Longxmas Longxmas force-pushed the feat/kda-mtp-verify-4ops branch from 50a60e6 to cf3c1e7 Compare June 24, 2026 06:42
Longxmas added 3 commits June 24, 2026 16:04
dynamic-N: mark the batch (and state-pool) axes dynamic via
mark_compact_shape_dynamic(mode=0, stride_order=...) and mark_layout_dynamic()
for the index tensors, and drop N + pool_size from the compile cache key, so one
cubin serves all batch sizes (removes the per-N JIT; no startup prewarm needed).
Applied to all five MTP decode gate kernels: vk/ws/kv in kda_decode_mtp and tp/cg
in kda_decode_mtp_kvbuffer. Validated: unit tests bit-exact/bf16-level, e2e gsm8k
unchanged at 0.8696 with cuda-graph capture succeeding without prewarm.

ruff: fix lint in the touched files (E402 noqa on intentional mid-file imports,
SIM comparison order, isort, F841 unused var, E702 multiple statements).
Run the ruff-format pre-commit hook on the files it flagged for pre-existing
format drift: cula/ops/kda_decode_mtp.py, cula/ops/kda_decode_mtp_kvbuffer.py,
tests/test_kda_decode_mtp.py, benchmarks/bench_kda_decode_mtp.py. Formatting
only, no logic change.
Add kda_flush_kvbuffer_all_layers: one launch over all L KDA layers (2D grid,
x = single-layer grid, y = layer) instead of the per-layer Python loop the
caller previously used for the spec-decode state commit. dyn-N: N (request
count) is not a compile constant -- the index tensors are marked
layout-dynamic and N is dropped from the cache key, so one cubin serves all
batch sizes (no per-N compile storm). cute.compile traces the real,
already-allocated tensors directly.

Bit-identical to the per-layer kda_flush_kvbuffer (MAXDIFF=0 vs the loop over
all layers); the single-layer entry point is left unchanged.
@Longxmas Longxmas force-pushed the feat/kda-mtp-verify-4ops branch from 68e8f3f to 0da49c7 Compare June 25, 2026 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KDA MTP (Multi-Token Prediction) support

1 participant