Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ This repository provides efficient implementations of orthonormal optimizers for
You can find the following optimizers:
* [Muon](https://kellerjordan.github.io/posts/muon/)
* [Dion2](https://arxiv.org/abs/2512.16928) and [Dion](https://arxiv.org/pdf/2504.05295) (Dion is a legacy optimizer; we recommend using Dion2)
* [NorMuon](https://arxiv.org/abs/2510.05491)
* [NorMuon](https://arxiv.org/abs/2510.05491)
* [Aurora](https://blog.tilderesearch.com/blog/aurora)


## Table of Contents
Expand Down Expand Up @@ -53,7 +54,7 @@ pip install git+https://github.com/microsoft/dion.git
Then in your code, you can use:

```python
from dion import Dion2, Muon, NorMuon, Dion
from dion import Dion2, Muon, NorMuon, Dion, Aurora
```

Please carefully go through this readme for detailed instructions on using our optimizers. There are major differences compared to PyTorch built-in optimizers, such as `Adam`/`AdamW`.
Expand Down Expand Up @@ -146,12 +147,12 @@ The practical effectiveness of orthonormal optimizers was first demonstrated by

Our current implementations support the following parallelization techniques:

| Parallelization | Dion | Dion2 | Muon | NorMuon |
|--------------------|------|-------|------|---------|
| Single device | Yes | Yes | Yes | Yes |
| PyTorch DDP | Yes | Yes | Yes | Yes |
| PyTorch FSDP2 | Yes | Yes | Yes | Yes |
| PyTorch FSDP2 + TP | Yes | No | No | No |
| Parallelization | Dion | Dion2 | Muon | NorMuon | Aurora |
|--------------------|------|-------|------|---------|--------|
| Single device | Yes | Yes | Yes | Yes | Yes |
| PyTorch DDP | Yes | Yes | Yes | Yes | Yes |
| PyTorch FSDP2 | Yes | Yes | Yes | Yes | Yes |
| PyTorch FSDP2 + TP | Yes | No | No | No | No |

For faster performance, these optimizers will process parameters in batches and interleave multiple batches to overlap compute with communication.

Expand All @@ -161,12 +162,14 @@ We include optimizer implementations in the `dion/` directory of this repo.
* `muon.py`: High-performance version of Muon. For sharded matrices, all-to-all communication is used to simultaneously unshard and distribute a batch of matrices. For replicated matrices, Muon will distribute work across all devices and all-gather final results.
* **`dion2.py`**: High-performance implementation of Dion2, using a similar all-to-all communication pattern for distributed orthonormalization. Only an α-fraction of the momentum matrix is communicated and orthonormalized, significantly reducing both communication overhead and computation cost.
* `normuon.py`: A variant of the Muon optimizer that introduces neuron-wise normalization to improve stability and convergence efficiency, modified to take similar arguments as `muon.py`. See [the paper](https://arxiv.org/abs/2510.05491) for more details.
* `aurora.py`: An optimizer for non-square matrices that produces leverage-uniform updates by iteratively row-preconditioning the polar (Newton-Schulz) factorization. For square matrices it reduces to standard Muon; for non-square ones it tightens the row-norm distribution of the orthogonalized update so all neurons receive comparably-sized steps. See [the Aurora blog post](https://blog.tilderesearch.com/blog/aurora) for the algorithm; uses the same `muon.py` mega-batch infrastructure.

We also provide some reference implementations:

* `dion_reference.py`: An implementation without batching, communication overlapping, or split all-reduce. This version of Dion is intended to closely follow the algorithms as described in our [Dion paper](https://arxiv.org/pdf/2504.05295).
* `dion_simple.py`: A simplified illustration of the Dion update rule in a single Python function, provided for educational value.
* `muon_reference.py`: A version of Muon by [Moonshot AI](https://github.com/MoonshotAI/Moonlight), modified to take similar arguments as `muon.py`.
* `aurora_reference.py`: A single-file readable port of [tilde-research/aurora-release](https://github.com/tilde-research/aurora-release), using the simple-quintic Newton-Schulz from the original Aurora repo.



Expand Down Expand Up @@ -264,9 +267,9 @@ Requirements: the parameter must be 2D, `num_heads` must divide dim 0, and when

For our efficient distributed optimizers to work correctly, they need information about the model's parallelization scheme. This is provided by passing `DeviceMesh` objects during optimizer construction.

### 1D Sharding Configuration (Dion2, Muon, NorMuon)
### 1D Sharding Configuration (Dion2, Muon, NorMuon, Aurora)

Most optimizers in this codebase (Dion2, Muon, NorMuon) currently support only 1D sharding. They accept a single 1D device mesh via the `distributed_mesh` argument and adapt their behavior based on how this mesh is used:
Most optimizers in this codebase (Dion2, Muon, NorMuon, Aurora) currently support only 1D sharding. They accept a single 1D device mesh via the `distributed_mesh` argument and adapt their behavior based on how this mesh is used:

- **If the mesh is used for parameter sharding**: The optimizer efficiently unshards parameters using all-to-all communication
- **If the mesh is not used for sharding**: The optimizer distributes work across devices and all-gathers the final results
Expand Down
2 changes: 2 additions & 0 deletions dion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .aurora import Aurora
from .aurora_reference import Aurora as AuroraReference
from .dion import Dion
from .dion import DionMixedPrecisionConfig
from .dion_simple import Dion as DionSimple
Expand Down
Loading