perf(SwitchGLU): +10% SSD-stream MoE via stacked-buffer + Gate+Up fusion (MLX_MOE_STACKED / MLX_MOE_FUSE_GATEUP)#35
Conversation
Adds a static `MAX_CACHE_SLOTS` on SwitchGLU read from `MLX_MOE_CACHE_SLOTS=N` (default 16, minimum 6). Used by SSD-streaming paths that keep experts resident across tokens to size the per-layer cache. Larger values trade unified memory for hit-rate; the default is chosen to leave headroom for top-k=8 routing plus prev-token speculative prefetch on Apple Silicon systems where the model + KV cache already consume the bulk of available RAM. No behavioral change on its own — the constant is consumed by the following commits in this PR.
…Fused (MLX_MOE_STACKED)
Adds an env-gated fast path that allocates ONE stacked weight buffer
of shape [CACHE_SLOTS, intermediate, hidden] per projection per layer
and populates slots in-place via MLXFast.preadIntoOffset, then issues
a single gatherQuantizedMM dispatch per projection per layer.
The legacy SSD-streaming path allocates one MLXArray per cached
expert and runs gatherQuantizedMM in a per-expert loop (top_k
dispatches × 3 projections × N layers per token). Each Metal
dispatch carries ~30 us of CPU->GPU encode/submit overhead on Apple
Silicon, which dominates per-token compute on SSD-streamed MoE
models. Stacking collapses the per-projection loop into a single
dispatch (rhsIndices = slotPerToken), yielding a measured ~5%
throughput gain on a 122B-A10B MoE workload at top-k=6 / slots=16
before the gate+up fusion in the next commit further extends it.
Details:
- private static let useStackedBuffers reads MLX_MOE_STACKED=1.
Off by default; consumers opt in per launch.
- SwitchGLU gains _stackedGate/_stackedUp/_stackedDown MLXArrays,
a per-slot _slotExpert/_slotLastUsed LRU array, _tokenCounter,
and _stackedBytesPerExpert (computed once at cold init).
- New private runStackedFastPath(x:indices:) -> MLXArray? handles
cold init / asyncEval(idx) / synchronous spec-prefetch of
prev-token experts / LRU resolution / miss pread / fused compute.
Returns nil for any ineligibility (non-quantized projection,
missing SSD info, idx.size > 32, no available slot) so the
caller falls through to the existing N-buffer path. No behavior
change when the env flag is unset.
- New public QuantizedSwitchLinear.computeExpertsFused(_:
stackedBuffer:slotPerToken:slotExperts:) issues the single
gatherQuantizedMM and gathers per-slot scales/biases via
MLX.take.
- SwitchGLU.callAsFunction gains a top-of-function early branch
that delegates to runStackedFastPath when applicable.
Depends on the additive MLXFast.preadIntoOffset Swift wrapper in
mlx-swift (companion PR).
When MLX_MOE_FUSE_GATEUP=1 (and the stacked-buffer path from the previous commit is enabled), allocate ONE combined buffer of shape [CACHE_SLOTS, 2*intermediate, hidden] per layer instead of separate gate/up stacked buffers. Gate weights go in the first half of each slot, up weights in the second half (offsets slot*2*bpe and slot*2*bpe + bpe). A single gatherQuantizedMM produces [..., 2*intermediate]; halves are split via .ellipsis-indexed range subscripts and fed into silu(g) * u. Dispatch reduction: collapses the two projection-level dispatches (gate, up) into one per layer per token. Identical FLOPs (the rows are independent), but each MLX dispatch carries ~30 us of CPU->GPU encode/submit overhead on Apple Silicon. On a 122B-A10B MoE workload at top-k=6 / slots=16 with the stacked-buffer path already enabled, this gives a measured +7-9% on matched-prompt sustained decode (producing ~+8% over the upstream baseline overall). Cold init pre-concatenates qGate.scales || qUp.scales along axis 1 so the runtime gather of scales is also a single MLX.take. Same for biases when both projections carry them. The down projection still uses the per-projection stacked buffer; gate+up share is the only combinable case (their input + output dims match and the SwiGLU activation needs both halves of the same matmul output). No behavior change when MLX_MOE_FUSE_GATEUP is unset; falls through to the per-projection stacked path from the prior commit. No effect on the legacy N-buffer / non-stacked paths.
There was a problem hiding this comment.
Pull request overview
Adds opt-in performance fast paths to SwitchGLU for SSD-streamed MoE inference by reducing Metal dispatches via stacked expert-weight buffers and an optional gate+up fusion, while keeping default behavior unchanged unless env flags are set.
Changes:
- Introduces
MLX_MOE_CACHE_SLOTS(default 16, min 6) to size stacked expert caches. - Adds
MLX_MOE_STACKEDstacked-buffer fast path with LRU slotting + speculative prefetch and a singlegatherQuantizedMMper projection. - Adds
MLX_MOE_FUSE_GATEUPto fuse gate+up into one combined buffer and onegatherQuantizedMM, pluscomputeExpertsFusedinQuantizedSwitchLinear.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if let cb = _combinedGateUpBiases { coldEvalList.append(cb) } | ||
| MLX.eval(coldEvalList) | ||
| _stackedGateUpBytesPerProj = _stackedGateUp!.nbytes / CACHE_SLOTS / 2 | ||
| _stackedBytesPerExpert = _stackedGateUpBytesPerProj // shared with down |
There was a problem hiding this comment.
In fused mode, _stackedBytesPerExpert is derived from the combined gate+up buffer and then reused for down-proj pread offsets (// shared with down). Down weights typically have a different per-expert slab size ([outputDims,inputDims] vs [intermediate,hidden]), so using the gate/up byte stride will write experts into the wrong offsets and corrupt _stackedDown contents. Compute and store a separate bytes-per-slot value for _stackedDown (e.g., _stackedDown!.nbytes / CACHE_SLOTS) and use it for down offsets; keep gate/up stride separate (and in fused mode, keep separate strides for “per-proj” vs “per-slot”).
| _stackedBytesPerExpert = _stackedGateUpBytesPerProj // shared with down | |
| _stackedBytesPerExpert = _stackedDown!.nbytes / CACHE_SLOTS |
| _slotLastUsed = Array(repeating: 0, count: CACHE_SLOTS) | ||
| _tokenCounter = 0 | ||
| MLX.eval([idx, _stackedGate!, _stackedUp!, _stackedDown!]) | ||
| _stackedBytesPerExpert = _stackedGate!.nbytes / CACHE_SLOTS | ||
| } |
There was a problem hiding this comment.
In the non-fused stacked-buffer cold init, _stackedBytesPerExpert is computed from _stackedGate and then used for all projections. This is incorrect for _stackedDown when inputDims != hiddenDims (common MoE FFNs), because down weights have a different [outputDims,inputDims] shape and therefore a different bytes-per-slot stride. Track bytes-per-slot separately per projection (at least gate/up vs down) and use the correct stride when computing dstOffset for preadIntoOffset.
| MLXFast.preadIntoOffset(self._stackedDown!, safetensorsPath: downSSD.path, | ||
| tensorName: downSSD.tensorName, expertIndex: UInt32(info.expertId), dstOffset: info.slot * bpe) | ||
| } |
There was a problem hiding this comment.
preadIntoOffset for the down projection uses dstOffset: info.slot * bpe, but bpe is currently derived from the gate (or gate/up) buffer. If the down weight slab size differs, this offset will be wrong and will overwrite/underfill the stacked down buffer. Use the down buffer’s own bytes-per-slot stride here (and in the analogous miss-handling loop below) instead of reusing bpe.
| /// - Parameters: | ||
| /// - x: input activations, shape `[totalTokens, ..., hidden]`. | ||
| /// - stackedBuffer: weight buffer, shape `[CACHE_SLOTS, intermediate, hidden]`. | ||
| /// Slots are populated externally via `MLXFast.preadIntoOffset`. | ||
| /// - slotPerToken: uint32 array mapping each token (along axis 0 of `x`) | ||
| /// to a slot index in `stackedBuffer`. Built from the routing. | ||
| /// - slotExperts: per-slot expert IDs (`0..<numExperts`). Used to gather | ||
| /// per-slot scales/biases from `self.scales` and `self.biases`. |
There was a problem hiding this comment.
The docstring for computeExpertsFused describes stackedBuffer as [CACHE_SLOTS, intermediate, hidden], but this helper is also used for the down projection where the logical shape is [CACHE_SLOTS, outputDims, inputDims]. Consider updating the documentation (and the inline comments that mention “intermediate”) to describe the buffer generically in terms of outputDims/inputDims to avoid confusion for future maintainers.
| public func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { | ||
| // Stacked-buffer fused-matmul fast path (env-gated MLX_MOE_STACKED=1). | ||
| // Early-out into the stacked path when applicable; otherwise fall | ||
| // through to the existing SSD-streaming / legacy code below. | ||
| if SwitchGLU.useStackedBuffers, | ||
| ExpertStreamingConfig.shared.isEnabled, |
There was a problem hiding this comment.
New env-gated behavior (MLX_MOE_STACKED / MLX_MOE_FUSE_GATEUP) introduces a distinct execution path that bypasses the existing SSD-streaming code. There are tests exercising SwitchGLU instantiation/forward passes, but nothing covering the stacked/fused paths or even the “flag enabled but ineligible → clean fallback” behavior. Please add a small unit/integration test that sets these env vars and validates either (a) the stacked path falls back without crashing for non-quantized layers, and/or (b) outputs match the legacy path for a tiny deterministic setup when the fast path is eligible.
| public func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { | |
| // Stacked-buffer fused-matmul fast path (env-gated MLX_MOE_STACKED=1). | |
| // Early-out into the stacked path when applicable; otherwise fall | |
| // through to the existing SSD-streaming / legacy code below. | |
| if SwitchGLU.useStackedBuffers, | |
| ExpertStreamingConfig.shared.isEnabled, | |
| private func shouldAttemptStackedFastPath() -> Bool { | |
| guard SwitchGLU.useStackedBuffers, | |
| ExpertStreamingConfig.shared.isEnabled | |
| else { | |
| return false | |
| } | |
| // The env flag must not be enough on its own to enter the stacked path. | |
| // Only attempt it when the stacked gate/up state has actually been | |
| // prepared; otherwise cleanly fall back to the legacy implementation. | |
| guard _stackedGateUp != nil, | |
| _combinedGateUpScales != nil | |
| else { | |
| return false | |
| } | |
| return true | |
| } | |
| public func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { | |
| // Stacked-buffer fused-matmul fast path (env-gated MLX_MOE_STACKED=1). | |
| // Early-out into the stacked path when applicable; otherwise fall | |
| // through to the existing SSD-streaming / legacy code below. | |
| if self.shouldAttemptStackedFastPath(), |
Both ci.yml and downstream_integration.yml check out SwiftLM with submodules: recursive, which pins mlx-swift to SwiftLM's submodule commit. When mlx-swift-lm PRs depend on new mlx-swift APIs (e.g. preadIntoOffset from mlx-swift#10), the build fails because the submodule pin is stale. Fix: after replacing mlx-swift-lm with the PR checkout, also pull mlx-swift to latest origin/main so new APIs are available.
- Compute and use separate `_stackedDownBytesPerExpert` for the down projection so it doesn't incorrectly reuse the gate/up stride. - Fix docstring for `computeExpertsFused` to refer generically to outputDims/inputDims instead of intermediate/hidden. - Add `StackedMoETests.swift` unit test to verify the fast path cleanly falls back without crashing when enabled on non-quantized models.
solderzzc
left a comment
There was a problem hiding this comment.
Thanks for the review @ericjlake! I've pushed fixes for all these points in 6a57d00:
- Added
_stackedDownBytesPerExpertto track the stride for the down projection separately and updated both the speculative prefetch loop and the on-demand miss loop to use it. - Updated the
computeExpertsFuseddocstring to refer generically tooutputDimsandinputDimsinstead ofintermediateandhidden. - Added
StackedMoETests.swiftwhich sets the env flags and runs a forward pass on a non-quantized model to ensure the fast path correctly and safely falls back without crashing.
(Regarding the shouldAttemptStackedFastPath suggestion — since _stackedGateUp is allocated lazily on the first token inside runStackedFastPath, checking _stackedGateUp != nil before entering the fast path would permanently prevent it from initializing. The fallback is instead handled correctly inside runStackedFastPath by returning nil if the layer is ineligible.)
Pins `mlx-swift` and `mlx-swift-lm` to their respective latest `main` commits, pulling in the newly merged stacked-buffer MoE optimizations for SwiftLM. - SharpAI/mlx-swift#10 - SharpAI/mlx-swift-lm#35
Updates mlx-swift and mlx-swift-lm submodules to pull in the stacked-buffer MoE fast-paths and memory-safety optimizations. Submodule Updates: - `mlx-swift`: SharpAI/mlx-swift#10 - `mlx-swift-lm`: SharpAI/mlx-swift-lm#35 All downstream integration tests pass.
Summary
Cherry-pick of 3 commits from ericjlake/mlx-swift-lm@feat/stacked-moe-fastpath, needed to unblock SharpAI/SwiftLM#90.
All three flags default OFF. Behavior is identical to
mainwhen env vars are unset.Commits
a2883a2(→ fbafa95)MLX_MOE_CACHE_SLOTSenv tunable (+15 lines)f432840(→ 5c7b402)computeExpertsFused(MLX_MOE_STACKED) (+314 lines)57ec366(→ c68d007)MLX_MOE_FUSE_GATEUP=1) (+161/−27 lines)How it works
MLX_MOE_STACKED=1— allocates ONE[CACHE_SLOTS, intermediate, hidden]weight buffer per projection per layer; populates slots in-place viaMLXFast.preadIntoOffset; issues ONEgatherQuantizedMMper projection (rhsIndices = slotPerToken) instead oftop_kseparate dispatches. Each Metal dispatch carries ~30 µs of CPU→GPU encode/submit overhead on Apple Silicon, which dominates per-token compute on SSD-streamed MoE.MLX_MOE_FUSE_GATEUP=1(requiresMLX_MOE_STACKED=1) — collapses gate+up into one combined[CACHE_SLOTS, 2*intermediate, hidden]buffer; onegatherQuantizedMMproduces[..., 2*intermediate], halves split and fed intosilu(g) * u. Saves one projection-level dispatch per layer per token.MLX_MOE_CACHE_SLOTS=N(default 16, min 6) — cache-slot tunable.Eligibility: stacked path engages only when all 3 projections are quantized +
resolveSSDInfo()succeeds +idx.size <= 32(single-token decode). Ineligible layers and prompt batches returnniland fall through to the existing N-buffer path.Benchmark (Qwen3.5-122B-A10B-4bit, M1 Ultra 64 GB, top-k=6, slots=16)
Matched 600-token prompt, mean of 3 runs:
MLX_MOE_STACKED=1onlyMLX_MOE_STACKED=1 MLX_MOE_FUSE_GATEUP=1How to test
Dependency
Requires
SharpAI/mlx-swift#10(feat/preadIntoOffset —MLXFast.preadIntoOffset) to be merged first.Relationship to SwiftLM#90
Once this and mlx-swift#10 are merged into their respective
mainbranches, SwiftLM PR#90'sgit submodule updatewill resolve cleanly from canonical SharpAI URLs.Co-authored-by: Eric Lake