Skip to content

[feat] ULTRA-HSTU: FP8 attention via fbgemm_gpu_hstu wheel#532

Open
tiankongdeguiji wants to merge 16 commits into
masterfrom
feat/ultra-hstu-fp8
Open

[feat] ULTRA-HSTU: FP8 attention via fbgemm_gpu_hstu wheel#532
tiankongdeguiji wants to merge 16 commits into
masterfrom
feat/ultra-hstu-fp8

Conversation

@tiankongdeguiji

@tiankongdeguiji tiankongdeguiji commented May 28, 2026

Copy link
Copy Markdown
Collaborator

Adds FP8 attention support to the ULTRA-HSTU stack. The fbgemm_gpu_hstu wheel already quantizes q/k/v internally (forward and backward) when hstu_attn_varlen_func is called with quant_mode>=0, so this PR is a thin pass-through: one int proto field on STU plumbed down to the CUTLASS attention call, plus an opaque export op so the FP8 path survives torch.export / AOTInductor.

Design

  • Proto: optional int32 STU.fp8_quant_mode = 17 [default = -1]-1 = off (bf16/fp16); 0..5 select an FP8 mode (per-tensor, two-direction, per-block, per-head, per-batch, global). The int mirrors the wheel's contract exactly and follows the existing scaling_seqlen=-1 / contextual_seq_len=-1 sentinel style. Lives on STU so it flows through config_to_kwargs(ch) → STULayer(**stu) with no HSTUTransducer.__init__ change, and each MoT channel can set its own.
  • Threading: STULayer.__init__/forward → hstu_preprocess_and_attention → hstu_mha → cutlass_hstu_mha → hstu_attn_varlen_func(quant_mode=…). q/k/v stay bf16/fp16; the wheel does all FP8 quantization internally. No re-port of quantize helpers, no schema extension on an in-house op.
  • Arch gate (fail-loud, no silent degradation): FP8 is accepted on SM90 (Hopper, all modes, fwd+bwd) or SM120 (Blackwell RTX, quant_mode=2 only, forward-only) — matches the wheel's per-arch dispatch. SM80/SM100 have no FP8 kernel and are rejected; SM120 quant_mode≠2 is rejected (wheel silently falls back to bf16 there). _assert_fp8_capable is is_fx_tracing-guarded so export on a non-capable box doesn't fault at trace time. hstu_mha rejects FP8 with any non-CUTLASS kernel; the fused-Triton preprocess path rejects it too. The cached/delta serving path is unchanged (bf16/fp16).
  • Export / AOTInductor: the wheel quantizes q/k/v in Python (quantize_for_block_scale etc.) before the registered fbgemm::hstu_varlen_fwd_90 op. That quantizer iterates over per-sample lengths from cu_seqlens and emits descale tensors with a data-dependent block-count dim. torch.export traces straight through it (@torch.fx.wrap is symbolic-trace-only; dynamo/export ignore it), so modes 2/3/4/5 raise GuardOnDataDependentSymNode. Fix: register tzrec::cutlass_hstu_fp8_fwd as a torch.library.custom_op wrapping the entire FP8 forward (quantization + kernel) with a fake returning v.new_empty(v.shape). The exported graph then holds a single opaque op with backed shapes — no quantizer Python, no escaping descales. cutlass_hstu_mha routes to it only under torch.compiler.is_exporting(); eager train/eval keep calling hstu_attn_varlen_func directly (autograd intact), and predict runs the real quantizer inside the op via the AOTI proxy executor. Works with the stock wheel.
  • Dep: bump fbgemm_gpu_hstu to 0.1.0+20260528.5f13f139.cu129. Version bump to 1.2.16.

Verification (H20, SM90)

  • Unit (test_attn_fp8_cutlass, @mark_ci_scope("h20")): FP8 fwd + bwd vs bf16 PyTorch reference within atol/rtol≈0.1, sampling fp8_quant_mode ∈ {0, 2, 3, 4, 5} (mode 1 excluded — wheel-side vt TMA descriptor init fails for some shapes). SM90-gated, auto-skips on the local A10.
  • Integration (test_rank_ultra_hstu_cutlass_train_eval_export): fp8_quant_mode=2 end-to-end — train + eval + AOTI export + predict all pass on H20. This config also has SLA (sla_k1/sla_k2) + attention truncation + MoT, so FP8 is exercised together with the NFUNC mask path through export and the compiled binary.
  • Non-FP8 regression: fp8_quant_mode=-1 is bit-identical to omitting the kwarg, so the existing CUTLASS path is unchanged.

Known limitations

  • FP8 quant_mode=1 is excluded from the unit parity test due to a wheel-side vt TMA-descriptor init failure on certain shapes; the tzrec code path supports it, but it isn't covered by a test.

User-facing docs land later in the consolidated ultra_hstu.md; this PR keeps docs/source/models/dlrm_hstu.md to the install-line update only, per the ULTRA-HSTU sub-PR convention.

🤖 Generated with Claude Code

tiankongdeguiji and others added 5 commits May 27, 2026 11:03
Add an int32 `fp8_quant_mode` (default -1) to the STU message. -1 keeps
attention in bf16/fp16; 0..5 select an FP8 mode forwarded to the CUTLASS
(SM90/Hopper) kernel. Mirrors the wheel's quant_mode int and the existing
scaling_seqlen=-1 sentinel style.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thread fp8_quant_mode (default -1) from STULayer down through
hstu_preprocess_and_attention and hstu_mha to cutlass_hstu_mha, which
forwards it as the wheel's quant_mode. The fbgemm_gpu_hstu wheel quantizes
q/k/v internally (fwd + bwd on SM90), so inputs stay bf16/fp16.

Fail loud on misconfig: cutlass_hstu_mha asserts SM90 capability and range
[-1,5] when fp8_quant_mode>=0 (gated by is_fx_tracing so export tracing on
a non-Hopper box is unaffected); hstu_mha rejects fp8_quant_mode>=0 with a
non-CUTLASS kernel; the fused-Triton preprocess path rejects it too. The
cached/delta serving path is unchanged and always runs bf16/fp16.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add test_attn_fp8_cutlass: compares CUTLASS fp8_quant_mode in {0,3} against
the bf16 PyTorch reference (fwd + bwd) with relaxed tolerances. Gated to
SM90 via _fp8_unavailable, so it is a no-op on the A10 dev box and only
runs on H20. Extend the shared test_attn helper with fp8_quant_mode
(forwarded only to the real-kernel call) and refresh the stale
"fp8 deferred" comment in test_sla_attn_cutlass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The new build adds the FP8 attention kernels (hstu_varlen_fwd_90 with
quant_mode>=0) and the blackwell_rtx dispatcher fix needed for the .so
to link. Verified on H20 (SM90): import OK and test_attn_fp8_cutlass
passes (quant_mode 0 & 3, fwd+bwd parity vs bf16 PyTorch reference).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tiankongdeguiji and others added 11 commits May 29, 2026 11:12
The new wheel version embeds the date+sha+cuda tag directly, so DEVICE
alone (cu126/cu129) replaces the prior DEVICE_DOTTED parameterization.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Relax the FP8 capability gate from "exactly SM90 (Hopper)" to "SM90+" so
the same fp8_quant_mode>=0 path also runs on sm100 (Blackwell) and sm120
(Blackwell RTX). The wheel dispatches to its per-arch FP8 kernel
internally (sm120/Blackwell RTX is forward-only and supports only
quant_mode=2; that constraint surfaces from the wheel's own check, not
tzrec's).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous "SM90+" gate was too permissive:
- SM100 (Blackwell datacenter) has no FP8 kernel in the wheel; the
  dispatcher routes there via _sm100.hstu_varlen_fwd_100 which doesn't
  even take quant_mode (cuda_hstu_attention.py:399-403).
- SM120 (Blackwell RTX) only handles quant_mode==2 (per-block, fwd-only,
  cuda_hstu_attention.py:282); for any other mode the wheel silently
  falls into the sm80 bf16/fp16 branch (line 308's `or major_version ==
  12`) -- the user gets non-FP8 attention with no warning.

Tighten _assert_fp8_capable to accept exactly (sm90, any mode) or
(sm120, mode=2), and reject everything else loudly. Pass fp8_quant_mode
into the helper so it can mode-check on sm120.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make the FP8 parity test exercise every quant_mode 0..5 (SM90/H20
supports all of them). Add @mark_ci_scope("h20") explicitly at the
method (matches the per-method pattern used in rank_integration_test).
Bump max_examples 10 -> 20 to give the wider mode space coverage.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous comment claimed FP8 only runs on SM90, which is wrong --
SM120 (Blackwell RTX) also has an FP8 kernel (mode=2 forward-only).
The test gate stays SM90-strict because the test samples all six modes,
which only SM90 supports together; update the comment + skip message to
say so.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the module-level _fp8_unavailable constant; put the SM90-required
check directly on the test as a @unittest.skipIf decorator with a short
explanatory comment above. Keeps the gate where it's enforced.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
mode=1 (two-direction) routes through the wheel's vt TMA descriptor
path, which fails to initialize for some shapes
(e.g. batch_size=6, heads=4, max_uih_len=100, attn_dim=128 surfaces
"Error: Failed to initialize the TMA descriptor 1"). The five remaining
modes (per-tensor, per-block, per-head, per-batch, global per-tensor)
still give good coverage of the FP8 dispatch. Revisit once the wheel
side TMA constraint is fixed upstream.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Set fp8_quant_mode=2 (per-block) on each STU in
test_rank_ultra_hstu_cutlass_train_eval_export so the H20 CI lane
exercises FP8 through train + eval + AOTI export + predict, not just
the unit-test parity check. Mode 2 is the choice the wheel supports
across both SM90 (Hopper) and SM120 (Blackwell RTX). The override is
applied in the test method (load the source config, set the field on
both MoT channels, save to a tmp config under test_dir, hand that to
test_train_eval) so the upstream
tzrec/tests/configs/ultra_hstu_cutlass_kuairand_1k.config (which is
referenced from docs/source/models/ultra_hstu.md) stays as is.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Modes 1..5 each have a wheel-side Python quant helper that's
data-dependent over per-batch lengths, which torch.export can't trace
through cleanly. Wrapping the helper as torch.library.custom_op
unblocks export, but AOTI's runtime then fails on
aoti_torch__reinterpret_tensor against the unbacked-SymInt-shaped
descale buffer -- a deeper AOTI + custom_op interaction. Mode 0 is the
only mode whose forward is pure tensor ops (q.to(fp8_e4m3fn) + dummy
1.0 descales), so it survives torch.export + AOTI + predict. Use it
for the H20 e2e until the wheel-side path is restructured.

A draft FBGEMM-nv patch that wraps quantize_for_block_scale as a
custom_op (unblocks export but not predict for mode 2) is kept under
experiments/hstu-fp8-build/ in the worktree for future iteration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Root cause of the mode>=1 export/AOTI failure: the wheel's
hstu_attn_varlen_func quantizes q/k/v in Python (quantize_for_block_scale
etc.) before dispatching to the registered fbgemm::hstu_varlen_fwd_90 op.
That quantizer iterates over per-sample lengths from cu_seqlens and emits
descale tensors whose block-count dim is data-dependent. torch.export
traces straight through it (cutlass_hstu_mha's @torch.fx.wrap is a
symbolic_trace-only marker that dynamo/export ignore), so modes 2/3/4/5
raise GuardOnDataDependentSymNode; wrapping just the quantizer as a
custom op moves the failure to AOTI runtime (unbacked-SymInt-shaped
descale buffers it cannot reinterpret). Mode 0 happened to survive only
because its "quantizer" is pure tensor ops (q.to(fp8) + scalar descales).

Fix: register tzrec::cutlass_hstu_fp8_fwd as a torch.library.custom_op
that wraps the ENTIRE fp8 forward (quantization + kernel). Its fake
returns v.new_empty(v.shape), so the exported graph contains a single
opaque op with backed shapes -- no quantizer Python, no escaping
descales. cutlass_hstu_mha routes to it only under
torch.compiler.is_exporting(); eager train/eval still call
hstu_attn_varlen_func directly so autograd is unchanged, and predict runs
the real quantizer inside the op via the AOTI proxy executor. Works with
the stock wheel (no quantizer patch needed). Integration test back to
quant_mode=2 end-to-end.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant