Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Critsium-xy committed Oct 28, 2024
1 parent 7a29e2f commit 098c932
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 4 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mt_(&transb, &transa, &n, &m, &k,
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -111,7 +111,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mt_(&transb, &transa, &n, &m, &k,
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -129,7 +129,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mt_(&transb, &transa, &n, &m, &k,
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand All @@ -147,7 +147,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mt_(&transb, &transa, &n, &m, &k,
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
Expand Down
12 changes: 12 additions & 0 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
if (arr != nullptr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
#ifdef __DSP
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK);
#else
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
#endif
std::string record_string;
if (record_in != nullptr)
{
Expand Down Expand Up @@ -96,7 +104,11 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
};

Expand Down

0 comments on commit 098c932

Please sign in to comment.