From 798eee462a9d559ea101dd37940eeb27119e5912 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 10 May 2026 10:29:39 -0700 Subject: [PATCH 1/4] Add Aurora optimizer for non-square matrices Aurora (https://blog.tilderesearch.com/blog/aurora) approximates a projection onto the intersection of the row oblique and Stiefel manifolds via diagonal preconditioning, producing leverage-uniform updates. Standard polar (Newton-Schulz) inherits the non-uniform left-singular row norms of the gradient; Aurora iteratively rescales rows so that all neurons receive comparably-sized updates. For square matrices Aurora reduces bit-for-bit to standard polar. For non-square matrices it transposes to tall, runs pp_iterations rounds of row-norm preconditioning around the existing polar function, and applies Aurora's max(1, m/n)**0.5 aspect-ratio scaling. Implementation reuses Muon's pre/post-orthogonalize stages and the shared megabatch infrastructure; the new logic lives entirely in a newton_schulz_func wrapper, so all dion features (mega-batching, FSDP2 sharding, num_heads, mixed param groups) work transparently. --- dion/__init__.py | 1 + dion/aurora.py | 369 +++++++++++++++++++++++++++++++++++++++ tests/test_optimizers.py | 105 +++++++++++ 3 files changed, 475 insertions(+) create mode 100644 dion/aurora.py diff --git a/dion/__init__.py b/dion/__init__.py index 34894e6..2f1f00b 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -1,3 +1,4 @@ +from .aurora import Aurora from .dion import Dion from .dion import DionMixedPrecisionConfig from .dion_simple import Dion as DionSimple diff --git a/dion/aurora.py b/dion/aurora.py new file mode 100644 index 0000000..24b2707 --- /dev/null +++ b/dion/aurora.py @@ -0,0 +1,369 @@ +import torch +from collections import defaultdict +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.optim.optimizer import ParamsT +from typing import Callable, Generator, List, Optional, Tuple, Union + +from .megabatch_base import ( + DistributedOrthoBase, + megabatch_orthogonalize_async, + adjust_lr_spectral_norm, + adjust_lr_rms_norm, +) +from .muon import muon_update_pre_orthogonalize, muon_update_post_orthogonalize +from .newton_schulz_triton import ( + TRITON_AVAILABLE, + newton_schulz_triton, + zeropower_via_newtonschulz5, +) +from .opt_utils import AsyncTask, to_local +from .polar_express import polar_express, polar_express_triton + + +class Aurora(DistributedOrthoBase): + """ + Distributed Aurora optimizer for PyTorch FSDP2. Also compatible with DDP. + + Aurora is an optimizer for non-square weight matrices that achieves more + balanced neuron utilization than standard Muon. Instead of applying the + polar (Newton-Schulz) factor directly, which inherits non-uniform + left-singular row norms, Aurora iteratively approximates a projection onto + the intersection of the row oblique and Stiefel manifolds via diagonal + preconditioning. The result is a leverage-uniform update. + + For square matrices Aurora reduces to standard Muon. + + Args: + params: Parameters for the optimizer. + distributed_mesh: DeviceMesh or ProcessGroup for distributed training. + Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. + lr: Base learning rate. Aurora bakes the tall-aspect-ratio scaling + ``max(1, m/n)**0.5`` into the update, so ``adjust_lr`` defaults to None. + mu: Momentum factor. + betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. + weight_decay: Weight decay factor. + cautious_wd: Whether to apply weight decay only where update and parameter signs align. + epsilon: Small value to avoid division by zero. + nesterov: Whether to use Nesterov momentum. + adjust_lr: Optional Muon-style LR adjustment ("spectral_norm", "rms_norm", or None). + None is the Aurora default; the algorithm already applies its own + ``max(1, m/n)**0.5`` aspect-ratio scaling inside the orthogonalization. + flatten: Whether to flatten 3D+ tensors to 2D for the orthogonalization step. + pp_iterations: Number of preconditioned-polar iterations. Each iteration + calls the base polar (Newton-Schulz) once. ``pp_iterations=2`` is the + Aurora paper default; ``pp_iterations=1`` is single-shot row-norm + preconditioning. + pp_beta: Exponent for the diagonal update between iterations. + use_triton: Whether to use the Triton Newton-Schulz kernel. + use_polar_express: Whether to use Polar Express for the base polar. + newton_schulz_func: Optional custom base polar function. Aurora wraps + this with its diagonal preconditioning loop. + + Aurora: https://blog.tilderesearch.com/blog/aurora + Reference implementation: https://github.com/tilde-research/aurora-release + """ + + def __init__( + self, + params: ParamsT, + distributed_mesh: Optional[Union[DeviceMesh, ProcessGroup]] = None, + lr: float = 0.01, + mu: float = 0.95, + betas: Tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.01, + cautious_wd: bool = False, + epsilon: float = 1e-8, + nesterov: bool = True, + adjust_lr: Optional[str] = None, + flatten: bool = False, + pp_iterations: int = 2, + pp_beta: float = 0.5, + use_gram_newton_schulz: bool = False, + use_triton: bool = False, + use_polar_express: bool = True, + newton_schulz_func: Optional[Callable] = None, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if mu < 0.0: + raise ValueError(f"Invalid momentum factor (mu): {mu}") + if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: + raise ValueError(f"Invalid betas: {betas}") + if adjust_lr not in ("spectral_norm", "rms_norm", None): + raise ValueError( + f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." + ) + if not isinstance(pp_iterations, int) or pp_iterations < 1: + raise ValueError( + f"Invalid pp_iterations: {pp_iterations}. Must be a positive integer." + ) + if pp_beta < 0.0: + raise ValueError(f"Invalid pp_beta: {pp_beta}") + + # Resolve the base polar function (the one wrapped by Aurora's + # diagonal preconditioning). Mirrors DistributedOrthoBase resolution. + if newton_schulz_func is not None: + if not callable(newton_schulz_func): + raise TypeError( + f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}" + ) + base_polar = newton_schulz_func + elif use_gram_newton_schulz: + try: + from gram_newton_schulz import GramNewtonSchulz + except ImportError: + raise ImportError( + "use_gram_newton_schulz=True requires the 'gram-newton-schulz' package, " + "which is not installed. " + "Install it with: pip install gram-newton-schulz" + ) + _gns = GramNewtonSchulz( + ns_use_kernels=use_triton, + use_gram_newton_schulz=True, + gram_newton_schulz_reset_iterations=[2], + compile_kwargs=dict(fullgraph=True, mode="default"), + ) + base_polar = lambda X, epsilon=None: _gns(X) + elif use_polar_express and use_triton: + base_polar = polar_express_triton + elif use_polar_express: + base_polar = polar_express + elif use_triton: + if not TRITON_AVAILABLE: + raise ImportError( + "use_triton=True requires the 'triton' package, which is not installed. " + "Install it with: pip install dion[triton] (or: pip install triton)" + ) + base_polar = newton_schulz_triton + else: + base_polar = zeropower_via_newtonschulz5 + + aurora_polar_func = make_aurora_polar( + base_polar=base_polar, + pp_iterations=pp_iterations, + pp_beta=pp_beta, + ) + + defaults = dict( + lr=lr, + mu=mu, + beta1=betas[0], + beta2=betas[1], + weight_decay=weight_decay, + cautious_wd=cautious_wd, + algorithm="aurora", + step=0, + epsilon=epsilon, + nesterov=nesterov, + flatten=flatten, + adjust_lr=adjust_lr, + pp_iterations=pp_iterations, + pp_beta=pp_beta, + ) + # Pass the Aurora-wrapped polar through the base class's + # ``newton_schulz_func`` slot so the existing megabatch path uses it. + super().__init__( + params, distributed_mesh, "aurora", defaults, + newton_schulz_func=aurora_polar_func, + ) + + def _create_ortho_tasks( + self, param_groups: List[dict] + ) -> Generator["AsyncTask", None, None]: + """ + Mega-batched Aurora task creation: groups ALL same-shape parameters + into a single task to minimize communication rounds and kernel launches. + """ + for group in param_groups: + assert group["algorithm"] == "aurora" + assert all( + p.ndim >= 2 for p in group["params"] + ), "Aurora optimizer only supports matrix parameters." + + group_params = [p for p in group["params"] if p.grad is not None] + if not group_params: + continue + + update_args = dict( + lr=torch.tensor(group["lr"]), + momentum=torch.tensor(group["mu"]), + weight_decay=torch.tensor(group["weight_decay"]), + epsilon=torch.tensor(group["epsilon"]), + nesterov=group["nesterov"], + flatten=group["flatten"], + adjust_lr=group["adjust_lr"], + device_rank=self._device_rank, + world_size=self._world_size, + process_group=self._process_group, + newton_schulz_func=self._newton_schulz_func, + cautious_wd=group["cautious_wd"], + ) + + shape_groups: dict[tuple, list] = defaultdict(list) + for p in group_params: + sharding = p.placements if isinstance(p, DTensor) else None + shape_groups[(p.shape, sharding, p.dtype)].append(p) + + num_heads = self._resolve_num_heads(group) + + for (_shape, _sharding, _dtype), params in shape_groups.items(): + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, "aurora") for p in params] + momentums = [s["momentum"] for s in states] + + if num_heads is not None: + params, gradients, momentums = self._prepare_head_split( + num_heads, params, gradients, momentums + ) + megabatch_args = {**update_args, "process_group": None} + shard_dim = None + else: + is_batch_sharded, is_matrix_sharded, sharded_tensor_dim = ( + self._get_shard_info(params[0], group) + ) + megabatch_args = update_args + if is_batch_sharded and not is_matrix_sharded: + megabatch_args = {**update_args, "process_group": None} + shard_dim = sharded_tensor_dim + + yield AsyncTask( + aurora_update_megabatch_async( + X=params, + G=gradients, + M=momentums, + shard_dim=shard_dim, + **megabatch_args, + ) + ) + + +def aurora_update_megabatch_async( + X: List[Tensor], + G: List[Tensor], + M: List[Tensor], + lr: Tensor, + momentum: Tensor, + weight_decay: Tensor, + epsilon: Tensor, + nesterov: bool, + flatten: bool, + adjust_lr: Optional[str], + device_rank: int, + world_size: int, + shard_dim: Optional[int] = None, + process_group: Optional[ProcessGroup] = None, + newton_schulz_func: Optional[Callable] = None, + cautious_wd: bool = False, +) -> Generator[None, None, None]: + """ + Mega-batched Aurora update. Reuses Muon's pre/post-orthogonalize stages + and the shared megabatch communication; ``newton_schulz_func`` is the + Aurora-wrapped polar (see ``make_aurora_polar``). + """ + N = len(X) + assert N == len(G) == len(M) + + U = muon_update_pre_orthogonalize( + G=to_local(G), M=to_local(M), momentum=momentum, nesterov=nesterov, + ) + + comm_dim = (shard_dim - X[0].ndim) if shard_dim is not None else None + + if comm_dim is not None: + if not isinstance(X[0], DTensor): + raise TypeError( + "Sharded path requires X[0] to be a DTensor so .shape gives " + f"the global size; got {type(X[0]).__name__}." + ) + global_comm_dim_size = X[0].shape[comm_dim] + else: + global_comm_dim_size = None + + U = yield from megabatch_orthogonalize_async( + U, + comm_dim=comm_dim, + device_rank=device_rank, + world_size=world_size, + process_group=process_group, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + global_comm_dim_size=global_comm_dim_size, + ) + + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + muon_update_post_orthogonalize( + X=to_local(X), + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + cautious_wd=cautious_wd, + ) + + +def make_aurora_polar( + base_polar: Callable, + pp_iterations: int = 2, + pp_beta: float = 0.5, +) -> Callable: + """ + Build an Aurora-flavored polar function that has the same signature as a + standard Newton-Schulz / polar function (``func(X, epsilon) -> Tensor``) + and can be plugged into ``megabatch_orthogonalize_async``. + + For square matrices this is just ``base_polar(X, epsilon)``. For + non-square matrices it transposes to tall, then runs ``pp_iterations`` + rounds of diagonal row-preconditioning, calling ``base_polar`` once per + round. The output is multiplied by ``max(1, m/n)**0.5`` to apply the + same aspect-ratio scaling as in the Aurora reference (this replaces + Muon's per-LR ``sqrt(m/n)`` scaling, hence the optimizer's + ``adjust_lr=None`` default). + + Reference: https://github.com/tilde-research/aurora-release/blob/main/aurora.py + """ + def aurora_polar(X: Tensor, epsilon=1e-7) -> Tensor: + m, n = X.size(-2), X.size(-1) + + if m == n: + U = base_polar(X, epsilon=epsilon) + else: + transposed = m < n + X_t = X.mT if transposed else X + mm = max(m, n) + nn = min(m, n) + # Use a Python float for clamp(min=...) to avoid device-mismatch + # when ``epsilon`` is a CPU Tensor (the megabatch path). + eps_f = float(epsilon) if isinstance(epsilon, Tensor) else float(epsilon) + X32 = X_t.to(torch.float32) + target_row_sq = nn / mm + row_norm = X32.norm(dim=-1, keepdim=True).clamp(min=eps_f) + D = 1.0 / row_norm + eps_sq = eps_f * eps_f + U = base_polar(D * X32, epsilon=epsilon) + for k in range(1, pp_iterations): + row_sq = U.to(torch.float32).pow(2).sum(dim=-1, keepdim=True).clamp(min=eps_sq) + D = D * (target_row_sq / row_sq).pow(pp_beta) + U = base_polar(D * X32, epsilon=epsilon) + if transposed: + U = U.mT + + # Aurora aspect-ratio scaling: max(1, m/n)**0.5. Baked here rather + # than in the LR so callers can use adjust_lr=None and still get the + # correct tall scaling regardless of input orientation. + scale = max(1.0, m / n) ** 0.5 + if scale != 1.0: + U = U * scale + return U + + return aurora_polar diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 8d2d98f..7035e14 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -93,6 +93,111 @@ def test_mixed_shapes(self): _run_steps(Muon, params, dict(lr=0.01)) +# --------------------------------------------------------------------------- +# Aurora +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA required") +class TestAurora: + def test_basic(self): + from dion import Aurora + params = _make_params([(64, 128), (128, 64)]) + _run_steps(Aurora, params, dict(lr=0.01)) + + def test_determinism(self): + from dion import Aurora + p1 = _make_params([(64, 128)]) + r1 = _run_steps(Aurora, p1, dict(lr=0.01)) + p2 = _make_params([(64, 128)]) + r2 = _run_steps(Aurora, p2, dict(lr=0.01)) + assert torch.equal(r1[0], r2[0]) + + def test_params_change(self): + from dion import Aurora + params = _make_params([(64, 128)]) + before = params[0].data.clone() + _run_steps(Aurora, params, dict(lr=0.01), n_steps=1) + assert not torch.equal(params[0].data, before) + + def test_nesterov(self): + from dion import Aurora + params = _make_params([(64, 128)]) + _run_steps(Aurora, params, dict(lr=0.01, nesterov=True)) + + def test_cautious_wd(self): + from dion import Aurora + params = _make_params([(64, 128)]) + _run_steps(Aurora, params, dict(lr=0.01, cautious_wd=True)) + + def test_pp_iterations(self): + from dion import Aurora + for pp_iterations in [1, 2, 3]: + params = _make_params([(64, 128)]) + _run_steps(Aurora, params, dict(lr=0.01, pp_iterations=pp_iterations)) + + def test_megabatch_same_shape(self): + from dion import Aurora + params = _make_params([(64, 128)] * 5) + _run_steps(Aurora, params, dict(lr=0.01)) + + def test_mixed_shapes(self): + from dion import Aurora + params = _make_params([(64, 128), (128, 64), (32, 32)]) + _run_steps(Aurora, params, dict(lr=0.01)) + + def test_invalid_pp_iterations(self): + from dion import Aurora + with pytest.raises(ValueError, match="pp_iterations"): + Aurora(_make_params([(32, 64)]), pp_iterations=0) + with pytest.raises(ValueError, match="pp_iterations"): + Aurora(_make_params([(32, 64)]), pp_iterations=-1) + + def test_square_matches_muon_polar(self): + """For square matrices, Aurora's orthogonalization should equal the + underlying polar function bit-for-bit (the diagonal preconditioning + loop is bypassed when m == n).""" + from dion.aurora import make_aurora_polar + from dion.polar_express import polar_express + + torch.manual_seed(0) + G = torch.randn(128, 128, device=DEVICE) + ap = make_aurora_polar(polar_express, pp_iterations=2, pp_beta=0.5) + u_std = polar_express(G, epsilon=1e-7) + u_aur = ap(G, epsilon=1e-7) + # Aurora's aspect-ratio scaling is max(1, m/n)**0.5 = 1 for square, + # so the outputs should be exactly equal. + assert torch.equal(u_std, u_aur) + + def test_row_norms_more_uniform_than_polar(self): + """The defining property of Aurora: row norms of the orthogonalized + update should be more uniform than standard polar on a non-square + matrix. Compares max/min row-norm ratio (target: 1.0).""" + from dion.aurora import make_aurora_polar + from dion.polar_express import polar_express + + torch.manual_seed(0) + # Tall: rows >> cols + G = torch.randn(512, 128, device=DEVICE, dtype=torch.float32) + u_std = polar_express(G, epsilon=1e-7).to(torch.float32) + ap = make_aurora_polar(polar_express, pp_iterations=2, pp_beta=0.5) + # Strip Aurora's max(1, m/n)**0.5 scaling to compare just the polar shape. + scale = max(1.0, 512 / 128) ** 0.5 + u_aur = (ap(G, epsilon=1e-7).to(torch.float32)) / scale + + rn_std = u_std.norm(dim=-1) + rn_aur = u_aur.norm(dim=-1) + ratio_std = (rn_std.max() / rn_std.min()).item() + ratio_aur = (rn_aur.max() / rn_aur.min()).item() + # Aurora should noticeably tighten the ratio toward 1. + assert ratio_aur < ratio_std, ( + f"Aurora row-norm ratio {ratio_aur:.3f} should be tighter than " + f"standard polar {ratio_std:.3f}" + ) + assert ratio_aur < 1.15, ( + f"Aurora row-norm ratio {ratio_aur:.3f} should be close to 1" + ) + + # --------------------------------------------------------------------------- # NorMuon # --------------------------------------------------------------------------- From d4d4fd545e5b7c978e44957f4bc7dedbe414df09 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 10 May 2026 10:45:41 -0700 Subject: [PATCH 2/4] Aurora: standard adjust_lr pathway, reference impl, README - Switch Aurora's aspect-ratio scaling to dion's standard ``adjust_lr`` pathway: ``adjust_lr="spectral_norm"`` is now the default (matching Muon/NorMuon), and the baked-in ``max(1, m/n)**0.5`` multiplication is removed from ``aurora_polar``. Note: this differs from the Aurora reference for wide matrices (dion uses ``sqrt(m/n)``, the reference uses ``max(1, m/n)**0.5``); on tall and square matrices the two agree. - Add ``dion/aurora_reference.py``: single-file readable port of ``tilde-research/aurora-release`` (simple-quintic Newton-Schulz + diag-preconditioned polar + AdamW/Lion fallback), exported as ``dion.AuroraReference``. Mirrors the existing ``muon_reference.py`` / ``dion_reference.py`` pattern. - README: list Aurora alongside Muon/Dion2/NorMuon in the optimizer table, the ``1D Sharding Configuration`` heading, the imports example, and the per-file descriptions. - Tests: drop the now-unneeded aspect-ratio rescale in the row-norm uniformity assertion; add two AuroraReference tests. --- README.md | 23 ++-- dion/__init__.py | 1 + dion/aurora.py | 27 ++--- dion/aurora_reference.py | 248 +++++++++++++++++++++++++++++++++++++++ tests/test_optimizers.py | 21 +++- 5 files changed, 289 insertions(+), 31 deletions(-) create mode 100644 dion/aurora_reference.py diff --git a/README.md b/README.md index e15dc86..7a344b4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ This repository provides efficient implementations of orthonormal optimizers for You can find the following optimizers: * [Muon](https://kellerjordan.github.io/posts/muon/) * [Dion2](https://arxiv.org/abs/2512.16928) and [Dion](https://arxiv.org/pdf/2504.05295) (Dion is a legacy optimizer; we recommend using Dion2) -* [NorMuon](https://arxiv.org/abs/2510.05491) +* [NorMuon](https://arxiv.org/abs/2510.05491) +* [Aurora](https://blog.tilderesearch.com/blog/aurora) ## Table of Contents @@ -53,7 +54,7 @@ pip install git+https://github.com/microsoft/dion.git Then in your code, you can use: ```python -from dion import Dion2, Muon, NorMuon, Dion +from dion import Dion2, Muon, NorMuon, Dion, Aurora ``` Please carefully go through this readme for detailed instructions on using our optimizers. There are major differences compared to PyTorch built-in optimizers, such as `Adam`/`AdamW`. @@ -146,12 +147,12 @@ The practical effectiveness of orthonormal optimizers was first demonstrated by Our current implementations support the following parallelization techniques: -| Parallelization | Dion | Dion2 | Muon | NorMuon | -|--------------------|------|-------|------|---------| -| Single device | Yes | Yes | Yes | Yes | -| PyTorch DDP | Yes | Yes | Yes | Yes | -| PyTorch FSDP2 | Yes | Yes | Yes | Yes | -| PyTorch FSDP2 + TP | Yes | No | No | No | +| Parallelization | Dion | Dion2 | Muon | NorMuon | Aurora | +|--------------------|------|-------|------|---------|--------| +| Single device | Yes | Yes | Yes | Yes | Yes | +| PyTorch DDP | Yes | Yes | Yes | Yes | Yes | +| PyTorch FSDP2 | Yes | Yes | Yes | Yes | Yes | +| PyTorch FSDP2 + TP | Yes | No | No | No | No | For faster performance, these optimizers will process parameters in batches and interleave multiple batches to overlap compute with communication. @@ -161,12 +162,14 @@ We include optimizer implementations in the `dion/` directory of this repo. * `muon.py`: High-performance version of Muon. For sharded matrices, all-to-all communication is used to simultaneously unshard and distribute a batch of matrices. For replicated matrices, Muon will distribute work across all devices and all-gather final results. * **`dion2.py`**: High-performance implementation of Dion2, using a similar all-to-all communication pattern for distributed orthonormalization. Only an α-fraction of the momentum matrix is communicated and orthonormalized, significantly reducing both communication overhead and computation cost. * `normuon.py`: A variant of the Muon optimizer that introduces neuron-wise normalization to improve stability and convergence efficiency, modified to take similar arguments as `muon.py`. See [the paper](https://arxiv.org/abs/2510.05491) for more details. +* `aurora.py`: An optimizer for non-square matrices that produces leverage-uniform updates by iteratively row-preconditioning the polar (Newton-Schulz) factorization. For square matrices it reduces to standard Muon; for non-square ones it tightens the row-norm distribution of the orthogonalized update so all neurons receive comparably-sized steps. See [the Aurora blog post](https://blog.tilderesearch.com/blog/aurora) for the algorithm; uses the same `muon.py` mega-batch infrastructure. We also provide some reference implementations: * `dion_reference.py`: An implementation without batching, communication overlapping, or split all-reduce. This version of Dion is intended to closely follow the algorithms as described in our [Dion paper](https://arxiv.org/pdf/2504.05295). * `dion_simple.py`: A simplified illustration of the Dion update rule in a single Python function, provided for educational value. * `muon_reference.py`: A version of Muon by [Moonshot AI](https://github.com/MoonshotAI/Moonlight), modified to take similar arguments as `muon.py`. +* `aurora_reference.py`: A single-file readable port of [tilde-research/aurora-release](https://github.com/tilde-research/aurora-release), using the simple-quintic Newton-Schulz from the original Aurora repo. @@ -264,9 +267,9 @@ Requirements: the parameter must be 2D, `num_heads` must divide dim 0, and when For our efficient distributed optimizers to work correctly, they need information about the model's parallelization scheme. This is provided by passing `DeviceMesh` objects during optimizer construction. -### 1D Sharding Configuration (Dion2, Muon, NorMuon) +### 1D Sharding Configuration (Dion2, Muon, NorMuon, Aurora) -Most optimizers in this codebase (Dion2, Muon, NorMuon) currently support only 1D sharding. They accept a single 1D device mesh via the `distributed_mesh` argument and adapt their behavior based on how this mesh is used: +Most optimizers in this codebase (Dion2, Muon, NorMuon, Aurora) currently support only 1D sharding. They accept a single 1D device mesh via the `distributed_mesh` argument and adapt their behavior based on how this mesh is used: - **If the mesh is used for parameter sharding**: The optimizer efficiently unshards parameters using all-to-all communication - **If the mesh is not used for sharding**: The optimizer distributes work across devices and all-gathers the final results diff --git a/dion/__init__.py b/dion/__init__.py index 2f1f00b..123be38 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -1,4 +1,5 @@ from .aurora import Aurora +from .aurora_reference import Aurora as AuroraReference from .dion import Dion from .dion import DionMixedPrecisionConfig from .dion_simple import Dion as DionSimple diff --git a/dion/aurora.py b/dion/aurora.py index 24b2707..8b13904 100644 --- a/dion/aurora.py +++ b/dion/aurora.py @@ -39,17 +39,19 @@ class Aurora(DistributedOrthoBase): params: Parameters for the optimizer. distributed_mesh: DeviceMesh or ProcessGroup for distributed training. Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. - lr: Base learning rate. Aurora bakes the tall-aspect-ratio scaling - ``max(1, m/n)**0.5`` into the update, so ``adjust_lr`` defaults to None. + lr: Base learning rate. Scaled by ``adjust_lr`` to convert from spectral + norm 1 to a comparable RMS operator norm, same as Muon/NorMuon. mu: Momentum factor. betas: Tuple of (beta1, beta2) for AdamW and Lion algorithms. weight_decay: Weight decay factor. cautious_wd: Whether to apply weight decay only where update and parameter signs align. epsilon: Small value to avoid division by zero. nesterov: Whether to use Nesterov momentum. - adjust_lr: Optional Muon-style LR adjustment ("spectral_norm", "rms_norm", or None). - None is the Aurora default; the algorithm already applies its own - ``max(1, m/n)**0.5`` aspect-ratio scaling inside the orthogonalization. + adjust_lr: How to adjust the learning rate ("spectral_norm" or "rms_norm" or None). + Same semantics and default as Muon/NorMuon. Note that this differs + slightly from the Aurora reference (which uses ``max(1, m/n)^0.5`` + and so leaves wide matrices unscaled); dion's ``spectral_norm`` + applies ``sqrt(m/n)`` regardless of orientation, matching Muon. flatten: Whether to flatten 3D+ tensors to 2D for the orthogonalization step. pp_iterations: Number of preconditioned-polar iterations. Each iteration calls the base polar (Newton-Schulz) once. ``pp_iterations=2`` is the @@ -76,7 +78,7 @@ def __init__( cautious_wd: bool = False, epsilon: float = 1e-8, nesterov: bool = True, - adjust_lr: Optional[str] = None, + adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, pp_iterations: int = 2, pp_beta: float = 0.5, @@ -325,10 +327,9 @@ def make_aurora_polar( For square matrices this is just ``base_polar(X, epsilon)``. For non-square matrices it transposes to tall, then runs ``pp_iterations`` rounds of diagonal row-preconditioning, calling ``base_polar`` once per - round. The output is multiplied by ``max(1, m/n)**0.5`` to apply the - same aspect-ratio scaling as in the Aurora reference (this replaces - Muon's per-LR ``sqrt(m/n)`` scaling, hence the optimizer's - ``adjust_lr=None`` default). + round. Aspect-ratio scaling is left to the optimizer's ``adjust_lr`` + pathway (the same one Muon/NorMuon use), so the output here has + spectral norm at most 1 and unit row-norm structure. Reference: https://github.com/tilde-research/aurora-release/blob/main/aurora.py """ @@ -358,12 +359,6 @@ def aurora_polar(X: Tensor, epsilon=1e-7) -> Tensor: if transposed: U = U.mT - # Aurora aspect-ratio scaling: max(1, m/n)**0.5. Baked here rather - # than in the LR so callers can use adjust_lr=None and still get the - # correct tall scaling regardless of input orientation. - scale = max(1.0, m / n) ** 0.5 - if scale != 1.0: - U = U * scale return U return aurora_polar diff --git a/dion/aurora_reference.py b/dion/aurora_reference.py new file mode 100644 index 0000000..334355d --- /dev/null +++ b/dion/aurora_reference.py @@ -0,0 +1,248 @@ +"""Reference Aurora implementation. + +Single-file readable port of https://github.com/tilde-research/aurora-release. +Mirrors ``src/aurora.py`` (preconditioned polar) and ``src/polar.py`` +(simple-quintic Newton-Schulz) byte-for-byte in their math, wrapped in a +PyTorch ``Optimizer`` that follows the same param-group conventions as +``muon_reference.Muon``. + +This module is for clarity and reproducibility. For training, prefer +``dion.Aurora``, which integrates with FSDP2 / DDP and uses +``polar_express`` (faster than the simple-quintic polar in this file). + +Aurora: https://blog.tilderesearch.com/blog/aurora +Reference: https://github.com/tilde-research/aurora-release +""" + +import math +import torch +from torch.distributed.tensor import DTensor +from torch.optim.optimizer import Optimizer, ParamsT +from typing import Optional, Tuple + + +@torch.no_grad() +def polar(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """Polar factor via 12-step simple-quintic Newton-Schulz. + + p(sigma) = 2*sigma - 1.5*sigma^3 + 0.5*sigma^5; sigma=1 is super-attracting. + Matches ``src/polar.py`` from the Aurora reference repo. + """ + assert G.ndim >= 2 + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + a, b, c = 2, -1.5, 0.5 + for _ in range(12): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +@torch.no_grad() +def aurora_polar( + G: torch.Tensor, + pp_iterations: int = 2, + pp_beta: float = 0.5, + eps: float = 1e-7, +) -> torch.Tensor: + """Aurora's leverage-uniform polar via diagonal preconditioning. + + For square ``G`` reduces to ``polar(G)``. For non-square ``G`` transposes + to tall, runs ``pp_iterations`` rounds of polar with row-norm + preconditioning, then applies the ``max(1, m/n)^0.5`` aspect-ratio + scaling. Matches ``src/aurora.py`` from the Aurora reference repo. + """ + m, n = G.size(-2), G.size(-1) + if m == n: + update = polar(G, eps=eps) + else: + transposed = m < n + if transposed: + G = G.mT + m, n = n, m + G32 = G.to(torch.float32) + target_row_sq = n / m + row_norm = G32.norm(dim=-1, keepdim=True).clamp_(min=eps) + D = 1.0 / row_norm + for k in range(pp_iterations): + U = polar(D * G32, eps=eps) + if k < pp_iterations - 1: + row_sq = ( + U.to(torch.float32) + .pow(2) + .sum(dim=-1, keepdim=True) + .clamp_(min=eps * eps) + ) + D = D * (target_row_sq / row_sq).pow(pp_beta) + update = U.mT if transposed else U + update = update * (max(1.0, m / n) ** 0.5) + return update + + +class Aurora(Optimizer): + """Reference Aurora optimizer (single-file, no FSDP/DDP integration). + + Mirrors the param-group style of ``dion.muon_reference.Muon``: the + ``algorithm`` key on a param group selects ``aurora`` (default for + matrix params), ``adamw``, or ``lion``. + + For distributed training, use ``dion.Aurora`` instead. + """ + + def __init__( + self, + params: ParamsT, + lr: float = 0.05, + mu: float = 0.95, + betas: Tuple[float, float] = (0.95, 0.95), + weight_decay: float = 0.025, + epsilon: float = 1e-7, + nesterov: bool = True, + pp_iterations: int = 2, + pp_beta: float = 0.5, + ): + defaults = dict( + lr=lr, + momentum=mu, + betas=betas, + weight_decay=weight_decay, + epsilon=epsilon, + nesterov=nesterov, + pp_iterations=pp_iterations, + pp_beta=pp_beta, + ) + super().__init__(params, defaults) + + if isinstance(params, dict): + params = [params] + + for param_or_param_group in params: + if isinstance(param_or_param_group, dict): + algo = param_or_param_group.get("algorithm", "aurora") + if algo not in ("aurora", "adamw", "lion"): + raise ValueError(f"Unknown algorithm: {algo}") + for p in param_or_param_group["params"]: + self.state[p]["algorithm"] = algo + if algo == "aurora" and p.ndim != 2: + raise ValueError( + f"Aurora requires 2D parameters, but got {p.ndim}D" + ) + else: + p = ( + param_or_param_group[1] + if isinstance(param_or_param_group, tuple) + and len(param_or_param_group) == 2 + else param_or_param_group + ) + if not isinstance(p, torch.Tensor): + raise ValueError( + f"Invalid parameter type: {type(param_or_param_group)}" + ) + self.state[p]["algorithm"] = "aurora" + if p.ndim != 2: + raise ValueError( + f"Aurora requires 2D parameters, but got {p.ndim}D" + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + eps = group["epsilon"] + pp_iterations = group["pp_iterations"] + pp_beta = group["pp_beta"] + + aurora_params = [ + p for p in group["params"] if self.state[p]["algorithm"] == "aurora" + ] + for p in aurora_params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + # SGD momentum (Nesterov by default), matching the reference + # ``aurora()`` function's lerp-style EMA momentum. + buf.lerp_(g, 1 - momentum) + u = g.lerp(buf, momentum) if group["nesterov"] else buf.clone() + + if isinstance(u, DTensor): + u_local = u.full_tensor() + u_local = aurora_polar( + u_local, pp_iterations=pp_iterations, pp_beta=pp_beta, eps=eps + ) + u = DTensor.from_local( + u_local, + device_mesh=u.device_mesh, + placements=None, + run_check=False, + ).redistribute(placements=p.placements) + else: + u = aurora_polar( + u, pp_iterations=pp_iterations, pp_beta=pp_beta, eps=eps + ) + + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-lr) + + adamw_params = [ + p for p in group["params"] if self.state[p]["algorithm"] == "adamw" + ] + beta1, beta2 = group["betas"] + for p in adamw_params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + g = buf1 / (eps + buf2.sqrt()) + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.mul_(1 - lr * weight_decay) + p.add_(g, alpha=-lr / scale) + + lion_params = [ + p for p in group["params"] if self.state[p]["algorithm"] == "lion" + ] + for p in lion_params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(p) + buf = state["momentum_buffer"] + update = buf.lerp(g, 1 - beta1).sign_() + buf.lerp_(g, 1 - beta2) + p.mul_(1 - lr * weight_decay) + p.add_(update, alpha=-lr) + + return loss diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 7035e14..8ca9ce4 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -152,6 +152,21 @@ def test_invalid_pp_iterations(self): with pytest.raises(ValueError, match="pp_iterations"): Aurora(_make_params([(32, 64)]), pp_iterations=-1) + def test_reference_runs(self): + """Single-file AuroraReference should run on tall/square/wide.""" + from dion import AuroraReference + params = _make_params([(128, 64), (64, 64), (64, 128)]) + _run_steps(AuroraReference, params, dict(lr=0.05)) + + def test_reference_matches_polar_reference_on_square(self): + """AuroraReference's polar should equal its bundled simple-quintic + polar on square inputs (the diag-preconditioning loop is bypassed).""" + from dion.aurora_reference import polar, aurora_polar + torch.manual_seed(0) + G = torch.randn(128, 128, device=DEVICE) + # max(1, m/n)**0.5 = 1 for square, so the aspect-ratio scaling is 1. + assert torch.equal(aurora_polar(G), polar(G)) + def test_square_matches_muon_polar(self): """For square matrices, Aurora's orthogonalization should equal the underlying polar function bit-for-bit (the diagonal preconditioning @@ -164,8 +179,6 @@ def test_square_matches_muon_polar(self): ap = make_aurora_polar(polar_express, pp_iterations=2, pp_beta=0.5) u_std = polar_express(G, epsilon=1e-7) u_aur = ap(G, epsilon=1e-7) - # Aurora's aspect-ratio scaling is max(1, m/n)**0.5 = 1 for square, - # so the outputs should be exactly equal. assert torch.equal(u_std, u_aur) def test_row_norms_more_uniform_than_polar(self): @@ -180,9 +193,7 @@ def test_row_norms_more_uniform_than_polar(self): G = torch.randn(512, 128, device=DEVICE, dtype=torch.float32) u_std = polar_express(G, epsilon=1e-7).to(torch.float32) ap = make_aurora_polar(polar_express, pp_iterations=2, pp_beta=0.5) - # Strip Aurora's max(1, m/n)**0.5 scaling to compare just the polar shape. - scale = max(1.0, 512 / 128) ** 0.5 - u_aur = (ap(G, epsilon=1e-7).to(torch.float32)) / scale + u_aur = ap(G, epsilon=1e-7).to(torch.float32) rn_std = u_std.norm(dim=-1) rn_aur = u_aur.norm(dim=-1) From 15923e23eb01c54e75f86f45f805a98132739b4f Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 10 May 2026 11:13:44 -0700 Subject: [PATCH 3/4] Aurora: respect runtime mutations to pp_iterations / pp_beta The init-time wrapper closure captured ``pp_iterations`` and ``pp_beta`` as constants, so an LR scheduler or warmup that mutates the param group's ``pp_iterations`` was silently ignored. Every other Aurora hyperparameter (``lr``, ``mu``, ``weight_decay``, ``epsilon``, ``flatten``, ``adjust_lr``, ``cautious_wd``, ``nesterov``) is re-read from the group each step in ``_create_ortho_tasks``; the new hyperparameters now follow the same convention. Stash the unwrapped base polar as ``self._aurora_base_polar`` and rebuild the Aurora wrapper inside ``_create_ortho_tasks`` from the group's current values. Validate at the same point so bad runtime mutations fail fast. Also drop unused ``import math`` from ``aurora_reference.py``. --- dion/aurora.py | 36 +++++++++++++++++++++++++++--------- dion/aurora_reference.py | 1 - tests/test_optimizers.py | 19 +++++++++++++++++++ 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/dion/aurora.py b/dion/aurora.py index 8b13904..2b5b4d6 100644 --- a/dion/aurora.py +++ b/dion/aurora.py @@ -142,11 +142,11 @@ def __init__( else: base_polar = zeropower_via_newtonschulz5 - aurora_polar_func = make_aurora_polar( - base_polar=base_polar, - pp_iterations=pp_iterations, - pp_beta=pp_beta, - ) + # Stash the unwrapped base polar so ``_create_ortho_tasks`` can rebuild + # the Aurora wrapper each step using the param group's current + # ``pp_iterations`` / ``pp_beta`` (which an LR scheduler or warmup + # might mutate, just like ``lr``/``mu``). + self._aurora_base_polar = base_polar defaults = dict( lr=lr, @@ -164,11 +164,13 @@ def __init__( pp_iterations=pp_iterations, pp_beta=pp_beta, ) - # Pass the Aurora-wrapped polar through the base class's - # ``newton_schulz_func`` slot so the existing megabatch path uses it. + # Pass an init-time wrapper as the base ``newton_schulz_func`` so the + # parent class is happy; ``_create_ortho_tasks`` overrides it per-step. super().__init__( params, distributed_mesh, "aurora", defaults, - newton_schulz_func=aurora_polar_func, + newton_schulz_func=make_aurora_polar( + base_polar=base_polar, pp_iterations=pp_iterations, pp_beta=pp_beta, + ), ) def _create_ortho_tasks( @@ -188,6 +190,18 @@ def _create_ortho_tasks( if not group_params: continue + # Re-read pp_iterations / pp_beta from the group every step so an + # LR scheduler or warmup can mutate them (matching how lr/mu/etc. + # are re-read here). Validate to fail fast on bad runtime values. + pp_iterations = group["pp_iterations"] + pp_beta = group["pp_beta"] + if not isinstance(pp_iterations, int) or pp_iterations < 1: + raise ValueError( + f"Invalid pp_iterations: {pp_iterations}. Must be a positive integer." + ) + if pp_beta < 0.0: + raise ValueError(f"Invalid pp_beta: {pp_beta}") + update_args = dict( lr=torch.tensor(group["lr"]), momentum=torch.tensor(group["mu"]), @@ -199,7 +213,11 @@ def _create_ortho_tasks( device_rank=self._device_rank, world_size=self._world_size, process_group=self._process_group, - newton_schulz_func=self._newton_schulz_func, + newton_schulz_func=make_aurora_polar( + base_polar=self._aurora_base_polar, + pp_iterations=pp_iterations, + pp_beta=pp_beta, + ), cautious_wd=group["cautious_wd"], ) diff --git a/dion/aurora_reference.py b/dion/aurora_reference.py index 334355d..28dc1a4 100644 --- a/dion/aurora_reference.py +++ b/dion/aurora_reference.py @@ -14,7 +14,6 @@ Reference: https://github.com/tilde-research/aurora-release """ -import math import torch from torch.distributed.tensor import DTensor from torch.optim.optimizer import Optimizer, ParamsT diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 8ca9ce4..d114e85 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -152,6 +152,25 @@ def test_invalid_pp_iterations(self): with pytest.raises(ValueError, match="pp_iterations"): Aurora(_make_params([(32, 64)]), pp_iterations=-1) + def test_pp_iterations_mutation_takes_effect(self): + """Mutating ``pp_iterations`` on a param group at runtime must change + the update — the wrapper should be rebuilt each step (mirroring how + ``lr`` is re-read from the group).""" + from dion import Aurora + torch.manual_seed(0) + # Tall non-square so pp_iterations actually changes the answer. + params = [torch.nn.Parameter(torch.randn(128, 32, device=DEVICE))] + opt = Aurora(params, lr=0.01, pp_iterations=1) + params[0].grad = torch.randn_like(params[0]) + opt.step() + + opt.param_groups[0]["pp_iterations"] = 3 + params[0].grad = torch.randn_like(params[0]) + # Bad value should now raise from the mutated group state. + opt.param_groups[0]["pp_iterations"] = 0 + with pytest.raises(ValueError, match="pp_iterations"): + opt.step() + def test_reference_runs(self): """Single-file AuroraReference should run on tall/square/wide.""" from dion import AuroraReference From a7e185a1dce6eb5eb9134180634e8d86cdeb7fea Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 10 May 2026 17:21:10 -0700 Subject: [PATCH 4/4] Aurora: add adjust_lr='aurora_aspect' for reference-faithful LR scaling Adds ``adjust_lr_aurora_aspect`` (``lr * sqrt(max(1, m/n))``) alongside ``spectral_norm`` and ``rms_norm``. The Aurora reference uses one-sided ``max(1, m/n)^0.5`` aspect-ratio scaling, which differs from dion's default ``spectral_norm`` (= ``sqrt(m/n)``) on wide matrices: dion shrinks wide-matrix LR by < 1 while the reference leaves it unscaled. This lets ``dion.Aurora(adjust_lr='aurora_aspect')`` reproduce the reference's LR conventions exactly while keeping the megabatch all-to-all communication path. Previously, matching the reference's wide-matrix behavior required falling back to ``dion.AuroraReference``, which uses slower per-param ``full_tensor()`` + ``redistribute()``. --- dion/aurora.py | 22 +++++++++++++++------- dion/megabatch_base.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/dion/aurora.py b/dion/aurora.py index 2b5b4d6..28aa133 100644 --- a/dion/aurora.py +++ b/dion/aurora.py @@ -11,6 +11,7 @@ megabatch_orthogonalize_async, adjust_lr_spectral_norm, adjust_lr_rms_norm, + adjust_lr_aurora_aspect, ) from .muon import muon_update_pre_orthogonalize, muon_update_post_orthogonalize from .newton_schulz_triton import ( @@ -47,11 +48,15 @@ class Aurora(DistributedOrthoBase): cautious_wd: Whether to apply weight decay only where update and parameter signs align. epsilon: Small value to avoid division by zero. nesterov: Whether to use Nesterov momentum. - adjust_lr: How to adjust the learning rate ("spectral_norm" or "rms_norm" or None). - Same semantics and default as Muon/NorMuon. Note that this differs - slightly from the Aurora reference (which uses ``max(1, m/n)^0.5`` - and so leaves wide matrices unscaled); dion's ``spectral_norm`` - applies ``sqrt(m/n)`` regardless of orientation, matching Muon. + adjust_lr: How to adjust the learning rate ("spectral_norm", "rms_norm", + "aurora_aspect", or None). + "spectral_norm" (default, same as Muon/NorMuon): scales by sqrt(m/n) + regardless of orientation. + "aurora_aspect": scales by ``max(1, m/n)^0.5``, matching the Aurora + reference exactly (wide matrices unscaled). Use this for + reference-faithful Aurora. + "rms_norm": Adam-comparable RMS scaling. + None: no LR adjustment. flatten: Whether to flatten 3D+ tensors to 2D for the orthogonalization step. pp_iterations: Number of preconditioned-polar iterations. Each iteration calls the base polar (Newton-Schulz) once. ``pp_iterations=2`` is the @@ -93,9 +98,10 @@ def __init__( raise ValueError(f"Invalid momentum factor (mu): {mu}") if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: raise ValueError(f"Invalid betas: {betas}") - if adjust_lr not in ("spectral_norm", "rms_norm", None): + if adjust_lr not in ("spectral_norm", "rms_norm", "aurora_aspect", None): raise ValueError( - f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." + f"Invalid adjust_lr value: {adjust_lr}. " + "Must be 'spectral_norm', 'rms_norm', 'aurora_aspect', or None." ) if not isinstance(pp_iterations, int) or pp_iterations < 1: raise ValueError( @@ -319,6 +325,8 @@ def aurora_update_megabatch_async( adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) elif adjust_lr == "rms_norm": adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "aurora_aspect": + adjusted_lr = adjust_lr_aurora_aspect(lr, X[0].shape, flatten=flatten) else: raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") diff --git a/dion/megabatch_base.py b/dion/megabatch_base.py index 4fbe88f..bc2aa70 100644 --- a/dion/megabatch_base.py +++ b/dion/megabatch_base.py @@ -559,3 +559,20 @@ def adjust_lr_spectral_norm(lr, param_shape, flatten): else: fan_out, fan_in = param_shape[-2:] return lr * math.sqrt(fan_out / fan_in) + + +def adjust_lr_aurora_aspect(lr, param_shape, flatten): + """One-sided tall aspect-ratio scaling: ``max(1, m/n)^0.5``. + + Matches the Aurora reference's scaling + (https://github.com/tilde-research/aurora-release/blob/main/src/aurora.py): + tall matrices get ``sqrt(m/n) > 1`` boost, wide matrices are unscaled. + Differs from ``adjust_lr_spectral_norm``, which applies ``sqrt(m/n)`` + regardless of orientation (so wide matrices get LR shrunk by ``< 1``). + """ + if flatten: + fan_out = param_shape[0] + fan_in = math.prod(param_shape[1:]) + else: + fan_out, fan_in = param_shape[-2:] + return lr * math.sqrt(max(1.0, fan_out / fan_in))