diff --git a/src/infiniop/ops/add/ascend/add_ascend.cc b/src/infiniop/ops/add/ascend/add_ascend.cc index daef41673..1a9465f5f 100644 --- a/src/infiniop/ops/add/ascend/add_ascend.cc +++ b/src/infiniop/ops/add/ascend/add_ascend.cc @@ -1,6 +1,8 @@ #include "add_ascend.h" #include "../../../devices/ascend/common_ascend.h" #include +#include +#include namespace op::add::ascend { @@ -10,18 +12,20 @@ struct Descriptor::Opaque { aclnnTensorDescriptor_t b; aclnnTensorDescriptor_t c; aclnnScalarDescriptor_t alpha; + void *alpha_value; size_t workspaceSize; aclOpExecutor *executor; Opaque(aclnnTensorDescriptor_t a_, aclnnTensorDescriptor_t b_, aclnnTensorDescriptor_t c_, - aclnnScalarDescriptor_t alpha_, size_t ws, aclOpExecutor *exec) - : a(a_), b(b_), c(c_), alpha(alpha_), workspaceSize(ws), executor(exec) {} + aclnnScalarDescriptor_t alpha_, void *alpha_value_, size_t ws, aclOpExecutor *exec) + : a(a_), b(b_), c(c_), alpha(alpha_), alpha_value(alpha_value_), workspaceSize(ws), executor(exec) {} ~Opaque() { delete a; delete b; delete c; delete alpha; + std::free(alpha_value); aclDestroyAclOpExecutor(executor); } }; @@ -53,10 +57,37 @@ infiniStatus_t Descriptor::create( aclnnTensorDescriptor_t b = new aclnnTensorDescriptor(b_desc); aclnnTensorDescriptor_t c = new aclnnTensorDescriptor(c_desc); - // Default alpha = 1.0 - float alpha_value = 1.0f; + void *alpha_value = nullptr; + size_t alpha_value_size = 0; + infiniDtype_t alpha_dtype = INFINI_DTYPE_F32; + +#define SET_ALPHA(TYPE, DTYPE, VALUE) \ + do { \ + alpha_value_size = sizeof(TYPE); \ + alpha_value = std::malloc(alpha_value_size); \ + if (alpha_value == nullptr) { \ + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; \ + } \ + *static_cast(alpha_value) = (TYPE)(VALUE); \ + alpha_dtype = DTYPE; \ + } while (0) + + switch (c_desc->dtype()) { + case INFINI_DTYPE_I32: + SET_ALPHA(int32_t, INFINI_DTYPE_I32, 1); + break; + case INFINI_DTYPE_I64: + SET_ALPHA(int64_t, INFINI_DTYPE_I64, 1); + break; + default: + SET_ALPHA(float, INFINI_DTYPE_F32, 1.0f); + break; + } + +#undef SET_ALPHA + aclnnScalarDescriptor_t alpha = new aclnnScalarDescriptor( - INFINI_DTYPE_F32, &alpha_value, sizeof(float)); + alpha_dtype, alpha_value, alpha_value_size); size_t workspace_size = 0; aclOpExecutor *executor = nullptr; @@ -72,7 +103,7 @@ infiniStatus_t Descriptor::create( aclSetAclOpExecutorRepeatable(executor); *desc_ptr = new Descriptor( - new Opaque{a, b, c, alpha, workspace_size, executor}, + new Opaque{a, b, c, alpha, alpha_value, workspace_size, executor}, result.take(), workspace_size, handle_ascend->device, diff --git a/src/infiniop/ops/add/info.h b/src/infiniop/ops/add/info.h index 0d69bd9ae..ca2736222 100644 --- a/src/infiniop/ops/add/info.h +++ b/src/infiniop/ops/add/info.h @@ -23,7 +23,7 @@ class AddInfo { auto dtype = c_desc->dtype(); // Check dtype compatibility - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64); // Check shape compatibility (broadcast) auto c_shape = c_desc->shape();