[feat] ULTRA-HSTU: FP8 attention via fbgemm_gpu_hstu wheel#532
Open
tiankongdeguiji wants to merge 16 commits into
Open
[feat] ULTRA-HSTU: FP8 attention via fbgemm_gpu_hstu wheel#532tiankongdeguiji wants to merge 16 commits into
tiankongdeguiji wants to merge 16 commits into
Conversation
Add an int32 `fp8_quant_mode` (default -1) to the STU message. -1 keeps attention in bf16/fp16; 0..5 select an FP8 mode forwarded to the CUTLASS (SM90/Hopper) kernel. Mirrors the wheel's quant_mode int and the existing scaling_seqlen=-1 sentinel style. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thread fp8_quant_mode (default -1) from STULayer down through hstu_preprocess_and_attention and hstu_mha to cutlass_hstu_mha, which forwards it as the wheel's quant_mode. The fbgemm_gpu_hstu wheel quantizes q/k/v internally (fwd + bwd on SM90), so inputs stay bf16/fp16. Fail loud on misconfig: cutlass_hstu_mha asserts SM90 capability and range [-1,5] when fp8_quant_mode>=0 (gated by is_fx_tracing so export tracing on a non-Hopper box is unaffected); hstu_mha rejects fp8_quant_mode>=0 with a non-CUTLASS kernel; the fused-Triton preprocess path rejects it too. The cached/delta serving path is unchanged and always runs bf16/fp16. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add test_attn_fp8_cutlass: compares CUTLASS fp8_quant_mode in {0,3} against
the bf16 PyTorch reference (fwd + bwd) with relaxed tolerances. Gated to
SM90 via _fp8_unavailable, so it is a no-op on the A10 dev box and only
runs on H20. Extend the shared test_attn helper with fp8_quant_mode
(forwarded only to the real-kernel call) and refresh the stale
"fp8 deferred" comment in test_sla_attn_cutlass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The new build adds the FP8 attention kernels (hstu_varlen_fwd_90 with quant_mode>=0) and the blackwell_rtx dispatcher fix needed for the .so to link. Verified on H20 (SM90): import OK and test_attn_fp8_cutlass passes (quant_mode 0 & 3, fwd+bwd parity vs bf16 PyTorch reference). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
b8ab33f to
3aa38e1
Compare
The new wheel version embeds the date+sha+cuda tag directly, so DEVICE alone (cu126/cu129) replaces the prior DEVICE_DOTTED parameterization. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Relax the FP8 capability gate from "exactly SM90 (Hopper)" to "SM90+" so the same fp8_quant_mode>=0 path also runs on sm100 (Blackwell) and sm120 (Blackwell RTX). The wheel dispatches to its per-arch FP8 kernel internally (sm120/Blackwell RTX is forward-only and supports only quant_mode=2; that constraint surfaces from the wheel's own check, not tzrec's). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous "SM90+" gate was too permissive: - SM100 (Blackwell datacenter) has no FP8 kernel in the wheel; the dispatcher routes there via _sm100.hstu_varlen_fwd_100 which doesn't even take quant_mode (cuda_hstu_attention.py:399-403). - SM120 (Blackwell RTX) only handles quant_mode==2 (per-block, fwd-only, cuda_hstu_attention.py:282); for any other mode the wheel silently falls into the sm80 bf16/fp16 branch (line 308's `or major_version == 12`) -- the user gets non-FP8 attention with no warning. Tighten _assert_fp8_capable to accept exactly (sm90, any mode) or (sm120, mode=2), and reject everything else loudly. Pass fp8_quant_mode into the helper so it can mode-check on sm120. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts: # tzrec/version.py
Make the FP8 parity test exercise every quant_mode 0..5 (SM90/H20
supports all of them). Add @mark_ci_scope("h20") explicitly at the
method (matches the per-method pattern used in rank_integration_test).
Bump max_examples 10 -> 20 to give the wider mode space coverage.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous comment claimed FP8 only runs on SM90, which is wrong -- SM120 (Blackwell RTX) also has an FP8 kernel (mode=2 forward-only). The test gate stays SM90-strict because the test samples all six modes, which only SM90 supports together; update the comment + skip message to say so. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the module-level _fp8_unavailable constant; put the SM90-required check directly on the test as a @unittest.skipIf decorator with a short explanatory comment above. Keeps the gate where it's enforced. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
mode=1 (two-direction) routes through the wheel's vt TMA descriptor path, which fails to initialize for some shapes (e.g. batch_size=6, heads=4, max_uih_len=100, attn_dim=128 surfaces "Error: Failed to initialize the TMA descriptor 1"). The five remaining modes (per-tensor, per-block, per-head, per-batch, global per-tensor) still give good coverage of the FP8 dispatch. Revisit once the wheel side TMA constraint is fixed upstream. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Set fp8_quant_mode=2 (per-block) on each STU in test_rank_ultra_hstu_cutlass_train_eval_export so the H20 CI lane exercises FP8 through train + eval + AOTI export + predict, not just the unit-test parity check. Mode 2 is the choice the wheel supports across both SM90 (Hopper) and SM120 (Blackwell RTX). The override is applied in the test method (load the source config, set the field on both MoT channels, save to a tmp config under test_dir, hand that to test_train_eval) so the upstream tzrec/tests/configs/ultra_hstu_cutlass_kuairand_1k.config (which is referenced from docs/source/models/ultra_hstu.md) stays as is. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Modes 1..5 each have a wheel-side Python quant helper that's data-dependent over per-batch lengths, which torch.export can't trace through cleanly. Wrapping the helper as torch.library.custom_op unblocks export, but AOTI's runtime then fails on aoti_torch__reinterpret_tensor against the unbacked-SymInt-shaped descale buffer -- a deeper AOTI + custom_op interaction. Mode 0 is the only mode whose forward is pure tensor ops (q.to(fp8_e4m3fn) + dummy 1.0 descales), so it survives torch.export + AOTI + predict. Use it for the H20 e2e until the wheel-side path is restructured. A draft FBGEMM-nv patch that wraps quantize_for_block_scale as a custom_op (unblocks export but not predict for mode 2) is kept under experiments/hstu-fp8-build/ in the worktree for future iteration. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Root cause of the mode>=1 export/AOTI failure: the wheel's hstu_attn_varlen_func quantizes q/k/v in Python (quantize_for_block_scale etc.) before dispatching to the registered fbgemm::hstu_varlen_fwd_90 op. That quantizer iterates over per-sample lengths from cu_seqlens and emits descale tensors whose block-count dim is data-dependent. torch.export traces straight through it (cutlass_hstu_mha's @torch.fx.wrap is a symbolic_trace-only marker that dynamo/export ignore), so modes 2/3/4/5 raise GuardOnDataDependentSymNode; wrapping just the quantizer as a custom op moves the failure to AOTI runtime (unbacked-SymInt-shaped descale buffers it cannot reinterpret). Mode 0 happened to survive only because its "quantizer" is pure tensor ops (q.to(fp8) + scalar descales). Fix: register tzrec::cutlass_hstu_fp8_fwd as a torch.library.custom_op that wraps the ENTIRE fp8 forward (quantization + kernel). Its fake returns v.new_empty(v.shape), so the exported graph contains a single opaque op with backed shapes -- no quantizer Python, no escaping descales. cutlass_hstu_mha routes to it only under torch.compiler.is_exporting(); eager train/eval still call hstu_attn_varlen_func directly so autograd is unchanged, and predict runs the real quantizer inside the op via the AOTI proxy executor. Works with the stock wheel (no quantizer patch needed). Integration test back to quant_mode=2 end-to-end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds FP8 attention support to the ULTRA-HSTU stack. The
fbgemm_gpu_hstuwheel already quantizes q/k/v internally (forward and backward) whenhstu_attn_varlen_funcis called withquant_mode>=0, so this PR is a thin pass-through: one int proto field onSTUplumbed down to the CUTLASS attention call, plus an opaque export op so the FP8 path survives torch.export / AOTInductor.Design
optional int32 STU.fp8_quant_mode = 17 [default = -1]—-1= off (bf16/fp16);0..5select an FP8 mode (per-tensor, two-direction, per-block, per-head, per-batch, global). The int mirrors the wheel's contract exactly and follows the existingscaling_seqlen=-1/contextual_seq_len=-1sentinel style. Lives onSTUso it flows throughconfig_to_kwargs(ch) → STULayer(**stu)with noHSTUTransducer.__init__change, and each MoT channel can set its own.STULayer.__init__/forward → hstu_preprocess_and_attention → hstu_mha → cutlass_hstu_mha → hstu_attn_varlen_func(quant_mode=…). q/k/v stay bf16/fp16; the wheel does all FP8 quantization internally. No re-port of quantize helpers, no schema extension on an in-house op.quant_mode=2only, forward-only) — matches the wheel's per-arch dispatch. SM80/SM100 have no FP8 kernel and are rejected; SM120quant_mode≠2is rejected (wheel silently falls back to bf16 there)._assert_fp8_capableisis_fx_tracing-guarded so export on a non-capable box doesn't fault at trace time.hstu_mharejects FP8 with any non-CUTLASS kernel; the fused-Triton preprocess path rejects it too. The cached/delta serving path is unchanged (bf16/fp16).quantize_for_block_scaleetc.) before the registeredfbgemm::hstu_varlen_fwd_90op. That quantizer iterates over per-sample lengths fromcu_seqlensand emits descale tensors with a data-dependent block-count dim. torch.export traces straight through it (@torch.fx.wrapis symbolic-trace-only; dynamo/export ignore it), so modes 2/3/4/5 raiseGuardOnDataDependentSymNode. Fix: registertzrec::cutlass_hstu_fp8_fwdas atorch.library.custom_opwrapping the entire FP8 forward (quantization + kernel) with a fake returningv.new_empty(v.shape). The exported graph then holds a single opaque op with backed shapes — no quantizer Python, no escaping descales.cutlass_hstu_mharoutes to it only undertorch.compiler.is_exporting(); eager train/eval keep callinghstu_attn_varlen_funcdirectly (autograd intact), and predict runs the real quantizer inside the op via the AOTI proxy executor. Works with the stock wheel.fbgemm_gpu_hstuto0.1.0+20260528.5f13f139.cu129. Version bump to1.2.16.Verification (H20, SM90)
test_attn_fp8_cutlass,@mark_ci_scope("h20")): FP8 fwd + bwd vs bf16 PyTorch reference withinatol/rtol≈0.1, samplingfp8_quant_mode ∈ {0, 2, 3, 4, 5}(mode 1 excluded — wheel-sidevtTMA descriptor init fails for some shapes). SM90-gated, auto-skips on the local A10.test_rank_ultra_hstu_cutlass_train_eval_export):fp8_quant_mode=2end-to-end — train + eval + AOTI export + predict all pass on H20. This config also has SLA (sla_k1/sla_k2) + attention truncation + MoT, so FP8 is exercised together with the NFUNC mask path through export and the compiled binary.fp8_quant_mode=-1is bit-identical to omitting the kwarg, so the existing CUTLASS path is unchanged.Known limitations
quant_mode=1is excluded from the unit parity test due to a wheel-sidevtTMA-descriptor init failure on certain shapes; the tzrec code path supports it, but it isn't covered by a test.User-facing docs land later in the consolidated
ultra_hstu.md; this PR keepsdocs/source/models/dlrm_hstu.mdto the install-line update only, per the ULTRA-HSTU sub-PR convention.🤖 Generated with Claude Code