Summary
Many launch sites in candle-core/src/cuda_backend/mod.rs compute their launch config as
let cfg = LaunchConfig::for_num_elems(el as u32);
el is a usize element count. For a tensor with exactly 2^32 elements, el as u32 wraps to 0, giving grid (0, 1, 1) and CUDA_ERROR_INVALID_VALUE; for counts slightly above 2^32 it silently computes on the wrong (wrapped) element count, which is worse.
Repro
// 128 * 8192 * 32 * 128 = 2^32 elements
let dev = Device::new_cuda(0)?;
let x = Tensor::zeros((128, 8192, 32, 128), DType::F32, &dev)?;
let y = x.to_dtype(DType::BF16)?; // DriverError(CUDA_ERROR_INVALID_VALUE)
Hit in practice while benchmarking flash-attn-v3 at batch=128, seqlen=8192, 32 heads, headdim=128 — a perfectly realistic H200-sized activation (~8.6 GB in f16).
Scope
grep -n "for_num_elems(.* as u32)" candle-core/src/cuda_backend/mod.rs currently shows 10+ call sites (to_dtype, affine, unary/binary ops, where_cond, etc.), plus similar as u32 casts in candle-nn CUDA kernels. Anything that can see a >= 2^32-element tensor is affected.
Possible directions
- Bail with a clear error when
el > u32::MAX (cheap, honest, no silent wrong results), or
- chunk the launch into multiple grids / use 64-bit indexing in the kernels (real fix, larger change).
Happy to send a PR for the bail-with-error variant if that is the preferred direction.
Summary
Many launch sites in
candle-core/src/cuda_backend/mod.rscompute their launch config aselis ausizeelement count. For a tensor with exactly 2^32 elements,el as u32wraps to 0, giving grid(0, 1, 1)andCUDA_ERROR_INVALID_VALUE; for counts slightly above 2^32 it silently computes on the wrong (wrapped) element count, which is worse.Repro
Hit in practice while benchmarking flash-attn-v3 at batch=128, seqlen=8192, 32 heads, headdim=128 — a perfectly realistic H200-sized activation (~8.6 GB in f16).
Scope
grep -n "for_num_elems(.* as u32)" candle-core/src/cuda_backend/mod.rscurrently shows 10+ call sites (to_dtype,affine, unary/binary ops,where_cond, etc.), plus similaras u32casts incandle-nnCUDA kernels. Anything that can see a >= 2^32-element tensor is affected.Possible directions
el > u32::MAX(cheap, honest, no silent wrong results), orHappy to send a PR for the bail-with-error variant if that is the preferred direction.