Skip to content

Add split_sizes param-group option for per-block Newton-Schulz on fused weights#90

Open
alint77 wants to merge 1 commit into
microsoft:mainfrom
alint77:feat/split-sizes-group-option
Open

Add split_sizes param-group option for per-block Newton-Schulz on fused weights#90
alint77 wants to merge 1 commit into
microsoft:mainfrom
alint77:feat/split-sizes-group-option

Conversation

@alint77

@alint77 alint77 commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Motivation

A fused QKV projection keeps the model on a single wide GEMM (and plays well with fp8 paths that own one fused weight), but Muon orthogonalizes it as one tall matrix, blending the Q, K, and V projections. Splitting the parameter recovers per-projection Newton-Schulz — I've measured it improving loss — at the cost of narrower GEMMs.

The existing num_heads option doesn't cover this case: under grouped-query attention the Q and K/V blocks have unequal sizes, and the num_heads local-shard path requires each FSDP rank to hold whole blocks, which fails when kv_dim exceeds the per-rank shard.

What this adds

A split_sizes param-group option for Muon and NorMuon that partitions dim 0 of a 2D weight into row blocks and runs Newton-Schulz independently per block:

param_groups = [
    dict(params=fused_qkv_params, split_sizes=(q_dim, kv_dim, kv_dim)),
    dict(params=other_matrix_params),
]

Each block receives the same update it would as a separate parameter:

  • Newton-Schulz runs per block; blocks of equal size (e.g. K and V) are stacked into one batched NS call.
  • LR adjustment (spectral_norm / rms_norm) is converted from the whole-matrix value to the per-block value via a post-NS rescale (adjust(block_shape) / adjust(full_shape)), so muon_update_post_orthogonalize is unchanged.
  • NorMuon's norm-preserving rescale runs per block on the non-sharded path. The sharded all-to-all path keeps the existing per-shard normalization (block boundaries aren't locally available there), consistent with current distributed behavior.

The split happens on the fully assembled matrices after the FSDP all-to-all, so the communication pattern, momentum state layout, and checkpoint format are all unchanged from the fused parameter. Total NS FLOPs decrease relative to fused (smaller matrices).

Constraints

  • 2D parameters only; split_sizes must sum to dim 0.
  • With FSDP, dim 0 must be divisible by the world size (otherwise shard padding would intersperse zero rows between rank segments of the assembled matrices); raises NotImplementedError otherwise.
  • Mutually exclusive with num_heads and flatten=True.
  • Dion2 / NorDion2 raise NotImplementedError for the option (their orthogonalization operates on local shards).

Testing

New TestSplitSizes class (21 tests): parity against separate per-block parameters for Muon and NorMuon across adjust_lr modes and nesterov, megabatching, fused momentum state, and validation errors. Full suite passes on a single GH200 (229 passed, 11 skipped multi-GPU).

🤖 Generated with Claude Code

…ed weights

A fused QKV weight keeps the model on a single wide GEMM, but
orthogonalizing it as one matrix blends the Q, K, and V projections.
Unlike the existing num_heads option, the blocks may have unequal sizes
under grouped-query attention, and num_heads' local-shard path also
requires each FSDP rank to hold whole blocks.

The new split_sizes option on Muon and NorMuon param groups partitions
dim 0 into row blocks and runs Newton-Schulz independently per block, on
the fully assembled matrices after the FSDP all-to-all, so the
communication pattern is unchanged from the fused parameter. Blocks of
equal size (e.g. K and V) are batched into one Newton-Schulz call.

Each block receives the same update it would as a separate parameter:
the learning-rate adjustment (spectral_norm / rms_norm) is converted to
the per-block value via a post-NS rescale, and NorMuon's norm-preserving
rescale runs per block (on the non-sharded path; the sharded path keeps
the existing per-shard approximation).

Dion2 and NorDion2 raise NotImplementedError for the option.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@JohnLangford

Copy link
Copy Markdown
Contributor

This is John's reviewbot.

Reviewed the full diff against the source. This is a clean, well-reasoned patch — the core insight that you can split after the all-to-all assembly (keeping comm/momentum/checkpoint layout identical to the fused param, with the only constraint being "no interspersed padding rows") is exactly right, and the divisibility guard enforces it correctly. A few substantive points:

Correctness — the math checks out, including the subtle parts.

  • Block reassembly under FSDP is sound: contiguous chunking + dim0 % world_size == 0padded_local_size == local_sizecat(output_chunks, dim=-2) reproduces the original row order, so block boundaries land where they should. The comm_dim == -2 guard at megabatch_base.py:531 is the right gate.
  • The split_scales = adjust(block)/adjust(full) trick (pre-divide so the single whole-matrix adjusted_lr applied in post_orthogonalize lands per-block) is neat and keeps muon_update_post_orthogonalize untouched. Verified net LR per block = base_lr · adjust(block), matching a separate param.
  • NorMuon parity is approximate, not exact, even on a single GPU at fp32 — and not for the bf16 reason the test comment gives. With V init=0 and a constant per-block scale c, induction gives V' = c²·V exactly, but denom = sqrt(V) + 1e-8 doesn't commute with c (c·sqrt(V)+1e-8 ≠ c·(sqrt(V)+1e-8)). It's negligible for well-scaled updates (and rtol=1e-2 absorbs it), but the normuon.py comment claiming the scales "commute through the normalization self-consistently" overstates it. Worth a one-line honest caveat about the epsilon term.

Test coverage — the headline path is untested. The whole motivation is "num_heads breaks under FSDP, split_sizes doesn't," yet every new test is single-GPU (CUDA_AVAILABLE, no world_size>1). The divisibility NotImplementedError, the assembled-matrix block alignment, and NorMuon's documented per-shard fallback all rest purely on reasoning. The 21 parity tests are good for the local path, but a 2-GPU test (even one in the multi-GPU-skipped tier) exercising the sharded reassembly would meaningfully de-risk the feature that justifies the PR. At minimum, add a unit test asserting the divisibility guard fires (construct the sharded condition directly), since that branch currently has zero coverage.

Design / possible simplification. The split_scales ratio exists only because post_orthogonalize will later multiply by adjusted_lr(full). An alternative that avoids the ratio (and the /full_adjust division, and the per-block Python multiply loop): build a single [rows,1] per-row LR vector = adjust(block_i) broadcast over each block, and apply it as one vectorized multiply on the assembled matrix — no dependence on the full-matrix adjustment at all. It's arguably cleaner and composes with NorMuon identically, at the cost of touching the post-ortho LR application. The current minimal-touch approach is a defensible trade; just noting it. Threading split_sizes/split_scales through 5 call sites is verbose but localized and all-optional.

Backward compat / security: clean. All new kwargs default to None; the new param-group key is opt-in; Dion2/NorDion2 reject it explicitly. No security surface (pure numerics, no external input). Bool-rejection in validation is a nice touch.

Perf/throughput: FLOPs-down claim is correct (Σ rᵢ² ≤ (Σ rᵢ)²), and batching equal-height blocks (K+V) into one NS call is valid since NS is per-batch-element independent. Net wall-clock may not improve on small matrices (more kernel launches + stack/cat/split allocs each step), but the patch sells this on loss quality, not speed, so that's fine.

Nits (non-blocking): the sum-mismatch validation fires at first step() rather than construction — consistent with num_heads, fine; a group carrying split_sizes with algorithm=adamw/lion silently ignores it (a warning would be friendlier).

Solid work overall — main asks are (1) the epsilon caveat in the NorMuon comment and (2) some automated coverage of the sharded path or at least the divisibility guard.

@alint77

alint77 commented Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

This feature is mostly aimed at fp8/fp4 training. the bigger qkv matmul amortizes the quantization costs much better compared to separate q,k,v matmuls which would have their own separate quantization/amax kernels that might even cause slowdown with fp8.

@alint77

alint77 commented Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

and the same can be applied to gated MLPs, where Wi is separated into 2 matrices (up, gate), and doing separate ortho on those also improves loss.

@JohnLangford

Copy link
Copy Markdown
Contributor

Do you have performance impact summary (computational & statistical)?

@NoahAmsel NoahAmsel left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether the runtime benefits will be significant, but many people have existing codebases with merged QKV and this could make dion easier to adopt. (on the other hand, they are already covered by split heads unless they use grouped query attention). still, i agree it would be great to get a sense of whether this is beneficial in practice (even just for aifsdk)

my only concern is for the added complexity. it is awkward to have even one system for reshaping weight matrices, let alone two totally separate systems. in terms of the implementation: in this PR, _newton_schulz_row_blocks is responsible for doing the slicing / reshaping, but for num_heads, this reshaping happens elsewhere. isn't it better to have it all in one place?

likewise is there anyway to simplify the interface? perhaps we could have a single parameter where you either give the sizes of the blocks explicitly (as in this PR) or you give a number of blocks and it divides the matrix into equal sized pieces (as with num_heads) or perhaps at least some of the rescaling logic can be shared?

@alint77

alint77 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

split head doesn't work when:
1- world_size%3 != 0 in fsdp mode
2- gqa, pairedheadattn or diffv2 where q,k,v have different sizes

@alint77

alint77 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

h1024 16layer 170M model, FP8 on Wqkv and MLP, 4xH100 (4xGH200) fsdp2:

Fused Wqkv: 179ms
Unfused Wqkv: 186ms

same model with Hdim 2048:

424ms
451ms

@alint77

alint77 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

breakdown of kernels:

run NorMuon kernel launches summed kernel time NCCL launches NVJit GEMM launches Triton launches
split 1064 33.03 ms 32 384 440
whole 788 29.38 ms 24 288 332

New split mode is naturally a bit slower when looking at the opt.step alone, but this is offset by the gains in fused Wqkv matmul as shown in the e2e training numbers in the post above

@alint77

alint77 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

At the logged matched steps, split ortho is consistently a little better than whole-Wqkv ortho.

Eval Loss

comparison step split whole whole - split
4-node 1000 2.5908 2.5919 +0.0011
4-node 2000 2.5226 2.5261 +0.0035
4-node 3000 2.4818 2.4858 +0.0040
4-node 4000 2.4501 2.4552 +0.0051
4-node 8000 2.3863 2.3917 +0.0054

keep in mind this was just a quick AB test at only 8B tokens (global_BS=1M), but the improvement is consistant on every combination of features/sizes that I've tested.

In short, the point of this feature is to get the best of both worlds instead of having to choose one: faster gemms due to fused qkv + better loss due to split NS

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants