diff --git a/dion/normuon.py b/dion/normuon.py index 76e7d7c..923d843 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -1,4 +1,7 @@ +import warnings + import torch +import torch.distributed as dist from collections import defaultdict from torch import Tensor from torch.distributed import ProcessGroup @@ -44,6 +47,15 @@ class NorMuon(DistributedOrthoBase): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is ``func(input: Tensor, epsilon: float) -> Tensor``. + nan_guard_fallback: If True, after each megabatch Newton-Schulz call, + check for non-finite values in the orthogonalized update and skip + the update on all ranks if any rank sees one. Adds a single + ``all_reduce(MAX)`` of one byte per shape group per step when + enabled and the optimizer is distributed. Defensive guard for + issues like microsoft/dion#76 where a Newton-Schulz backend + intermittently produces NaNs; safer than letting the corrupt + update propagate into the parameter and poisoning the run. + Default ``False`` (no extra collective, no extra check). Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -68,6 +80,7 @@ def __init__( use_triton: bool = False, use_polar_express: bool = True, newton_schulz_func: Optional[Callable] = None, + nan_guard_fallback: bool = False, ): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") @@ -104,6 +117,7 @@ def __init__( use_polar_express=use_polar_express, newton_schulz_func=newton_schulz_func, ) + self._nan_guard_fallback = nan_guard_fallback def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state = super()._get_or_initialize_state(param, algo) @@ -152,6 +166,14 @@ def _create_ortho_tasks( process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], + nan_guard_fallback=self._nan_guard_fallback, + # Sync the nan flag on the optimizer's full process group, + # not the megabatch's: the head-split and batch-sharded + # paths below override ``process_group`` to ``None`` (no + # alltoall is needed there), but the nan-skip decision + # still has to agree across all ranks or DDP drifts and + # sharded params end up torn. + nan_sync_process_group=self._process_group, ) shape_groups: dict[tuple, list] = defaultdict(list) @@ -215,6 +237,8 @@ def normuon_update_megabatch_async( process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, + nan_guard_fallback: bool = False, + nan_sync_process_group: Optional[ProcessGroup] = None, ) -> Generator[None, None, None]: """ Mega-batched NorMuon update: processes ALL same-shape parameters in one @@ -263,6 +287,39 @@ def normuon_update_megabatch_async( V_local = to_local(V) U_stacked = torch.stack(U) V_stacked = torch.stack(V_local) + + if nan_guard_fallback: + # Defensive guard: if any rank's NS output went non-finite, all ranks + # must take the same fallback or DDP will diverge / a future + # collective will mismatch. Sync a single-byte ``ReduceOp.MAX`` flag + # across the process group, then on detection bail out of the rest + # of this step's update on all ranks. Skipping the post-ortho path + # (rather than zeroing U) keeps the variance buffer V and the param + # X strictly unchanged, so the run state is identical to "this batch + # never happened" for these params. Momentum M was already updated + # in pre-orthogonalize from the (clean) gradient and is left alone. + # Cost when triggered: one tiny allreduce + one device->host sync per + # shape group per step. Cost when off: a single Python branch. + local_nonfinite = (~torch.isfinite(U_stacked).all()).to(torch.uint8).reshape(1) + # Sync on the optimizer's process group rather than the megabatch's: + # head-split and FSDP2 batch-sharded paths run with no megabatch + # alltoall (process_group=None) but the nan-skip decision must still + # agree across all ranks, or DDP-replicated params drift and + # sharded params end up torn (some shards stepped, some not). + sync_pg = nan_sync_process_group if nan_sync_process_group is not None else process_group + if sync_pg is not None: + dist.all_reduce(local_nonfinite, op=dist.ReduceOp.MAX, group=sync_pg) + if bool(local_nonfinite.item()): + if device_rank == 0: + warnings.warn( + f"[NorMuon nan_guard_fallback] non-finite Newton-Schulz " + f"output detected; skipping update for this step on all " + f"ranks (shape={tuple(U_stacked.shape)}). See microsoft/" + f"dion#76.", + RuntimeWarning, + stacklevel=2, + ) + return U_stacked, V_stacked = normuon_normalization_stacked(U_stacked, V_stacked, muon_beta2) for i in range(N): V_local[i].copy_(V_stacked[i]) diff --git a/tests/test_nan_guard_fallback.py b/tests/test_nan_guard_fallback.py new file mode 100644 index 0000000..2415486 --- /dev/null +++ b/tests/test_nan_guard_fallback.py @@ -0,0 +1,115 @@ +"""Tests for NorMuon's defensive ``nan_guard_fallback`` option. + +Background. microsoft/dion#76 reports that NorMuon + gram-newton-schulz + +quack-kernels 0.4.1 occasionally produces NaN parameters on certain +hardware/shapes; gradients are clean entering ``optimizer.step()`` but the +post-step parameter is all-NaN. ``nan_guard_fallback`` adds a single +``all_reduce(MAX)`` of one byte per shape group per step so all ranks +agree on whether to skip the update; on detection the entire post-ortho +path is bypassed so the parameter, the variance buffer, and the run state +remain strictly unchanged for these params. + +These tests run single-rank (no NCCL/MPI) and use a custom Newton-Schulz +function that injects NaN, isolating the guard logic from any specific +NS backend. +""" + +import warnings + +import pytest +import torch + +CUDA_AVAILABLE = torch.cuda.is_available() +DEVICE = "cuda" if CUDA_AVAILABLE else "cpu" + + +def _nan_ns(X, epsilon=None): + return torch.full_like(X, float("nan")) + + +def _identity_ns(X, epsilon=None): + # Bypass torch.compile and shape gymnastics; a plain identity is enough + # to validate the guard and the no-fallback baseline. + return X.clone() + + +def _make_param(shape, seed=42): + torch.manual_seed(seed) + return torch.nn.Parameter(torch.randn(shape, device=DEVICE)) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA") +def test_fallback_off_lets_nan_propagate_to_params(): + from dion import NorMuon + + p = _make_param((8, 16)) + p.grad = torch.randn_like(p) + + opt = NorMuon( + [p], + lr=0.01, + newton_schulz_func=_nan_ns, + nan_guard_fallback=False, + ) + opt.step() + + # Without the guard, the NaN NS output flows through normalization and + # the post-ortho update, poisoning the parameter. This is the bug from + # issue #76 in unit-test form. + assert torch.isnan(p.data).any(), ( + "expected NaN to propagate to params with the guard off; if this " + "asserts, NorMuon's NaN behavior changed and the fallback test " + "below may no longer be exercising the right path." + ) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA") +def test_fallback_on_skips_step_when_ns_returns_nan(): + from dion import NorMuon + + p = _make_param((8, 16)) + before = p.data.clone() + p.grad = torch.randn_like(p) + + opt = NorMuon( + [p], + lr=0.01, + newton_schulz_func=_nan_ns, + nan_guard_fallback=True, + ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + opt.step() + + assert torch.isfinite(p.data).all() + # ``return`` before post-ortho means the parameter is bit-exact unchanged. + assert torch.equal(p.data, before) + # Variance buffer must also be left at its initial zero state - skipping + # the step means "this batch never happened" for these params. + assert torch.all(opt.state[p]["variance_neuron"] == 0) + + msgs = [str(w.message) for w in caught if issubclass(w.category, RuntimeWarning)] + assert any("nan_guard_fallback" in m for m in msgs), msgs + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA") +def test_fallback_on_does_not_block_normal_step(): + from dion import NorMuon + + p = _make_param((8, 16)) + before = p.data.clone() + p.grad = torch.randn_like(p) + + # Identity NS gives finite output -> guard sees nothing wrong, step + # proceeds normally and the parameter changes. + opt = NorMuon( + [p], + lr=0.01, + newton_schulz_func=_identity_ns, + nan_guard_fallback=True, + ) + opt.step() + + assert torch.isfinite(p.data).all() + assert not torch.equal(p.data, before)