Skip to content

perf(SwitchGLU): +10% SSD-stream MoE via stacked-buffer + Gate+Up fusion (MLX_MOE_STACKED / MLX_MOE_FUSE_GATEUP)#35

Merged
solderzzc merged 5 commits into
mainfrom
feat/stacked-moe-fastpath
Apr 27, 2026
Merged

perf(SwitchGLU): +10% SSD-stream MoE via stacked-buffer + Gate+Up fusion (MLX_MOE_STACKED / MLX_MOE_FUSE_GATEUP)#35
solderzzc merged 5 commits into
mainfrom
feat/stacked-moe-fastpath

Conversation

@solderzzc

@solderzzc solderzzc commented Apr 27, 2026

Copy link
Copy Markdown
Member

Summary

Cherry-pick of 3 commits from ericjlake/mlx-swift-lm@feat/stacked-moe-fastpath, needed to unblock SharpAI/SwiftLM#90.

All three flags default OFF. Behavior is identical to main when env vars are unset.

Commits

SHA Summary
a2883a2 (→ fbafa95) feat(SwitchGLU): MLX_MOE_CACHE_SLOTS env tunable (+15 lines)
f432840 (→ 5c7b402) feat(SwitchGLU): stacked-buffer SSD-stream fast path + computeExpertsFused (MLX_MOE_STACKED) (+314 lines)
57ec366 (→ c68d007) feat(SwitchGLU): Gate+Up SwiGLU matmul fusion (MLX_MOE_FUSE_GATEUP=1) (+161/−27 lines)

How it works

MLX_MOE_STACKED=1 — allocates ONE [CACHE_SLOTS, intermediate, hidden] weight buffer per projection per layer; populates slots in-place via MLXFast.preadIntoOffset; issues ONE gatherQuantizedMM per projection (rhsIndices = slotPerToken) instead of top_k separate dispatches. Each Metal dispatch carries ~30 µs of CPU→GPU encode/submit overhead on Apple Silicon, which dominates per-token compute on SSD-streamed MoE.

MLX_MOE_FUSE_GATEUP=1 (requires MLX_MOE_STACKED=1) — collapses gate+up into one combined [CACHE_SLOTS, 2*intermediate, hidden] buffer; one gatherQuantizedMM produces [..., 2*intermediate], halves split and fed into silu(g) * u. Saves one projection-level dispatch per layer per token.

MLX_MOE_CACHE_SLOTS=N (default 16, min 6) — cache-slot tunable.

Eligibility: stacked path engages only when all 3 projections are quantized + resolveSSDInfo() succeeds + idx.size <= 32 (single-token decode). Ineligible layers and prompt batches return nil and fall through to the existing N-buffer path.

Benchmark (Qwen3.5-122B-A10B-4bit, M1 Ultra 64 GB, top-k=6, slots=16)

Matched 600-token prompt, mean of 3 runs:

Config t/s Δ vs legacy
upstream baseline (legacy N-buffer) ~5.12
MLX_MOE_STACKED=1 only ~5.92 +15.6%
MLX_MOE_STACKED=1 MLX_MOE_FUSE_GATEUP=1 ~5.64 +10.2%

How to test

# From SwiftLM PR#90 branch (after both submodule PRs are merged):
git submodule update --init --recursive
swift build -c release --product SwiftLM
MLX_MOE_STACKED=1 MLX_MOE_FUSE_GATEUP=1 \
  MLX_MOE_CACHE_SLOTS=16 SWIFTLM_TOP_K=6 \
  .build/arm64-apple-macosx/release/SwiftLM \
    --model <path>/Qwen3.5-122B-A10B-4bit \
    --port 8002 --stream-experts

# Verify legacy path unchanged (env unset or =0)
.build/arm64-apple-macosx/release/SwiftLM --model <path>/Qwen3.5-122B-A10B-4bit --port 8002 --stream-experts

Dependency

Requires SharpAI/mlx-swift#10 (feat/preadIntoOffset — MLXFast.preadIntoOffset) to be merged first.

Relationship to SwiftLM#90

Once this and mlx-swift#10 are merged into their respective main branches, SwiftLM PR#90's git submodule update will resolve cleanly from canonical SharpAI URLs.

Co-authored-by: Eric Lake

Adds a static `MAX_CACHE_SLOTS` on SwitchGLU read from
`MLX_MOE_CACHE_SLOTS=N` (default 16, minimum 6). Used by
SSD-streaming paths that keep experts resident across tokens to
size the per-layer cache. Larger values trade unified memory for
hit-rate; the default is chosen to leave headroom for top-k=8
routing plus prev-token speculative prefetch on Apple Silicon
systems where the model + KV cache already consume the bulk of
available RAM.

No behavioral change on its own — the constant is consumed by the
following commits in this PR.
…Fused (MLX_MOE_STACKED)

Adds an env-gated fast path that allocates ONE stacked weight buffer
of shape [CACHE_SLOTS, intermediate, hidden] per projection per layer
and populates slots in-place via MLXFast.preadIntoOffset, then issues
a single gatherQuantizedMM dispatch per projection per layer.

The legacy SSD-streaming path allocates one MLXArray per cached
expert and runs gatherQuantizedMM in a per-expert loop (top_k
dispatches × 3 projections × N layers per token). Each Metal
dispatch carries ~30 us of CPU->GPU encode/submit overhead on Apple
Silicon, which dominates per-token compute on SSD-streamed MoE
models. Stacking collapses the per-projection loop into a single
dispatch (rhsIndices = slotPerToken), yielding a measured ~5%
throughput gain on a 122B-A10B MoE workload at top-k=6 / slots=16
before the gate+up fusion in the next commit further extends it.

Details:
  - private static let useStackedBuffers reads MLX_MOE_STACKED=1.
    Off by default; consumers opt in per launch.
  - SwitchGLU gains _stackedGate/_stackedUp/_stackedDown MLXArrays,
    a per-slot _slotExpert/_slotLastUsed LRU array, _tokenCounter,
    and _stackedBytesPerExpert (computed once at cold init).
  - New private runStackedFastPath(x:indices:) -> MLXArray? handles
    cold init / asyncEval(idx) / synchronous spec-prefetch of
    prev-token experts / LRU resolution / miss pread / fused compute.
    Returns nil for any ineligibility (non-quantized projection,
    missing SSD info, idx.size > 32, no available slot) so the
    caller falls through to the existing N-buffer path. No behavior
    change when the env flag is unset.
  - New public QuantizedSwitchLinear.computeExpertsFused(_:
    stackedBuffer:slotPerToken:slotExperts:) issues the single
    gatherQuantizedMM and gathers per-slot scales/biases via
    MLX.take.
  - SwitchGLU.callAsFunction gains a top-of-function early branch
    that delegates to runStackedFastPath when applicable.

Depends on the additive MLXFast.preadIntoOffset Swift wrapper in
mlx-swift (companion PR).
When MLX_MOE_FUSE_GATEUP=1 (and the stacked-buffer path from the
previous commit is enabled), allocate ONE combined buffer of shape
[CACHE_SLOTS, 2*intermediate, hidden] per layer instead of separate
gate/up stacked buffers. Gate weights go in the first half of each
slot, up weights in the second half (offsets slot*2*bpe and
slot*2*bpe + bpe). A single gatherQuantizedMM produces
[..., 2*intermediate]; halves are split via .ellipsis-indexed range
subscripts and fed into silu(g) * u.

Dispatch reduction: collapses the two projection-level dispatches
(gate, up) into one per layer per token. Identical FLOPs (the rows
are independent), but each MLX dispatch carries ~30 us of CPU->GPU
encode/submit overhead on Apple Silicon. On a 122B-A10B MoE workload
at top-k=6 / slots=16 with the stacked-buffer path already enabled,
this gives a measured +7-9% on matched-prompt sustained decode
(producing ~+8% over the upstream baseline overall).

Cold init pre-concatenates qGate.scales || qUp.scales along axis 1
so the runtime gather of scales is also a single MLX.take. Same for
biases when both projections carry them. The down projection still
uses the per-projection stacked buffer; gate+up share is the only
combinable case (their input + output dims match and the SwiGLU
activation needs both halves of the same matmul output).

No behavior change when MLX_MOE_FUSE_GATEUP is unset; falls through
to the per-projection stacked path from the prior commit. No effect
on the legacy N-buffer / non-stacked paths.
Copilot AI review requested due to automatic review settings April 27, 2026 01:45

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds opt-in performance fast paths to SwitchGLU for SSD-streamed MoE inference by reducing Metal dispatches via stacked expert-weight buffers and an optional gate+up fusion, while keeping default behavior unchanged unless env flags are set.

Changes:

  • Introduces MLX_MOE_CACHE_SLOTS (default 16, min 6) to size stacked expert caches.
  • Adds MLX_MOE_STACKED stacked-buffer fast path with LRU slotting + speculative prefetch and a single gatherQuantizedMM per projection.
  • Adds MLX_MOE_FUSE_GATEUP to fuse gate+up into one combined buffer and one gatherQuantizedMM, plus computeExpertsFused in QuantizedSwitchLinear.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if let cb = _combinedGateUpBiases { coldEvalList.append(cb) }
MLX.eval(coldEvalList)
_stackedGateUpBytesPerProj = _stackedGateUp!.nbytes / CACHE_SLOTS / 2
_stackedBytesPerExpert = _stackedGateUpBytesPerProj // shared with down

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fused mode, _stackedBytesPerExpert is derived from the combined gate+up buffer and then reused for down-proj pread offsets (// shared with down). Down weights typically have a different per-expert slab size ([outputDims,inputDims] vs [intermediate,hidden]), so using the gate/up byte stride will write experts into the wrong offsets and corrupt _stackedDown contents. Compute and store a separate bytes-per-slot value for _stackedDown (e.g., _stackedDown!.nbytes / CACHE_SLOTS) and use it for down offsets; keep gate/up stride separate (and in fused mode, keep separate strides for “per-proj” vs “per-slot”).

Suggested change
_stackedBytesPerExpert = _stackedGateUpBytesPerProj // shared with down
_stackedBytesPerExpert = _stackedDown!.nbytes / CACHE_SLOTS

Copilot uses AI. Check for mistakes.
Comment on lines +219 to +223
_slotLastUsed = Array(repeating: 0, count: CACHE_SLOTS)
_tokenCounter = 0
MLX.eval([idx, _stackedGate!, _stackedUp!, _stackedDown!])
_stackedBytesPerExpert = _stackedGate!.nbytes / CACHE_SLOTS
}

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-fused stacked-buffer cold init, _stackedBytesPerExpert is computed from _stackedGate and then used for all projections. This is incorrect for _stackedDown when inputDims != hiddenDims (common MoE FFNs), because down weights have a different [outputDims,inputDims] shape and therefore a different bytes-per-slot stride. Track bytes-per-slot separately per projection (at least gate/up vs down) and use the correct stride when computing dstOffset for preadIntoOffset.

Copilot uses AI. Check for mistakes.
Comment on lines +297 to +299
MLXFast.preadIntoOffset(self._stackedDown!, safetensorsPath: downSSD.path,
tensorName: downSSD.tensorName, expertIndex: UInt32(info.expertId), dstOffset: info.slot * bpe)
}

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preadIntoOffset for the down projection uses dstOffset: info.slot * bpe, but bpe is currently derived from the gate (or gate/up) buffer. If the down weight slab size differs, this offset will be wrong and will overwrite/underfill the stacked down buffer. Use the down buffer’s own bytes-per-slot stride here (and in the analogous miss-handling loop below) instead of reusing bpe.

Copilot uses AI. Check for mistakes.
Comment on lines +1185 to +1192
/// - Parameters:
/// - x: input activations, shape `[totalTokens, ..., hidden]`.
/// - stackedBuffer: weight buffer, shape `[CACHE_SLOTS, intermediate, hidden]`.
/// Slots are populated externally via `MLXFast.preadIntoOffset`.
/// - slotPerToken: uint32 array mapping each token (along axis 0 of `x`)
/// to a slot index in `stackedBuffer`. Built from the routing.
/// - slotExperts: per-slot expert IDs (`0..<numExperts`). Used to gather
/// per-slot scales/biases from `self.scales` and `self.biases`.

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for computeExpertsFused describes stackedBuffer as [CACHE_SLOTS, intermediate, hidden], but this helper is also used for the down projection where the logical shape is [CACHE_SLOTS, outputDims, inputDims]. Consider updating the documentation (and the inline comments that mention “intermediate”) to describe the buffer generically in terms of outputDims/inputDims to avoid confusion for future maintainers.

Copilot uses AI. Check for mistakes.
Comment on lines 483 to +488
public func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
// Stacked-buffer fused-matmul fast path (env-gated MLX_MOE_STACKED=1).
// Early-out into the stacked path when applicable; otherwise fall
// through to the existing SSD-streaming / legacy code below.
if SwitchGLU.useStackedBuffers,
ExpertStreamingConfig.shared.isEnabled,

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New env-gated behavior (MLX_MOE_STACKED / MLX_MOE_FUSE_GATEUP) introduces a distinct execution path that bypasses the existing SSD-streaming code. There are tests exercising SwitchGLU instantiation/forward passes, but nothing covering the stacked/fused paths or even the “flag enabled but ineligible → clean fallback” behavior. Please add a small unit/integration test that sets these env vars and validates either (a) the stacked path falls back without crashing for non-quantized layers, and/or (b) outputs match the legacy path for a tiny deterministic setup when the fast path is eligible.

Suggested change
public func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
// Stacked-buffer fused-matmul fast path (env-gated MLX_MOE_STACKED=1).
// Early-out into the stacked path when applicable; otherwise fall
// through to the existing SSD-streaming / legacy code below.
if SwitchGLU.useStackedBuffers,
ExpertStreamingConfig.shared.isEnabled,
private func shouldAttemptStackedFastPath() -> Bool {
guard SwitchGLU.useStackedBuffers,
ExpertStreamingConfig.shared.isEnabled
else {
return false
}
// The env flag must not be enough on its own to enter the stacked path.
// Only attempt it when the stacked gate/up state has actually been
// prepared; otherwise cleanly fall back to the legacy implementation.
guard _stackedGateUp != nil,
_combinedGateUpScales != nil
else {
return false
}
return true
}
public func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
// Stacked-buffer fused-matmul fast path (env-gated MLX_MOE_STACKED=1).
// Early-out into the stacked path when applicable; otherwise fall
// through to the existing SSD-streaming / legacy code below.
if self.shouldAttemptStackedFastPath(),

Copilot uses AI. Check for mistakes.
Aegis-AI added 2 commits April 26, 2026 19:26
Both ci.yml and downstream_integration.yml check out SwiftLM with
submodules: recursive, which pins mlx-swift to SwiftLM's submodule
commit. When mlx-swift-lm PRs depend on new mlx-swift APIs (e.g.
preadIntoOffset from mlx-swift#10), the build fails because the
submodule pin is stale.

Fix: after replacing mlx-swift-lm with the PR checkout, also pull
mlx-swift to latest origin/main so new APIs are available.
- Compute and use separate `_stackedDownBytesPerExpert` for the down
  projection so it doesn't incorrectly reuse the gate/up stride.
- Fix docstring for `computeExpertsFused` to refer generically to
  outputDims/inputDims instead of intermediate/hidden.
- Add `StackedMoETests.swift` unit test to verify the fast path cleanly
  falls back without crashing when enabled on non-quantized models.

@solderzzc solderzzc left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @ericjlake! I've pushed fixes for all these points in 6a57d00:

  1. Added _stackedDownBytesPerExpert to track the stride for the down projection separately and updated both the speculative prefetch loop and the on-demand miss loop to use it.
  2. Updated the computeExpertsFused docstring to refer generically to outputDims and inputDims instead of intermediate and hidden.
  3. Added StackedMoETests.swift which sets the env flags and runs a forward pass on a non-quantized model to ensure the fast path correctly and safely falls back without crashing.

(Regarding the shouldAttemptStackedFastPath suggestion — since _stackedGateUp is allocated lazily on the first token inside runStackedFastPath, checking _stackedGateUp != nil before entering the fast path would permanently prevent it from initializing. The fallback is instead handled correctly inside runStackedFastPath by returning nil if the layer is ineligible.)

@solderzzc solderzzc merged commit 4c7301d into main Apr 27, 2026
6 checks passed
@solderzzc solderzzc deleted the feat/stacked-moe-fastpath branch April 27, 2026 03:36
solderzzc pushed a commit to ericjlake/SwiftLM that referenced this pull request Apr 27, 2026
Pins `mlx-swift` and `mlx-swift-lm` to their respective latest `main`
commits, pulling in the newly merged stacked-buffer MoE optimizations
for SwiftLM.

- SharpAI/mlx-swift#10
- SharpAI/mlx-swift-lm#35
solderzzc pushed a commit to SharpAI/SwiftLM that referenced this pull request Apr 27, 2026
Updates mlx-swift and mlx-swift-lm submodules to pull in the stacked-buffer MoE fast-paths and memory-safety optimizations.

Submodule Updates:
- `mlx-swift`: SharpAI/mlx-swift#10
- `mlx-swift-lm`: SharpAI/mlx-swift-lm#35

All downstream integration tests pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants