Skip to content

Metal apple silicon fixes#3640

Open
tyt2y3 wants to merge 2 commits into
huggingface:mainfrom
visioncortex:metal-apple-silicon-fixes
Open

Metal apple silicon fixes#3640
tyt2y3 wants to merge 2 commits into
huggingface:mainfrom
visioncortex:metal-apple-silicon-fixes

Conversation

@tyt2y3

@tyt2y3 tyt2y3 commented Jun 22, 2026

Copy link
Copy Markdown

Two independent Apple-Silicon Metal fixes, one commit each:

1. MetalDevice::new panics on Apple Silicon

MetalDevice::new does Device::all().swap_remove(ordinal). On Apple Silicon, MTLCopyAllDevices returns no devices, so swap_remove(0) panics on an empty vec — Metal is unusable on M-series:

thread 'main' panicked: swap_remove index (is 0) should be < len (is 0)

Fix: use Device::system_default() (MTLCreateSystemDefaultDevice) for the default device, falling back to all() with a bounds check for explicit additional-GPU ordinals.

2. conv weight-gradient is wrong on Metal — training silently diverges

The conv1d / conv2d / conv_transpose2d backward computes the kernel gradient by convolving a transposed (non-contiguous) input:

let grad_kernel = arg.transpose(0, 1)?.conv2d(&grad.transpose(0, 1)?, ...)?...

The Metal im2col reads its input as if contiguous, so the weight gradient is incorrect on Metal. The forward pass and the input gradient are correct — so inference works fine, but training diverges: every conv kernel is updated in the wrong direction and the loss climbs.

Fix: force the transposed operands contiguous before the conv. No-op on CPU.

Verification

Minimal CPU-vs-Metal check of conv2d through loss.backward() (max abs diff):

forward input-grad weight-grad
before 2e-7 2e-6 2.95e1
after 2e-7 2e-6 7.6e-6

A small residual U-Net that diverged when trained on Metal (loss climbing, PSNR going negative) trains normally after this change, matching the CPU result step-for-step. group_norm, reshape/permute, and the conv input gradient were all already correct on Metal — only the conv weight gradient was affected.

tyt2y3 added 2 commits June 22, 2026 14:54
MetalDevice::new called Device::all().swap_remove(ordinal). On Apple Silicon MTLCopyAllDevices returns no devices, so swap_remove panics on an empty vec. Use Device::system_default() (MTLCreateSystemDefaultDevice) for the default device, falling back to all() with a bounds check for explicit additional GPU ordinals.
The conv1d/conv2d/conv_transpose2d backward computes the kernel gradient by convolving a transposed (non-contiguous) input. The Metal im2col reads the input as if contiguous, so the weight gradient is wrong on Metal (forward and input gradients are correct), and any conv net trained on Metal diverges as every kernel is updated in the wrong direction. Force the transposed operands contiguous before the conv. No-op on CPU; corrects the Metal weight gradient (verified: max abs diff vs CPU drops from ~6 to ~2e-6).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant