diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 1de321ca99..30b3b93d40 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - sgemm_mt_(&transb, &transa, &n, &m, &k, + sgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice){ - dgemm_mt_(&transb, &transa, &n, &m, &k, + dgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - cgemm_mt_(&transb, &transa, &n, &m, &k, + cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } @@ -147,7 +147,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons } #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - zgemm_mt_(&transb, &transa, &n, &m, &k, + zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK); } diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 00c4a36ad7..8f74c016bb 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -22,9 +22,17 @@ struct resize_memory_op { if (arr != nullptr) { +#ifdef __DSP + free_ht(arr); +#else free(arr); +#endif } +#ifdef __DSP + arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); +#else arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size); +#endif std::string record_string; if (record_in != nullptr) { @@ -96,7 +104,11 @@ struct delete_memory_op { void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr) { +#ifdef __DSP + free_ht(arr); +#else free(arr); +#endif } };