Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions dion/normuon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import warnings

import torch
import torch.distributed as dist
from collections import defaultdict
from torch import Tensor
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -44,6 +47,15 @@ class NorMuon(DistributedOrthoBase):
use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided.
newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization.
Signature is ``func(input: Tensor, epsilon: float) -> Tensor``.
nan_guard_fallback: If True, after each megabatch Newton-Schulz call,
check for non-finite values in the orthogonalized update and skip
the update on all ranks if any rank sees one. Adds a single
``all_reduce(MAX)`` of one byte per shape group per step when
enabled and the optimizer is distributed. Defensive guard for
issues like microsoft/dion#76 where a Newton-Schulz backend
intermittently produces NaNs; safer than letting the corrupt
update propagate into the parameter and poisoning the run.
Default ``False`` (no extra collective, no extra check).

Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/
FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra
Expand All @@ -68,6 +80,7 @@ def __init__(
use_triton: bool = False,
use_polar_express: bool = True,
newton_schulz_func: Optional[Callable] = None,
nan_guard_fallback: bool = False,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
Expand Down Expand Up @@ -104,6 +117,7 @@ def __init__(
use_polar_express=use_polar_express,
newton_schulz_func=newton_schulz_func,
)
self._nan_guard_fallback = nan_guard_fallback

def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict:
state = super()._get_or_initialize_state(param, algo)
Expand Down Expand Up @@ -152,6 +166,14 @@ def _create_ortho_tasks(
process_group=self._process_group,
newton_schulz_func=self._newton_schulz_func,
cautious_wd=group["cautious_wd"],
nan_guard_fallback=self._nan_guard_fallback,
# Sync the nan flag on the optimizer's full process group,
# not the megabatch's: the head-split and batch-sharded
# paths below override ``process_group`` to ``None`` (no
# alltoall is needed there), but the nan-skip decision
# still has to agree across all ranks or DDP drifts and
# sharded params end up torn.
nan_sync_process_group=self._process_group,
)

shape_groups: dict[tuple, list] = defaultdict(list)
Expand Down Expand Up @@ -215,6 +237,8 @@ def normuon_update_megabatch_async(
process_group: Optional[ProcessGroup] = None,
newton_schulz_func: Optional[Callable] = None,
cautious_wd: bool = False,
nan_guard_fallback: bool = False,
nan_sync_process_group: Optional[ProcessGroup] = None,
) -> Generator[None, None, None]:
"""
Mega-batched NorMuon update: processes ALL same-shape parameters in one
Expand Down Expand Up @@ -263,6 +287,39 @@ def normuon_update_megabatch_async(
V_local = to_local(V)
U_stacked = torch.stack(U)
V_stacked = torch.stack(V_local)

if nan_guard_fallback:
# Defensive guard: if any rank's NS output went non-finite, all ranks
# must take the same fallback or DDP will diverge / a future
# collective will mismatch. Sync a single-byte ``ReduceOp.MAX`` flag
# across the process group, then on detection bail out of the rest
# of this step's update on all ranks. Skipping the post-ortho path
# (rather than zeroing U) keeps the variance buffer V and the param
# X strictly unchanged, so the run state is identical to "this batch
# never happened" for these params. Momentum M was already updated
# in pre-orthogonalize from the (clean) gradient and is left alone.
# Cost when triggered: one tiny allreduce + one device->host sync per
# shape group per step. Cost when off: a single Python branch.
local_nonfinite = (~torch.isfinite(U_stacked).all()).to(torch.uint8).reshape(1)
# Sync on the optimizer's process group rather than the megabatch's:
# head-split and FSDP2 batch-sharded paths run with no megabatch
# alltoall (process_group=None) but the nan-skip decision must still
# agree across all ranks, or DDP-replicated params drift and
# sharded params end up torn (some shards stepped, some not).
sync_pg = nan_sync_process_group if nan_sync_process_group is not None else process_group
if sync_pg is not None:
dist.all_reduce(local_nonfinite, op=dist.ReduceOp.MAX, group=sync_pg)
if bool(local_nonfinite.item()):
if device_rank == 0:
warnings.warn(
f"[NorMuon nan_guard_fallback] non-finite Newton-Schulz "
f"output detected; skipping update for this step on all "
f"ranks (shape={tuple(U_stacked.shape)}). See microsoft/"
f"dion#76.",
RuntimeWarning,
stacklevel=2,
)
return
U_stacked, V_stacked = normuon_normalization_stacked(U_stacked, V_stacked, muon_beta2)
for i in range(N):
V_local[i].copy_(V_stacked[i])
Expand Down
115 changes: 115 additions & 0 deletions tests/test_nan_guard_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for NorMuon's defensive ``nan_guard_fallback`` option.

Background. microsoft/dion#76 reports that NorMuon + gram-newton-schulz +
quack-kernels 0.4.1 occasionally produces NaN parameters on certain
hardware/shapes; gradients are clean entering ``optimizer.step()`` but the
post-step parameter is all-NaN. ``nan_guard_fallback`` adds a single
``all_reduce(MAX)`` of one byte per shape group per step so all ranks
agree on whether to skip the update; on detection the entire post-ortho
path is bypassed so the parameter, the variance buffer, and the run state
remain strictly unchanged for these params.

These tests run single-rank (no NCCL/MPI) and use a custom Newton-Schulz
function that injects NaN, isolating the guard logic from any specific
NS backend.
"""

import warnings

import pytest
import torch

CUDA_AVAILABLE = torch.cuda.is_available()
DEVICE = "cuda" if CUDA_AVAILABLE else "cpu"


def _nan_ns(X, epsilon=None):
return torch.full_like(X, float("nan"))


def _identity_ns(X, epsilon=None):
# Bypass torch.compile and shape gymnastics; a plain identity is enough
# to validate the guard and the no-fallback baseline.
return X.clone()


def _make_param(shape, seed=42):
torch.manual_seed(seed)
return torch.nn.Parameter(torch.randn(shape, device=DEVICE))


@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA")
def test_fallback_off_lets_nan_propagate_to_params():
from dion import NorMuon

p = _make_param((8, 16))
p.grad = torch.randn_like(p)

opt = NorMuon(
[p],
lr=0.01,
newton_schulz_func=_nan_ns,
nan_guard_fallback=False,
)
opt.step()

# Without the guard, the NaN NS output flows through normalization and
# the post-ortho update, poisoning the parameter. This is the bug from
# issue #76 in unit-test form.
assert torch.isnan(p.data).any(), (
"expected NaN to propagate to params with the guard off; if this "
"asserts, NorMuon's NaN behavior changed and the fallback test "
"below may no longer be exercising the right path."
)


@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA")
def test_fallback_on_skips_step_when_ns_returns_nan():
from dion import NorMuon

p = _make_param((8, 16))
before = p.data.clone()
p.grad = torch.randn_like(p)

opt = NorMuon(
[p],
lr=0.01,
newton_schulz_func=_nan_ns,
nan_guard_fallback=True,
)

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
opt.step()

assert torch.isfinite(p.data).all()
# ``return`` before post-ortho means the parameter is bit-exact unchanged.
assert torch.equal(p.data, before)
# Variance buffer must also be left at its initial zero state - skipping
# the step means "this batch never happened" for these params.
assert torch.all(opt.state[p]["variance_neuron"] == 0)

msgs = [str(w.message) for w in caught if issubclass(w.category, RuntimeWarning)]
assert any("nan_guard_fallback" in m for m in msgs), msgs


@pytest.mark.skipif(not CUDA_AVAILABLE, reason="NorMuon's compiled kernels target CUDA")
def test_fallback_on_does_not_block_normal_step():
from dion import NorMuon

p = _make_param((8, 16))
before = p.data.clone()
p.grad = torch.randn_like(p)

# Identity NS gives finite output -> guard sees nothing wrong, step
# proceeds normally and the parameter changes.
opt = NorMuon(
[p],
lr=0.01,
newton_schulz_func=_identity_ns,
nan_guard_fallback=True,
)
opt.step()

assert torch.isfinite(p.data).all()
assert not torch.equal(p.data, before)
Loading