Skip to content

Commit

Permalink
[cherry-pick] fix sve backends bug(matmul_v2&conv) (#9696)
Browse files Browse the repository at this point in the history
* add A510 for sdot supported (#9537)

* [ARM] add matmul_v2 sve2 backends and add 5x5s1p2 max pooling (#9653)

* [SVE] add matmul_v2 sve backends

* [ARM] add 5x5s1p2 pooling max kernel

* [sve] fix fuse leakrelu in conv (#9670)

* [OpMakerClean] fix flatten op for removing xshape
  • Loading branch information
mjp9527 authored Nov 17, 2022
1 parent 6bc3164 commit f294964
Show file tree
Hide file tree
Showing 13 changed files with 788 additions and 167 deletions.
110 changes: 110 additions & 0 deletions lite/backends/arm/math/gemm_s8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.

#include "lite/backends/arm/math/gemm_s8.h"
#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
#include "lite/backends/arm/math/sve/gemm_sve_i8mm.h"
#endif

namespace paddle {
namespace lite {
Expand Down Expand Up @@ -112,6 +115,113 @@ template void gemm_s8<int8_t>(bool is_transA,
const operators::ActivationParam act_param,
ARMContext* ctx);

#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
template <typename Dtype>
void gemm_sve(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
Dtype* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) {
if (N == 1) {
gemv_int8(A, B, C, is_transA, M, K, scale, is_bias, bias, act_param, ctx);
return;
}
if (M == 1) {
#ifdef TARGET_IOS
float* bias_ptr = new float[N];
float* scale_ptr = new float[N];
#else
float bias_ptr[N]; // NOLINT
float scale_ptr[N]; // NOLINT
#endif
if (is_bias) {
for (int i = 0; i < N; i++) {
bias_ptr[i] = bias[0];
}
}
for (int i = 0; i < N; i++) {
scale_ptr[i] = scale[0];
}
gemv_int8(B,
A,
C,
!is_transB,
N,
K,
scale_ptr,
is_bias,
bias_ptr,
act_param,
ctx);
#ifdef TARGET_IOS
delete[] bias_ptr;
delete[] scale_ptr;
#endif
return;
}

//! prepack
Tensor tpackedA_sve;
int hblock_sve = paddle::lite::arm::math::sve::get_hblock_int8_sve(ctx);
int round_up_a_sve = ((hblock_sve + M - 1) / hblock_sve) * hblock_sve;
int round_up_k_sve = 8 * ((K + 7) / 8);
tpackedA_sve.Resize({round_up_a_sve * round_up_k_sve});
int lda = is_transA ? M : K;
paddle::lite::arm::math::sve::prepackA_int8_sve(
tpackedA_sve.mutable_data<int8_t>(), A, lda, 0, M, 0, K, is_transA, ctx);
// sve
lite::arm::math::sve::gemm_prepack_int8_sve<Dtype>(
tpackedA_sve.data<int8_t>(),
B,
bias,
C,
M,
N,
K,
is_bias,
is_transB,
scale,
act_param,
ctx);
}

template void gemm_sve<float>(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
float* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);

template void gemm_sve<int8_t>(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
int8_t* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
#endif

} // namespace math
} // namespace arm
} // namespace lite
Expand Down
16 changes: 16 additions & 0 deletions lite/backends/arm/math/gemm_s8.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ void gemm_s8(bool is_transA,
const operators::ActivationParam act_param,
ARMContext* ctx);

#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
template <typename Dtype>
void gemm_sve(bool is_transA,
bool is_transB,
int M,
int N,
int K,
const int8_t* A,
const int8_t* B,
Dtype* C,
const float* bias,
bool is_bias,
const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx);
#endif
} // namespace math
} // namespace arm
} // namespace lite
Expand Down
Loading

0 comments on commit f294964

Please sign in to comment.