Skip to content

candle-core CUDA backend: LaunchConfig::for_num_elems(el as u32) silently truncates; ops fail on tensors with >= 2^32 elements #3604

Description

@Xueying-VirtueAI

Summary

Many launch sites in candle-core/src/cuda_backend/mod.rs compute their launch config as

let cfg = LaunchConfig::for_num_elems(el as u32);

el is a usize element count. For a tensor with exactly 2^32 elements, el as u32 wraps to 0, giving grid (0, 1, 1) and CUDA_ERROR_INVALID_VALUE; for counts slightly above 2^32 it silently computes on the wrong (wrapped) element count, which is worse.

Repro

// 128 * 8192 * 32 * 128 = 2^32 elements
let dev = Device::new_cuda(0)?;
let x = Tensor::zeros((128, 8192, 32, 128), DType::F32, &dev)?;
let y = x.to_dtype(DType::BF16)?; // DriverError(CUDA_ERROR_INVALID_VALUE)

Hit in practice while benchmarking flash-attn-v3 at batch=128, seqlen=8192, 32 heads, headdim=128 — a perfectly realistic H200-sized activation (~8.6 GB in f16).

Scope

grep -n "for_num_elems(.* as u32)" candle-core/src/cuda_backend/mod.rs currently shows 10+ call sites (to_dtype, affine, unary/binary ops, where_cond, etc.), plus similar as u32 casts in candle-nn CUDA kernels. Anything that can see a >= 2^32-element tensor is affected.

Possible directions

  • Bail with a clear error when el > u32::MAX (cheap, honest, no silent wrong results), or
  • chunk the launch into multiple grids / use 64-bit indexing in the kernels (real fix, larger change).

Happy to send a PR for the bail-with-error variant if that is the preferred direction.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions