Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/infinicore/nn/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings,
Tensor Embedding::forward(const Tensor &indices) const {
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto device_type = device_.getType();
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) {
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::CAMBRICON || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) {
// Use op::embedding which supports device-side input and batch dimension
return op::embedding(indices->contiguous()->to(device_), weight_);
}
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/embedding/bang/embedding_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __EMBEDDING_BANG_H__
#define __EMBEDDING_BANG_H__

#include "../embedding.h"

DESCRIPTOR(bang)

#endif // __EMBEDDING_BANG_H__
182 changes: 182 additions & 0 deletions src/infiniop/ops/embedding/bang/embedding_bang.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include "../../../devices/bang/common_bang.h"
#include "embedding_bang.h"

#include <algorithm>

__nram__ char nram_buffer[NRAM_MAX_SIZE];

template <typename T, typename IndexType>
__mlu_global__ void embeddingKernel(
T *__restrict__ output,
const IndexType *__restrict__ indices,
const T *__restrict__ weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
if (num_indices == 0 || embedding_dim == 0) {
return;
}

size_t chunk_target = 512;
size_t parts_by_dim = (embedding_dim + chunk_target - 1) / chunk_target;
size_t max_parts_by_tasks = std::max<size_t>(1, taskDim / num_indices);
size_t parts_per_row = std::max<size_t>(1, std::min(parts_by_dim, max_parts_by_tasks));
size_t part_size = (embedding_dim + parts_per_row - 1) / parts_per_row;
size_t logical_tasks = num_indices * parts_per_row;
size_t max_chunk = NRAM_MAX_SIZE / sizeof(T);
T *cache = reinterpret_cast<T *>(nram_buffer);

for (size_t logical = taskId; logical < logical_tasks; logical += taskDim) {
size_t row = logical / parts_per_row;
size_t part = logical - row * parts_per_row;
size_t start = part * part_size;
size_t end = std::min(embedding_dim, start + part_size);
if (start >= end) {
continue;
}

IndexType index_val = indices[row];
if (index_val < 0 || static_cast<size_t>(index_val) >= vocab_size) {
continue;
}

const T *src = weight + static_cast<size_t>(index_val) * embedding_dim + start;
T *dst = output + row * embedding_dim + start;
size_t processed = 0;
size_t len = end - start;
while (processed < len) {
size_t current = std::min(max_chunk, len - processed);
__memcpy(cache, src + processed, current * sizeof(T), GDRAM2NRAM);
__memcpy(dst + processed, cache, current * sizeof(T), NRAM2GDRAM);
processed += current;
}
}
}

template <typename T, typename IndexType>
static infiniStatus_t launchEmbedding(
int core_per_cluster,
int cluster_count,
cnrtQueue_t queue,
void *output,
const void *input,
const void *weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
cnrtDim3_t kernel_dim;
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;

embeddingKernel<T, IndexType><<<kernel_dim, cnrtFuncTypeUnion1, queue>>>(
reinterpret_cast<T *>(output),
reinterpret_cast<const IndexType *>(input),
reinterpret_cast<const T *>(weight),
num_indices,
embedding_dim,
vocab_size);
cnrtQueueSync(queue);
return INFINI_STATUS_SUCCESS;
}

namespace op::embedding::bang {

struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc) {

auto input_shape = input_desc->shape();
auto weight_shape = weight_desc->shape();

CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);

auto output_shape = output_desc->shape();
size_t embedding_dim = weight_shape[1];
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);

for (size_t i = 0; i < input_shape.size(); ++i) {
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
}

auto input_dtype = input_desc->dtype();
auto weight_dtype = weight_desc->dtype();
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16,
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);

size_t num_indices = 1;
for (auto dim : input_shape) {
num_indices *= dim;
}
size_t vocab_size = weight_shape[0];

auto handle_bang = reinterpret_cast<device::bang::Handle *>(handle);
*desc_ptr = new Descriptor(
num_indices,
embedding_dim,
vocab_size,
input_dtype,
weight_dtype,
new Opaque{handle_bang->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *output,
const void *input,
const void *weight,
void *stream) const {
if (_num_indices == 0) {
return INFINI_STATUS_SUCCESS;
}

auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();

#define DISPATCH(T, IndexType) \
return launchEmbedding<T, IndexType>( \
core_per_cluster, cluster_count, queue, output, input, weight, \
_num_indices, _embedding_dim, _vocab_size)

if (_input_dtype == INFINI_DTYPE_I32) {
if (_weight_dtype == INFINI_DTYPE_F32) {
DISPATCH(float, int32_t);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
DISPATCH(half, int32_t);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
DISPATCH(bfloat16_t, int32_t);
}
} else if (_input_dtype == INFINI_DTYPE_I64) {
if (_weight_dtype == INFINI_DTYPE_F32) {
DISPATCH(float, int64_t);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
DISPATCH(half, int64_t);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
DISPATCH(bfloat16_t, int64_t);
}
}

#undef DISPATCH

return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

} // namespace op::embedding::bang
13 changes: 13 additions & 0 deletions src/infiniop/ops/embedding/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#include "../../handle.h"
#include "infiniop/ops/embedding.h"

#ifdef ENABLE_CAMBRICON_API
#include "bang/embedding_bang.h"
#endif

#ifdef ENABLE_CPU_API
#include "cpu/embedding_cpu.h"
#endif
Expand Down Expand Up @@ -57,6 +61,9 @@ __INFINI_C infiniStatus_t infiniopCreateEmbeddingDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -103,6 +110,9 @@ __INFINI_C infiniStatus_t infiniopEmbedding(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -143,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDe
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_CAMBRICON_API
DESTROY(INFINI_DEVICE_CAMBRICON, bang);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
Loading