diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index 7b02f93ce..03928177a 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings, Tensor Embedding::forward(const Tensor &indices) const { // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach auto device_type = device_.getType(); - if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { + if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::CAMBRICON || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { // Use op::embedding which supports device-side input and batch dimension return op::embedding(indices->contiguous()->to(device_), weight_); } diff --git a/src/infiniop/ops/embedding/bang/embedding_bang.h b/src/infiniop/ops/embedding/bang/embedding_bang.h new file mode 100644 index 000000000..bc665a1a3 --- /dev/null +++ b/src/infiniop/ops/embedding/bang/embedding_bang.h @@ -0,0 +1,8 @@ +#ifndef __EMBEDDING_BANG_H__ +#define __EMBEDDING_BANG_H__ + +#include "../embedding.h" + +DESCRIPTOR(bang) + +#endif // __EMBEDDING_BANG_H__ diff --git a/src/infiniop/ops/embedding/bang/embedding_bang.mlu b/src/infiniop/ops/embedding/bang/embedding_bang.mlu new file mode 100644 index 000000000..01ff3812d --- /dev/null +++ b/src/infiniop/ops/embedding/bang/embedding_bang.mlu @@ -0,0 +1,182 @@ +#include "../../../devices/bang/common_bang.h" +#include "embedding_bang.h" + +#include + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void embeddingKernel( + T *__restrict__ output, + const IndexType *__restrict__ indices, + const T *__restrict__ weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + if (num_indices == 0 || embedding_dim == 0) { + return; + } + + size_t chunk_target = 512; + size_t parts_by_dim = (embedding_dim + chunk_target - 1) / chunk_target; + size_t max_parts_by_tasks = std::max(1, taskDim / num_indices); + size_t parts_per_row = std::max(1, std::min(parts_by_dim, max_parts_by_tasks)); + size_t part_size = (embedding_dim + parts_per_row - 1) / parts_per_row; + size_t logical_tasks = num_indices * parts_per_row; + size_t max_chunk = NRAM_MAX_SIZE / sizeof(T); + T *cache = reinterpret_cast(nram_buffer); + + for (size_t logical = taskId; logical < logical_tasks; logical += taskDim) { + size_t row = logical / parts_per_row; + size_t part = logical - row * parts_per_row; + size_t start = part * part_size; + size_t end = std::min(embedding_dim, start + part_size); + if (start >= end) { + continue; + } + + IndexType index_val = indices[row]; + if (index_val < 0 || static_cast(index_val) >= vocab_size) { + continue; + } + + const T *src = weight + static_cast(index_val) * embedding_dim + start; + T *dst = output + row * embedding_dim + start; + size_t processed = 0; + size_t len = end - start; + while (processed < len) { + size_t current = std::min(max_chunk, len - processed); + __memcpy(cache, src + processed, current * sizeof(T), GDRAM2NRAM); + __memcpy(dst + processed, cache, current * sizeof(T), NRAM2GDRAM); + processed += current; + } + } +} + +template +static infiniStatus_t launchEmbedding( + int core_per_cluster, + int cluster_count, + cnrtQueue_t queue, + void *output, + const void *input, + const void *weight, + size_t num_indices, + size_t embedding_dim, + size_t vocab_size) { + cnrtDim3_t kernel_dim; + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + + embeddingKernel<<>>( + reinterpret_cast(output), + reinterpret_cast(input), + reinterpret_cast(weight), + num_indices, + embedding_dim, + vocab_size); + cnrtQueueSync(queue); + return INFINI_STATUS_SUCCESS; +} + +namespace op::embedding::bang { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + + auto input_shape = input_desc->shape(); + auto weight_shape = weight_desc->shape(); + + CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto output_shape = output_desc->shape(); + size_t embedding_dim = weight_shape[1]; + CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE); + + for (size_t i = 0; i < input_shape.size(); ++i) { + CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + auto input_dtype = input_desc->dtype(); + auto weight_dtype = weight_desc->dtype(); + CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t num_indices = 1; + for (auto dim : input_shape) { + num_indices *= dim; + } + size_t vocab_size = weight_shape[0]; + + auto handle_bang = reinterpret_cast(handle); + *desc_ptr = new Descriptor( + num_indices, + embedding_dim, + vocab_size, + input_dtype, + weight_dtype, + new Opaque{handle_bang->internal()}, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *output, + const void *input, + const void *weight, + void *stream) const { + if (_num_indices == 0) { + return INFINI_STATUS_SUCCESS; + } + + auto queue = reinterpret_cast(stream); + int core_per_cluster = _opaque->internal->getCorePerCluster(); + int cluster_count = _opaque->internal->getClusterCount(); + +#define DISPATCH(T, IndexType) \ + return launchEmbedding( \ + core_per_cluster, cluster_count, queue, output, input, weight, \ + _num_indices, _embedding_dim, _vocab_size) + + if (_input_dtype == INFINI_DTYPE_I32) { + if (_weight_dtype == INFINI_DTYPE_F32) { + DISPATCH(float, int32_t); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + DISPATCH(half, int32_t); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + DISPATCH(bfloat16_t, int32_t); + } + } else if (_input_dtype == INFINI_DTYPE_I64) { + if (_weight_dtype == INFINI_DTYPE_F32) { + DISPATCH(float, int64_t); + } else if (_weight_dtype == INFINI_DTYPE_F16) { + DISPATCH(half, int64_t); + } else if (_weight_dtype == INFINI_DTYPE_BF16) { + DISPATCH(bfloat16_t, int64_t); + } + } + +#undef DISPATCH + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::embedding::bang diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc index 4741945c7..525ced0af 100644 --- a/src/infiniop/ops/embedding/operator.cc +++ b/src/infiniop/ops/embedding/operator.cc @@ -2,6 +2,10 @@ #include "../../handle.h" #include "infiniop/ops/embedding.h" +#ifdef ENABLE_CAMBRICON_API +#include "bang/embedding_bang.h" +#endif + #ifdef ENABLE_CPU_API #include "cpu/embedding_cpu.h" #endif @@ -57,6 +61,9 @@ __INFINI_C infiniStatus_t infiniopCreateEmbeddingDescriptor( #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -103,6 +110,9 @@ __INFINI_C infiniStatus_t infiniopEmbedding( #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -143,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDe #ifdef ENABLE_MOORE_API DESTROY(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_CAMBRICON_API + DESTROY(INFINI_DEVICE_CAMBRICON, bang); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;