Add split_sizes param-group option for per-block Newton-Schulz on fused weights#90
Add split_sizes param-group option for per-block Newton-Schulz on fused weights#90alint77 wants to merge 1 commit into
Conversation
…ed weights A fused QKV weight keeps the model on a single wide GEMM, but orthogonalizing it as one matrix blends the Q, K, and V projections. Unlike the existing num_heads option, the blocks may have unequal sizes under grouped-query attention, and num_heads' local-shard path also requires each FSDP rank to hold whole blocks. The new split_sizes option on Muon and NorMuon param groups partitions dim 0 into row blocks and runs Newton-Schulz independently per block, on the fully assembled matrices after the FSDP all-to-all, so the communication pattern is unchanged from the fused parameter. Blocks of equal size (e.g. K and V) are batched into one Newton-Schulz call. Each block receives the same update it would as a separate parameter: the learning-rate adjustment (spectral_norm / rms_norm) is converted to the per-block value via a post-NS rescale, and NorMuon's norm-preserving rescale runs per block (on the non-sharded path; the sharded path keeps the existing per-shard approximation). Dion2 and NorDion2 raise NotImplementedError for the option. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
This is John's reviewbot. Reviewed the full diff against the source. This is a clean, well-reasoned patch — the core insight that you can split after the all-to-all assembly (keeping comm/momentum/checkpoint layout identical to the fused param, with the only constraint being "no interspersed padding rows") is exactly right, and the divisibility guard enforces it correctly. A few substantive points: Correctness — the math checks out, including the subtle parts.
Test coverage — the headline path is untested. The whole motivation is "num_heads breaks under FSDP, split_sizes doesn't," yet every new test is single-GPU ( Design / possible simplification. The Backward compat / security: clean. All new kwargs default to Perf/throughput: FLOPs-down claim is correct ( Nits (non-blocking): the sum-mismatch validation fires at first Solid work overall — main asks are (1) the epsilon caveat in the NorMuon comment and (2) some automated coverage of the sharded path or at least the divisibility guard. |
|
This feature is mostly aimed at fp8/fp4 training. the bigger qkv matmul amortizes the quantization costs much better compared to separate q,k,v matmuls which would have their own separate quantization/amax kernels that might even cause slowdown with fp8. |
|
and the same can be applied to gated MLPs, where Wi is separated into 2 matrices (up, gate), and doing separate ortho on those also improves loss. |
|
Do you have performance impact summary (computational & statistical)? |
NoahAmsel
left a comment
There was a problem hiding this comment.
I don't know whether the runtime benefits will be significant, but many people have existing codebases with merged QKV and this could make dion easier to adopt. (on the other hand, they are already covered by split heads unless they use grouped query attention). still, i agree it would be great to get a sense of whether this is beneficial in practice (even just for aifsdk)
my only concern is for the added complexity. it is awkward to have even one system for reshaping weight matrices, let alone two totally separate systems. in terms of the implementation: in this PR, _newton_schulz_row_blocks is responsible for doing the slicing / reshaping, but for num_heads, this reshaping happens elsewhere. isn't it better to have it all in one place?
likewise is there anyway to simplify the interface? perhaps we could have a single parameter where you either give the sizes of the blocks explicitly (as in this PR) or you give a number of blocks and it divides the matrix into equal sized pieces (as with num_heads) or perhaps at least some of the rescaling logic can be shared?
|
split head doesn't work when: |
|
h1024 16layer 170M model, FP8 on Wqkv and MLP, 4xH100 (4xGH200) fsdp2: Fused Wqkv: 179ms same model with Hdim 2048: 424ms |
|
breakdown of kernels:
New split mode is naturally a bit slower when looking at the opt.step alone, but this is offset by the gains in fused Wqkv matmul as shown in the e2e training numbers in the post above |
|
At the logged matched steps, split ortho is consistently a little better than whole- Eval Loss
keep in mind this was just a quick AB test at only 8B tokens (global_BS=1M), but the improvement is consistant on every combination of features/sizes that I've tested. In short, the point of this feature is to get the best of both worlds instead of having to choose one: faster gemms due to fused qkv + better loss due to split NS |
Motivation
A fused QKV projection keeps the model on a single wide GEMM (and plays well with fp8 paths that own one fused weight), but Muon orthogonalizes it as one tall matrix, blending the Q, K, and V projections. Splitting the parameter recovers per-projection Newton-Schulz — I've measured it improving loss — at the cost of narrower GEMMs.
The existing
num_headsoption doesn't cover this case: under grouped-query attention the Q and K/V blocks have unequal sizes, and thenum_headslocal-shard path requires each FSDP rank to hold whole blocks, which fails whenkv_dimexceeds the per-rank shard.What this adds
A
split_sizesparam-group option for Muon and NorMuon that partitions dim 0 of a 2D weight into row blocks and runs Newton-Schulz independently per block:Each block receives the same update it would as a separate parameter:
spectral_norm/rms_norm) is converted from the whole-matrix value to the per-block value via a post-NS rescale (adjust(block_shape) / adjust(full_shape)), somuon_update_post_orthogonalizeis unchanged.The split happens on the fully assembled matrices after the FSDP all-to-all, so the communication pattern, momentum state layout, and checkpoint format are all unchanged from the fused parameter. Total NS FLOPs decrease relative to fused (smaller matrices).
Constraints
split_sizesmust sum to dim 0.NotImplementedErrorotherwise.num_headsandflatten=True.NotImplementedErrorfor the option (their orthogonalization operates on local shards).Testing
New
TestSplitSizesclass (21 tests): parity against separate per-block parameters for Muon and NorMuon acrossadjust_lrmodes and nesterov, megabatching, fused momentum state, and validation errors. Full suite passes on a single GH200 (229 passed, 11 skipped multi-GPU).🤖 Generated with Claude Code