[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96
Open
Longxmas wants to merge 8 commits into
Open
[KDA] KDA MTP decode: recurrent + KVBuffer chunkwise verify + flush#96Longxmas wants to merge 8 commits into
Longxmas wants to merge 8 commits into
Conversation
Single-kernel recurrent gated-delta-rule multi-token-prediction decode with register-resident state. vk (lane=K butterfly-reduce) and ws (4-warp warp-spec) CuTe ops behind a unified dispatch; single-token T=1 routes to vk regardless of batch.
Chunkwise parallel-verification KVBuffer ops for KDA MTP speculative decoding: tp (token-parallel SIMT) and cute-gemm (sm90 tensor-core) verify emit a compact u-buffer instead of T per-token states; rank-m flush rebuilds the accepted state. Adds unit + determinism tests and the unified decode-mtp benchmark.
…added CTAs, empty dummies flush kvbuffer: accept_len is now per-request and read at runtime from an [N] int32 buffer (m_buf[i_n]) instead of a compile-time constant. The kernel statically unrolls T and masks i_i < m_n, so it compiles exactly once per (shape, BV) regardless of accept length; b_m uses the per-request token m_n-1. Host API accepts an int (broadcast to all N) or a per-request [N] tensor. small-batch decode (vk + kv): wrap the compute body in `if cache_idx >= 0:` so padded slots (cache_idx < 0) skip the whole T-loop, matching the ws kernel (~1.3x on a half-padded batch). kv hoists its k_split constexpr decisions to top level so they stay python constants inside the guarded block. kvbuffer verify: torch.empty instead of torch.zeros for the write_ubuf=False dummy buffers (only ever written, never read) — drops a per-call memset.
- cg SMEM stride K+8 -> K+4: fixes MMA fragment bank conflict (-16~19% @ N>=4, bit-identical) - drop redundant P4 doubling-chain barrier - 3xTF32 emulation on P3 & GEMM1 GEMMs: cgkvb max|Δ| 2.44e-4 -> ~vk level/bf16 floor (+5~6% cgkvb_v @ large N) - bench: default H=HV=32, graph-calls=20
4fd3268 to
9bf0493
Compare
Model config kda_safe_gate=true with kda_lower_bound set uses the safe gate g = lower_bound * sigmoid(exp(A_log) * x); the MTP decode gates previously only implemented the softplus gate g = -exp(A_log) * softplus(x). Add the lower_bound branch (compile-time const_expr; lower_bound=None keeps the softplus path, bit-identical) to all five gates: vk/ws/kv in kda_decode_mtp and tp/cg in kda_decode_mtp_kvbuffer, threaded through launcher / compile cache key / host / dispatch. Tests: oracle lower_bound branch + test_lower_bound_safe_gate (vk/ws/kv) + test_lower_bound_kvbuffer (tp/cg).
50a60e6 to
cf3c1e7
Compare
dynamic-N: mark the batch (and state-pool) axes dynamic via mark_compact_shape_dynamic(mode=0, stride_order=...) and mark_layout_dynamic() for the index tensors, and drop N + pool_size from the compile cache key, so one cubin serves all batch sizes (removes the per-N JIT; no startup prewarm needed). Applied to all five MTP decode gate kernels: vk/ws/kv in kda_decode_mtp and tp/cg in kda_decode_mtp_kvbuffer. Validated: unit tests bit-exact/bf16-level, e2e gsm8k unchanged at 0.8696 with cuda-graph capture succeeding without prewarm. ruff: fix lint in the touched files (E402 noqa on intentional mid-file imports, SIM comparison order, isort, F841 unused var, E702 multiple statements).
Run the ruff-format pre-commit hook on the files it flagged for pre-existing format drift: cula/ops/kda_decode_mtp.py, cula/ops/kda_decode_mtp_kvbuffer.py, tests/test_kda_decode_mtp.py, benchmarks/bench_kda_decode_mtp.py. Formatting only, no logic change.
Add kda_flush_kvbuffer_all_layers: one launch over all L KDA layers (2D grid, x = single-layer grid, y = layer) instead of the per-layer Python loop the caller previously used for the spec-decode state commit. dyn-N: N (request count) is not a compile constant -- the index tensors are marked layout-dynamic and N is dropped from the cache key, so one cubin serves all batch sizes (no per-N compile storm). cute.compile traces the real, already-allocated tensors directly. Bit-identical to the per-layer kda_flush_kvbuffer (MAXDIFF=0 vs the loop over all layers); the single-layer entry point is left unchanged.
68e8f3f to
0da49c7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
📌 Description
Adds KDA (Kimi Delta Attention) multi-token-prediction (MTP) decode — the target-side gated-delta-rule recurrence for speculative decoding — in two complementary forms:
Recurrent (
kda_decode_mtp): a single register-resident CuTe kernel threading the recurrence over the T draft tokens per (batch, head). Two 1-CTA ops —vk(lane=K butterfly-reduce) /ws(4-warp warp-spec) — behind a unified dispatch.KVBuffer chunkwise verify (
kda_decode_mtp_kvbuffer): the parallel-verification path — verify emits a compact u-buffer (~2·T·d) instead of the T per-token states (T·d²), and a rank-m flush rebuilds the accepted state. Two variants —tp(token-parallel SIMT) /cute-gemm(sm90 tensor-core, flat-in-T).## What changed
cula/ops/kda_decode_mtp.py— recurrent vk/ws ops + unified dispatch.cula/ops/kda_decode_mtp_kvbuffer.py— tp / cute-gemm chunkwise verify + rank-m flush.tests/test_kda_decode_mtp.py— unit (vs fp32 oracle) + bit-exact determinism.benchmarks/bench_kda_decode_mtp.py— unified verify-chain benchmark.🔍 Related Issues
Closes #17
🧪 Tests
pytest tests/test_kda_decode_mtp.py— recurrent (vk/ws/kv) + kvbuffer (tp/cg) verify output & rank-m flush vs the fp32 single-token recurrence oracle, plus bit-exact determinism (torch.equal).⚡ Performance
H200 (HBM3e), K=V=128, bf16, accept m=full, official sglang scatter commit, CUDA-graph kernel-only. Each cell = the best-dispatch verify + state-update chain speedup vs official Triton recurrent (
fused_sigmoid_gating_delta_rule_update, SGLang). The unified dispatch picks the fastest of {vk, ws, tp, cg} per shape — recurrent vk/ws write T·d² states + scatter commit; KVBuffer tp/cg write a u-buffer + rank-m flush. >1 means faster than Triton.Best method / Triton (HV=H=32)
Best method / Triton (HV=H=64)
Takeaways:
Best-dispatch chain vs Triton: 1.41× – 2.50× across all shapes (both HV), strongest at T≥4. The dispatch routes T=2 → tp (token-parallel SIMT, best at small T), T≥3 → cute-gemm (flat-in-T), and tiny B≤2 at small T → recurrent vk.
flat-in-T: the cute-gemm verify kernel grows only +14–20% from T=2→6 while Triton recurrent grows +104–124%, reproducing the KVBuffer paper's Fig.4; the verify kernel alone (accept-independent) reaches up to 3.51× (B=128, T=6).
Memory: the u-buffer (~2·T·d) replaces the T·d² intermediate states → ~43× less rollback storage (d=128), independent of latency.
Correctness: tp ≤ 6.1e-5, cg ≤ 2.44e-4 max|Δ| vs Triton (bf16), well within bf16 noise.