diff --git a/src/infiniop/ops/paged_attention/info.h b/src/infiniop/ops/paged_attention/info.h index 4b840af69..70d148ce0 100644 --- a/src/infiniop/ops/paged_attention/info.h +++ b/src/infiniop/ops/paged_attention/info.h @@ -47,7 +47,7 @@ class PagedAttentionInfo { float scale) { auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca b/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca index 131ac2343..b7222f8dc 100644 --- a/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_hd128.maca @@ -797,6 +797,20 @@ infiniStatus_t launch_decode_hd128_impl( static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd128SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -923,6 +937,19 @@ infiniStatus_t launch_decode_hd128_impl( return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd128Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca b/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca index 2f8b95b3a..046125ce5 100644 --- a/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_hd64.maca @@ -346,6 +346,20 @@ infiniStatus_t launch_decode_hd64_impl( static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -423,6 +437,19 @@ infiniStatus_t launch_decode_hd64_impl( return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd64Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention/moore/paged_attention_hd128.mu b/src/infiniop/ops/paged_attention/moore/paged_attention_hd128.mu index 2ee720e29..229b057e3 100644 --- a/src/infiniop/ops/paged_attention/moore/paged_attention_hd128.mu +++ b/src/infiniop/ops/paged_attention/moore/paged_attention_hd128.mu @@ -793,6 +793,20 @@ infiniStatus_t launch_decode_hd128_impl( static_cast<__mt_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd128SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -919,6 +933,19 @@ infiniStatus_t launch_decode_hd128_impl( return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd128Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention/moore/paged_attention_hd64.mu b/src/infiniop/ops/paged_attention/moore/paged_attention_hd64.mu index 11716d58e..e35f0f922 100644 --- a/src/infiniop/ops/paged_attention/moore/paged_attention_hd64.mu +++ b/src/infiniop/ops/paged_attention/moore/paged_attention_hd64.mu @@ -342,6 +342,20 @@ infiniStatus_t launch_decode_hd64_impl( static_cast<__mt_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -419,6 +433,19 @@ infiniStatus_t launch_decode_hd64_impl( return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd64Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu index c16b48e48..20855b0b4 100644 --- a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu @@ -793,6 +793,20 @@ infiniStatus_t launch_decode_hd128_impl( static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd128SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -919,6 +933,19 @@ infiniStatus_t launch_decode_hd128_impl( return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd128Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu index 421fd22ef..e2c63f641 100644 --- a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu @@ -342,6 +342,20 @@ infiniStatus_t launch_decode_hd64_impl( static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -419,6 +433,19 @@ infiniStatus_t launch_decode_hd64_impl( return INFINI_STATUS_SUCCESS; } + if (dtype == INFINI_DTYPE_F32) { + flashAttentionDecodeHd64Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h index 9f1307c3b..4ab47e3e7 100644 --- a/src/infiniop/ops/paged_attention_prefill/info.h +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -52,7 +52,7 @@ class PagedAttentionPrefillInfo { float scale) { auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca index d631bdb4c..dd33ca3bf 100644 --- a/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca +++ b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca @@ -1531,6 +1531,19 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_BAD_PARAM; \ } while (false) +#define DISPATCH_FLOAT_KERNEL(Tindex) \ + return launch_prefill_warp( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), static_cast(total_kv_lens_ptr), static_cast(cu_seqlens_q_ptr), alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream) + #define DISPATCH_INDEX(Tindex) \ do { \ if (_info.dtype == INFINI_DTYPE_F16) { \ @@ -1539,6 +1552,9 @@ infiniStatus_t Descriptor::calculate( if (_info.dtype == INFINI_DTYPE_BF16) { \ DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ } \ + if (_info.dtype == INFINI_DTYPE_F32) { \ + DISPATCH_FLOAT_KERNEL(Tindex); \ + } \ return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } while (false) diff --git a/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu b/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu index 0ad2e1a51..f3f276e9f 100644 --- a/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu +++ b/src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu @@ -120,6 +120,9 @@ infiniStatus_t Descriptor::calculate( if (_info.dtype == INFINI_DTYPE_BF16) { \ DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ } \ + if (_info.dtype == INFINI_DTYPE_F32) { \ + DISPATCH_KERNEL(Tindex, float, float); \ + } \ return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } while (false) diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu index 100b5bc43..08910292b 100644 --- a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -1532,6 +1532,19 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_BAD_PARAM; \ } while (false) +#define DISPATCH_FLOAT_KERNEL(Tindex) \ + return launch_prefill_warp( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), static_cast(total_kv_lens_ptr), static_cast(cu_seqlens_q_ptr), alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream) + #define DISPATCH_INDEX(Tindex) \ do { \ if (_info.dtype == INFINI_DTYPE_F16) { \ @@ -1540,6 +1553,9 @@ infiniStatus_t Descriptor::calculate( if (_info.dtype == INFINI_DTYPE_BF16) { \ DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ } \ + if (_info.dtype == INFINI_DTYPE_F32) { \ + DISPATCH_FLOAT_KERNEL(Tindex); \ + } \ return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } while (false) diff --git a/test/infiniop/paged_attention_prefill.py b/test/infiniop/paged_attention_prefill.py index 82e850bc6..a42e45f89 100644 --- a/test/infiniop/paged_attention_prefill.py +++ b/test/infiniop/paged_attention_prefill.py @@ -4,6 +4,7 @@ import torch from libinfiniop import ( LIBINFINIOP, + InfiniDeviceEnum, InfiniDeviceNames, InfiniDtype, InfiniDtypeNames, @@ -39,11 +40,12 @@ (16, 128, 128, 128, 8, 16, 4, InfiniDtype.I64), ] -_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2}, + InfiniDtype.F32: {"atol": 2e-3, "rtol": 2e-3}, } DEBUG = False @@ -142,6 +144,17 @@ def test( f"index_dtype:{InfiniDtypeNames[index_dtype]}" ) + if dtype == InfiniDtype.F32 and device not in ( + InfiniDeviceEnum.NVIDIA, + InfiniDeviceEnum.METAX, + InfiniDeviceEnum.MOORE, + InfiniDeviceEnum.ILUVATAR, + ): + print( + f"Skipping F32 on {InfiniDeviceNames[device]}: backend F32 prefill is not implemented" + ) + return + # 1. Initialize persistent resources num_blocks = 8192 manager = SimpleCacheManager(num_blocks, block_size)