From ac57a3d10bdcc9fc967860dc52be9e5fd77c8763 Mon Sep 17 00:00:00 2001 From: John Langford Date: Fri, 8 May 2026 12:37:09 -0700 Subject: [PATCH 1/2] NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output Issue #76 reports intermittent NaN parameters with NorMuon + gram-newton-schulz + quack-kernels 0.4.1 on multi-GPU DDP. Once a single rank's NS output goes non-finite, the post-ortho update poisons the parameter, the bad value gets all-reduced into the gradient/state on the next step, and the run is dead. Add an opt-in ``nan_guard_fallback`` arg to NorMuon. After the megabatch orthogonalization completes, all ranks check ``isfinite(U_stacked).all()`` and exchange the result via a single-byte ``all_reduce(MAX)``. If any rank flagged non-finite, every rank early-returns from ``normuon_update_megabatch_async``: V (variance buffer) and X (param) stay strictly unchanged, so the run state is bit-identical to "this batch never happened" for these params. M (momentum) was already updated in pre-orthogonalize from the clean gradient and is left alone. Why early-return rather than zero U + run normalization: - normuon_normalization_stacked on zero U decays V toward 0, contaminating future steps' normalization. - weight_decay applies in the post-ortho path; when we know the optimizer step is junk, it's cleaner to skip everything than to apply a half-update. Cost when triggered: one tiny allreduce + one device->host sync per shape group per step. Cost when ``nan_guard_fallback=False`` (default): zero - just one Python-level branch in the optimizer step. Sync is required for correctness in the distributed case: if rank 0 takes the fallback and rank 1 doesn't, DDP weights drift or a future collective deadlocks because rank 1 advances to the next step's megabatch alltoall while rank 0 hasn't. Tests (single-rank CUDA, no NCCL): - ``test_fallback_off_lets_nan_propagate_to_params`` baselines that the bug exists when the guard is off, so the next test isn't trivially green. - ``test_fallback_on_skips_step_when_ns_returns_nan`` verifies bit-exact param + zero variance buffer + RuntimeWarning emission on rank 0. - ``test_fallback_on_does_not_block_normal_step`` verifies the guard is inert when NS output is finite (params change as usual). Existing NorMuon tests in test_optimizers.py still pass. --- dion/normuon.py | 43 ++++++++++++ tests/test_nan_guard_fallback.py | 115 +++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tests/test_nan_guard_fallback.py diff --git a/dion/normuon.py b/dion/normuon.py index 76e7d7c..d77588c 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -1,4 +1,7 @@ +import warnings + import torch +import torch.distributed as dist from collections import defaultdict from torch import Tensor from torch.distributed import ProcessGroup @@ -44,6 +47,15 @@ class NorMuon(DistributedOrthoBase): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is ``func(input: Tensor, epsilon: float) -> Tensor``. + nan_guard_fallback: If True, after each megabatch Newton-Schulz call, + check for non-finite values in the orthogonalized update and skip + the update on all ranks if any rank sees one. Adds a single + ``all_reduce(MAX)`` of one byte per shape group per step when + enabled and the optimizer is distributed. Defensive guard for + issues like microsoft/dion#76 where a Newton-Schulz backend + intermittently produces NaNs; safer than letting the corrupt + update propagate into the parameter and poisoning the run. + Default ``False`` (no extra collective, no extra check). Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -68,6 +80,7 @@ def __init__( use_triton: bool = False, use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, + nan_guard_fallback: bool = False, ): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") @@ -104,6 +117,7 @@ def __init__( use_polar_express=use_polar_express, newton_schulz_func=newton_schulz_func, ) + self._nan_guard_fallback = nan_guard_fallback def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state = super()._get_or_initialize_state(param, algo) @@ -152,6 +166,7 @@ def _create_ortho_tasks( process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], + nan_guard_fallback=self._nan_guard_fallback, ) shape_groups: dict[tuple, list] = defaultdict(list) @@ -215,6 +230,7 @@ def normuon_update_megabatch_async( process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, + nan_guard_fallback: bool = False, ) -> Generator[None, None, None]: """ Mega-batched NorMuon update: processes ALL same-shape parameters in one @@ -263,6 +279,33 @@ def normuon_update_megabatch_async( V_local = to_local(V) U_stacked = torch.stack(U) V_stacked = torch.stack(V_local) + + if nan_guard_fallback: + # Defensive guard: if any rank's NS output went non-finite, all ranks + # must take the same fallback or DDP will diverge / a future + # collective will mismatch. Sync a single-byte ``ReduceOp.MAX`` flag + # across the process group, then on detection bail out of the rest + # of this step's update on all ranks. Skipping the post-ortho path + # (rather than zeroing U) keeps the variance buffer V and the param + # X strictly unchanged, so the run state is identical to "this batch + # never happened" for these params. Momentum M was already updated + # in pre-orthogonalize from the (clean) gradient and is left alone. + # Cost when triggered: one tiny allreduce + one device->host sync per + # shape group per step. Cost when off: a single Python branch. + local_nonfinite = (~torch.isfinite(U_stacked).all()).to(torch.uint8).reshape(1) + if process_group is not None: + dist.all_reduce(local_nonfinite, op=dist.ReduceOp.MAX, group=process_group) + if bool(local_nonfinite.item()): + if device_rank == 0: + warnings.warn( + f"[NorMuon nan_guard_fallback] non-finite Newton-Schulz " + f"output detected; skipping update for this step on all " + f"ranks (shape={tuple(U_stacked.shape)}). See microsoft/" + f"dion#76.", + RuntimeWarning, + stacklevel=2, + ) + return U_stacked, V_stacked = normuon_normalization_stacked(U_stacked, V_stacked, muon_beta2) for i in range(N): V_local[i].copy_(V_stacked[i]) diff --git a/tests/test_nan_guard_fallback.py b/tests/test_nan_guard_fallback.py new file mode 100644 index 0000000..2415486 --- /dev/null +++ b/tests/test_nan_guard_fallback.py @@ -0,0 +1,115 @@ +"""Tests for NorMuon's defensive ``nan_guard_fallback`` option. + +Background. microsoft/dion#76 reports that NorMuon + gram-newton-schulz + +quack-kernels 0.4.1 occasionally produces NaN parameters on certain +hardware/shapes; gradients are clean entering ``optimizer.step()`` but the +post-step parameter is all-NaN. ``nan_guard_fallback`` adds a single +``all_reduce(MAX)`` of one byte per shape group per step so all ranks +agree on whether to skip the update; on detection the entire post-ortho +path is bypassed so the parameter, the variance buffer, and the run state +remain strictly unchanged for these params. + +These tests run single-rank (no NCCL/MPI) and use a custom Newton-Schulz +function that injects NaN, isolating the guard logic from any specific +NS backend. +""" + +import warnings + +import pytest +import torch + +CUDA_AVAILABLE = torch.cuda.is_available() +DEVICE = "cuda" if CUDA_AVAILABLE else "cpu" + + +def _nan_ns(X, epsilon=None): + return torch.full_like(X, float("nan")) + + +def _identity_ns(X, epsilon=None): + # Bypass torch.compile and shape gymnastics; a plain identity is enough + # to validate the guard and the no-fallback baseline. + return X.clone() + + +def _make_param(shape, seed=42): + torch.manual_seed(seed) + return torch.nn.Parameter(torch.randn(shape, device=DEVICE)) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA") +def test_fallback_off_lets_nan_propagate_to_params(): + from dion import NorMuon + + p = _make_param((8, 16)) + p.grad = torch.randn_like(p) + + opt = NorMuon( + [p], + lr=0.01, + newton_schulz_func=_nan_ns, + nan_guard_fallback=False, + ) + opt.step() + + # Without the guard, the NaN NS output flows through normalization and + # the post-ortho update, poisoning the parameter. This is the bug from + # issue #76 in unit-test form. + assert torch.isnan(p.data).any(), ( + "expected NaN to propagate to params with the guard off; if this " + "asserts, NorMuon's NaN behavior changed and the fallback test " + "below may no longer be exercising the right path." + ) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA") +def test_fallback_on_skips_step_when_ns_returns_nan(): + from dion import NorMuon + + p = _make_param((8, 16)) + before = p.data.clone() + p.grad = torch.randn_like(p) + + opt = NorMuon( + [p], + lr=0.01, + newton_schulz_func=_nan_ns, + nan_guard_fallback=True, + ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + opt.step() + + assert torch.isfinite(p.data).all() + # ``return`` before post-ortho means the parameter is bit-exact unchanged. + assert torch.equal(p.data, before) + # Variance buffer must also be left at its initial zero state - skipping + # the step means "this batch never happened" for these params. + assert torch.all(opt.state[p]["variance_neuron"] == 0) + + msgs = [str(w.message) for w in caught if issubclass(w.category, RuntimeWarning)] + assert any("nan_guard_fallback" in m for m in msgs), msgs + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA") +def test_fallback_on_does_not_block_normal_step(): + from dion import NorMuon + + p = _make_param((8, 16)) + before = p.data.clone() + p.grad = torch.randn_like(p) + + # Identity NS gives finite output -> guard sees nothing wrong, step + # proceeds normally and the parameter changes. + opt = NorMuon( + [p], + lr=0.01, + newton_schulz_func=_identity_ns, + nan_guard_fallback=True, + ) + opt.step() + + assert torch.isfinite(p.data).all() + assert not torch.equal(p.data, before) From a85fbb6c77cbca1cb5cae979c24df71ef9ad1aa1 Mon Sep 17 00:00:00 2001 From: John Langford Date: Fri, 8 May 2026 12:48:09 -0700 Subject: [PATCH 2/2] nan_guard_fallback: sync on the optimizer's process group, not the megabatch's The head-split and FSDP2 batch-sharded paths in `_create_ortho_tasks` deliberately override `process_group=None` for the megabatch because no alltoall is needed there. Gating the nan-flag allreduce on that same `process_group` silently disabled the sync in those cases: - Head-split + DDP: params replicated across ranks, but only some ranks may take the fallback. DDP weights drift -- the exact failure mode this guard was supposed to prevent. - Batch-sharded FSDP2: each rank owns a different shard of the same logical param. Divergent skips leave the logical tensor torn (some shards stepped, some not), violating the "this batch never happened" invariant. Thread the optimizer's full `self._process_group` through as a separate `nan_sync_process_group` argument so the nan-skip decision agrees across all ranks regardless of the megabatch's local-vs-collective config. --- dion/normuon.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/dion/normuon.py b/dion/normuon.py index d77588c..923d843 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -167,6 +167,13 @@ def _create_ortho_tasks( newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], nan_guard_fallback=self._nan_guard_fallback, + # Sync the nan flag on the optimizer's full process group, + # not the megabatch's: the head-split and batch-sharded + # paths below override ``process_group`` to ``None`` (no + # alltoall is needed there), but the nan-skip decision + # still has to agree across all ranks or DDP drifts and + # sharded params end up torn. + nan_sync_process_group=self._process_group, ) shape_groups: dict[tuple, list] = defaultdict(list) @@ -231,6 +238,7 @@ def normuon_update_megabatch_async( newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, nan_guard_fallback: bool = False, + nan_sync_process_group: Optional[ProcessGroup] = None, ) -> Generator[None, None, None]: """ Mega-batched NorMuon update: processes ALL same-shape parameters in one @@ -293,8 +301,14 @@ def normuon_update_megabatch_async( # Cost when triggered: one tiny allreduce + one device->host sync per # shape group per step. Cost when off: a single Python branch. local_nonfinite = (~torch.isfinite(U_stacked).all()).to(torch.uint8).reshape(1) - if process_group is not None: - dist.all_reduce(local_nonfinite, op=dist.ReduceOp.MAX, group=process_group) + # Sync on the optimizer's process group rather than the megabatch's: + # head-split and FSDP2 batch-sharded paths run with no megabatch + # alltoall (process_group=None) but the nan-skip decision must still + # agree across all ranks, or DDP-replicated params drift and + # sharded params end up torn (some shards stepped, some not). + sync_pg = nan_sync_process_group if nan_sync_process_group is not None else process_group + if sync_pg is not None: + dist.all_reduce(local_nonfinite, op=dist.ReduceOp.MAX, group=sync_pg) if bool(local_nonfinite.item()): if device_rank == 0: warnings.warn(