From 19994360608ddedcd4524b110f530ac15d8705e0 Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Sun, 7 Feb 2021 20:35:08 +0800 Subject: [PATCH] [X86] optimize depthwise conv2d3x3 (#5434) (#5477) --- .../backends/x86/math/conv_depthwise_pack8.cc | 752 +++++++++--------- lite/backends/x86/math/conv_utils.cc | 118 +++ lite/backends/x86/math/conv_utils.h | 5 + lite/kernels/x86/conv_compute.cc | 6 +- lite/kernels/x86/conv_compute.h | 20 +- lite/kernels/x86/conv_depthwise.cc | 10 +- 6 files changed, 529 insertions(+), 382 deletions(-) diff --git a/lite/backends/x86/math/conv_depthwise_pack8.cc b/lite/backends/x86/math/conv_depthwise_pack8.cc index f72f26082a0..bb722981257 100644 --- a/lite/backends/x86/math/conv_depthwise_pack8.cc +++ b/lite/backends/x86/math/conv_depthwise_pack8.cc @@ -57,389 +57,391 @@ void conv_depthwise_3x3s1_m256(lite::Tensor* input, const int filter_channel_step = kernel_h * kernel_w * 8; - for (int bs = 0; bs < batch_size; ++bs) { - for (int ic = 0; ic < channel_num; ++ic) { - __m256 _bias0 = bias ? _mm256_loadu_ps(bias->data() + ic * 8) - : _mm256_set1_ps(0.f); - - const float* k0 = filter_data + ic * filter_channel_step; - - const float* r0 = - input_data + bs * input_batch_step + ic * input_channel_step; - const float* r1 = r0 + input_group_step; - const float* r2 = r1 + input_group_step; - - __m256 _k00 = _mm256_loadu_ps(k0); - __m256 _k01 = _mm256_loadu_ps(k0 + 8); - __m256 _k02 = _mm256_loadu_ps(k0 + 16); - __m256 _k10 = _mm256_loadu_ps(k0 + 24); - __m256 _k11 = _mm256_loadu_ps(k0 + 32); - __m256 _k12 = _mm256_loadu_ps(k0 + 40); - __m256 _k20 = _mm256_loadu_ps(k0 + 48); - __m256 _k21 = _mm256_loadu_ps(k0 + 56); - __m256 _k22 = _mm256_loadu_ps(k0 + 64); - - for (int i = 0; i < output_height; ++i) { - int j = 0; - for (; j + 7 < output_width; j += 8) { - __m256 _sum0 = _bias0; - - __m256 _r00 = _mm256_loadu_ps(r0); - __m256 _r01 = _mm256_loadu_ps(r0 + 8); - __m256 _r02 = _mm256_loadu_ps(r0 + 16); - __m256 _r10 = _mm256_loadu_ps(r1); - __m256 _r11 = _mm256_loadu_ps(r1 + 8); - __m256 _r12 = _mm256_loadu_ps(r1 + 16); - __m256 _r20 = _mm256_loadu_ps(r2); - __m256 _r21 = _mm256_loadu_ps(r2 + 8); - __m256 _r22 = _mm256_loadu_ps(r2 + 16); - - _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); - _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); - _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); - _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); - _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); - _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); - _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); - _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); - _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); - - if (has_act) { - _sum0 = activation8_m256(_sum0, act_type); - } - - _mm256_storeu_ps(output_data, _sum0); - - __m256 _sum1 = _bias0; - __m256 _r03 = _mm256_loadu_ps(r0 + 24); - __m256 _r13 = _mm256_loadu_ps(r1 + 24); - __m256 _r23 = _mm256_loadu_ps(r2 + 24); - - _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1); - _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1); - _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1); - _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1); - _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1); - _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1); - _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1); - _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1); - _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1); - - if (has_act) { - _sum1 = activation8_m256(_sum1, act_type); - } - _mm256_storeu_ps(output_data + 8, _sum1); - - __m256 _sum2 = _bias0; - __m256 _r04 = _mm256_loadu_ps(r0 + 32); - __m256 _r14 = _mm256_loadu_ps(r1 + 32); - __m256 _r24 = _mm256_loadu_ps(r2 + 32); - - _sum2 = _mm256_fmadd_ps(_k00, _r02, _sum2); - _sum2 = _mm256_fmadd_ps(_k01, _r03, _sum2); - _sum2 = _mm256_fmadd_ps(_k02, _r04, _sum2); - _sum2 = _mm256_fmadd_ps(_k10, _r12, _sum2); - _sum2 = _mm256_fmadd_ps(_k11, _r13, _sum2); - _sum2 = _mm256_fmadd_ps(_k12, _r14, _sum2); - _sum2 = _mm256_fmadd_ps(_k20, _r22, _sum2); - _sum2 = _mm256_fmadd_ps(_k21, _r23, _sum2); - _sum2 = _mm256_fmadd_ps(_k22, _r24, _sum2); - - if (has_act) { - _sum2 = activation8_m256(_sum2, act_type); - } - _mm256_storeu_ps(output_data + 16, _sum2); - - __m256 _sum3 = _bias0; - __m256 _r05 = _mm256_loadu_ps(r0 + 40); - __m256 _r15 = _mm256_loadu_ps(r1 + 40); - __m256 _r25 = _mm256_loadu_ps(r2 + 40); - - _sum3 = _mm256_fmadd_ps(_k00, _r03, _sum3); - _sum3 = _mm256_fmadd_ps(_k01, _r04, _sum3); - _sum3 = _mm256_fmadd_ps(_k02, _r05, _sum3); - _sum3 = _mm256_fmadd_ps(_k10, _r13, _sum3); - _sum3 = _mm256_fmadd_ps(_k11, _r14, _sum3); - _sum3 = _mm256_fmadd_ps(_k12, _r15, _sum3); - _sum3 = _mm256_fmadd_ps(_k20, _r23, _sum3); - _sum3 = _mm256_fmadd_ps(_k21, _r24, _sum3); - _sum3 = _mm256_fmadd_ps(_k22, _r25, _sum3); - - if (has_act) { - _sum3 = activation8_m256(_sum3, act_type); - } - _mm256_storeu_ps(output_data + 24, _sum3); - - __m256 _sum4 = _bias0; - __m256 _r06 = _mm256_loadu_ps(r0 + 48); - __m256 _r16 = _mm256_loadu_ps(r1 + 48); - __m256 _r26 = _mm256_loadu_ps(r2 + 48); - - _sum4 = _mm256_fmadd_ps(_k00, _r04, _sum4); - _sum4 = _mm256_fmadd_ps(_k01, _r05, _sum4); - _sum4 = _mm256_fmadd_ps(_k02, _r06, _sum4); - _sum4 = _mm256_fmadd_ps(_k10, _r14, _sum4); - _sum4 = _mm256_fmadd_ps(_k11, _r15, _sum4); - _sum4 = _mm256_fmadd_ps(_k12, _r16, _sum4); - _sum4 = _mm256_fmadd_ps(_k20, _r24, _sum4); - _sum4 = _mm256_fmadd_ps(_k21, _r25, _sum4); - _sum4 = _mm256_fmadd_ps(_k22, _r26, _sum4); - - if (has_act) { - _sum4 = activation8_m256(_sum4, act_type); - } - _mm256_storeu_ps(output_data + 32, _sum4); - - __m256 _sum5 = _bias0; - __m256 _r07 = _mm256_loadu_ps(r0 + 56); - __m256 _r17 = _mm256_loadu_ps(r1 + 56); - __m256 _r27 = _mm256_loadu_ps(r2 + 56); - - _sum5 = _mm256_fmadd_ps(_k00, _r05, _sum5); - _sum5 = _mm256_fmadd_ps(_k01, _r06, _sum5); - _sum5 = _mm256_fmadd_ps(_k02, _r07, _sum5); - _sum5 = _mm256_fmadd_ps(_k10, _r15, _sum5); - _sum5 = _mm256_fmadd_ps(_k11, _r16, _sum5); - _sum5 = _mm256_fmadd_ps(_k12, _r17, _sum5); - _sum5 = _mm256_fmadd_ps(_k20, _r25, _sum5); - _sum5 = _mm256_fmadd_ps(_k21, _r26, _sum5); - _sum5 = _mm256_fmadd_ps(_k22, _r27, _sum5); - - if (has_act) { - _sum5 = activation8_m256(_sum5, act_type); - } - _mm256_storeu_ps(output_data + 40, _sum5); - - __m256 _sum6 = _bias0; - __m256 _r08 = _mm256_loadu_ps(r0 + 64); - __m256 _r18 = _mm256_loadu_ps(r1 + 64); - __m256 _r28 = _mm256_loadu_ps(r2 + 64); - - _sum6 = _mm256_fmadd_ps(_k00, _r06, _sum6); - _sum6 = _mm256_fmadd_ps(_k01, _r07, _sum6); - _sum6 = _mm256_fmadd_ps(_k02, _r08, _sum6); - _sum6 = _mm256_fmadd_ps(_k10, _r16, _sum6); - _sum6 = _mm256_fmadd_ps(_k11, _r17, _sum6); - _sum6 = _mm256_fmadd_ps(_k12, _r18, _sum6); - _sum6 = _mm256_fmadd_ps(_k20, _r26, _sum6); - _sum6 = _mm256_fmadd_ps(_k21, _r27, _sum6); - _sum6 = _mm256_fmadd_ps(_k22, _r28, _sum6); - - if (has_act) { - _sum6 = activation8_m256(_sum6, act_type); - } - _mm256_storeu_ps(output_data + 48, _sum6); - - __m256 _sum7 = _bias0; - __m256 _r09 = _mm256_loadu_ps(r0 + 72); - __m256 _r19 = _mm256_loadu_ps(r1 + 72); - __m256 _r29 = _mm256_loadu_ps(r2 + 72); - - _sum7 = _mm256_fmadd_ps(_k00, _r07, _sum7); - _sum7 = _mm256_fmadd_ps(_k01, _r08, _sum7); - _sum7 = _mm256_fmadd_ps(_k02, _r09, _sum7); - _sum7 = _mm256_fmadd_ps(_k10, _r17, _sum7); - _sum7 = _mm256_fmadd_ps(_k11, _r18, _sum7); - _sum7 = _mm256_fmadd_ps(_k12, _r19, _sum7); - _sum7 = _mm256_fmadd_ps(_k20, _r27, _sum7); - _sum7 = _mm256_fmadd_ps(_k21, _r28, _sum7); - _sum7 = _mm256_fmadd_ps(_k22, _r29, _sum7); - - if (has_act) { - _sum7 = activation8_m256(_sum7, act_type); - } - _mm256_storeu_ps(output_data + 56, _sum7); - - r0 += 64; - r1 += 64; - r2 += 64; - output_data += 64; + int total_count = batch_size * channel_num; + + // #pragma omp parallel for collapse(1) + for (int idx = 0; idx < total_count; ++idx) { + __m256 _bias0 = + bias ? _mm256_loadu_ps(bias->data() + (idx % channel_num) * 8) + : _mm256_set1_ps(0.f); + + const float* k0 = filter_data + (idx % channel_num) * filter_channel_step; + + const float* r0 = input_data + (idx / channel_num) * input_batch_step + + (idx % channel_num) * input_channel_step; + const float* r1 = r0 + input_group_step; + const float* r2 = r1 + input_group_step; + + __m256 _k00 = _mm256_loadu_ps(k0); + __m256 _k01 = _mm256_loadu_ps(k0 + 8); + __m256 _k02 = _mm256_loadu_ps(k0 + 16); + __m256 _k10 = _mm256_loadu_ps(k0 + 24); + __m256 _k11 = _mm256_loadu_ps(k0 + 32); + __m256 _k12 = _mm256_loadu_ps(k0 + 40); + __m256 _k20 = _mm256_loadu_ps(k0 + 48); + __m256 _k21 = _mm256_loadu_ps(k0 + 56); + __m256 _k22 = _mm256_loadu_ps(k0 + 64); + + for (int i = 0; i < output_height; ++i) { + int j = 0; + for (; j + 7 < output_width; j += 8) { + __m256 _sum0 = _bias0; + + __m256 _r00 = _mm256_loadu_ps(r0); + __m256 _r01 = _mm256_loadu_ps(r0 + 8); + __m256 _r02 = _mm256_loadu_ps(r0 + 16); + __m256 _r10 = _mm256_loadu_ps(r1); + __m256 _r11 = _mm256_loadu_ps(r1 + 8); + __m256 _r12 = _mm256_loadu_ps(r1 + 16); + __m256 _r20 = _mm256_loadu_ps(r2); + __m256 _r21 = _mm256_loadu_ps(r2 + 8); + __m256 _r22 = _mm256_loadu_ps(r2 + 16); + + _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); + _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); + _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); + _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); + _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); + _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); + _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); + _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); + _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); + + if (has_act) { + _sum0 = activation8_m256(_sum0, act_type); } - for (; j + 3 < output_width; j += 4) { - __m256 _sum0 = _bias0; - - __m256 _r00 = _mm256_loadu_ps(r0); - __m256 _r01 = _mm256_loadu_ps(r0 + 8); - __m256 _r02 = _mm256_loadu_ps(r0 + 16); - __m256 _r10 = _mm256_loadu_ps(r1); - __m256 _r11 = _mm256_loadu_ps(r1 + 8); - __m256 _r12 = _mm256_loadu_ps(r1 + 16); - __m256 _r20 = _mm256_loadu_ps(r2); - __m256 _r21 = _mm256_loadu_ps(r2 + 8); - __m256 _r22 = _mm256_loadu_ps(r2 + 16); - _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); - _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); - _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); - _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); - _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); - _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); - _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); - _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); - _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); - - if (has_act) { - _sum0 = activation8_m256(_sum0, act_type); - } - _mm256_storeu_ps(output_data, _sum0); - - __m256 _sum1 = _bias0; - __m256 _r03 = _mm256_loadu_ps(r0 + 24); - __m256 _r13 = _mm256_loadu_ps(r1 + 24); - __m256 _r23 = _mm256_loadu_ps(r2 + 24); - - _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1); - _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1); - _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1); - _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1); - _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1); - _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1); - _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1); - _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1); - _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1); - - if (has_act) { - _sum1 = activation8_m256(_sum1, act_type); - } - _mm256_storeu_ps(output_data + 8, _sum1); - - __m256 _sum2 = _bias0; - __m256 _r04 = _mm256_loadu_ps(r0 + 32); - __m256 _r14 = _mm256_loadu_ps(r1 + 32); - __m256 _r24 = _mm256_loadu_ps(r2 + 32); - - _sum2 = _mm256_fmadd_ps(_k00, _r02, _sum2); - _sum2 = _mm256_fmadd_ps(_k01, _r03, _sum2); - _sum2 = _mm256_fmadd_ps(_k02, _r04, _sum2); - _sum2 = _mm256_fmadd_ps(_k10, _r12, _sum2); - _sum2 = _mm256_fmadd_ps(_k11, _r13, _sum2); - _sum2 = _mm256_fmadd_ps(_k12, _r14, _sum2); - _sum2 = _mm256_fmadd_ps(_k20, _r22, _sum2); - _sum2 = _mm256_fmadd_ps(_k21, _r23, _sum2); - _sum2 = _mm256_fmadd_ps(_k22, _r24, _sum2); - - if (has_act) { - _sum2 = activation8_m256(_sum2, act_type); - } - _mm256_storeu_ps(output_data + 16, _sum2); - - __m256 _sum3 = _bias0; - __m256 _r05 = _mm256_loadu_ps(r0 + 40); - __m256 _r15 = _mm256_loadu_ps(r1 + 40); - __m256 _r25 = _mm256_loadu_ps(r2 + 40); - - _sum3 = _mm256_fmadd_ps(_k00, _r03, _sum3); - _sum3 = _mm256_fmadd_ps(_k01, _r04, _sum3); - _sum3 = _mm256_fmadd_ps(_k02, _r05, _sum3); - _sum3 = _mm256_fmadd_ps(_k10, _r13, _sum3); - _sum3 = _mm256_fmadd_ps(_k11, _r14, _sum3); - _sum3 = _mm256_fmadd_ps(_k12, _r15, _sum3); - _sum3 = _mm256_fmadd_ps(_k20, _r23, _sum3); - _sum3 = _mm256_fmadd_ps(_k21, _r24, _sum3); - _sum3 = _mm256_fmadd_ps(_k22, _r25, _sum3); - - if (has_act) { - _sum3 = activation8_m256(_sum3, act_type); - } - _mm256_storeu_ps(output_data + 24, _sum3); - - r0 += 32; - r1 += 32; - r2 += 32; - output_data += 32; + _mm256_storeu_ps(output_data, _sum0); + + __m256 _sum1 = _bias0; + __m256 _r03 = _mm256_loadu_ps(r0 + 24); + __m256 _r13 = _mm256_loadu_ps(r1 + 24); + __m256 _r23 = _mm256_loadu_ps(r2 + 24); + + _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1); + _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1); + _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1); + _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1); + _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1); + _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1); + _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1); + _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1); + _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1); + + if (has_act) { + _sum1 = activation8_m256(_sum1, act_type); } - for (; j + 1 < output_width; j += 2) { - __m256 _sum0 = _bias0; - - __m256 _r00 = _mm256_loadu_ps(r0); - __m256 _r01 = _mm256_loadu_ps(r0 + 8); - __m256 _r02 = _mm256_loadu_ps(r0 + 16); - __m256 _r10 = _mm256_loadu_ps(r1); - __m256 _r11 = _mm256_loadu_ps(r1 + 8); - __m256 _r12 = _mm256_loadu_ps(r1 + 16); - __m256 _r20 = _mm256_loadu_ps(r2); - __m256 _r21 = _mm256_loadu_ps(r2 + 8); - __m256 _r22 = _mm256_loadu_ps(r2 + 16); - - _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); - _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); - _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); - _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); - _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); - _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); - _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); - _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); - _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); - - if (has_act) { - _sum0 = activation8_m256(_sum0, act_type); - } - _mm256_storeu_ps(output_data, _sum0); - - __m256 _sum1 = _bias0; - __m256 _r03 = _mm256_loadu_ps(r0 + 24); - __m256 _r13 = _mm256_loadu_ps(r1 + 24); - __m256 _r23 = _mm256_loadu_ps(r2 + 24); - - _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1); - _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1); - _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1); - _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1); - _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1); - _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1); - _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1); - _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1); - _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1); - - if (has_act) { - _sum1 = activation8_m256(_sum1, act_type); - } - _mm256_storeu_ps(output_data + 8, _sum1); - - r0 += 16; - r1 += 16; - r2 += 16; - output_data += 16; + _mm256_storeu_ps(output_data + 8, _sum1); + + __m256 _sum2 = _bias0; + __m256 _r04 = _mm256_loadu_ps(r0 + 32); + __m256 _r14 = _mm256_loadu_ps(r1 + 32); + __m256 _r24 = _mm256_loadu_ps(r2 + 32); + + _sum2 = _mm256_fmadd_ps(_k00, _r02, _sum2); + _sum2 = _mm256_fmadd_ps(_k01, _r03, _sum2); + _sum2 = _mm256_fmadd_ps(_k02, _r04, _sum2); + _sum2 = _mm256_fmadd_ps(_k10, _r12, _sum2); + _sum2 = _mm256_fmadd_ps(_k11, _r13, _sum2); + _sum2 = _mm256_fmadd_ps(_k12, _r14, _sum2); + _sum2 = _mm256_fmadd_ps(_k20, _r22, _sum2); + _sum2 = _mm256_fmadd_ps(_k21, _r23, _sum2); + _sum2 = _mm256_fmadd_ps(_k22, _r24, _sum2); + + if (has_act) { + _sum2 = activation8_m256(_sum2, act_type); } - for (; j < output_width; ++j) { - __m256 _sum0 = _bias0; - - __m256 _r00 = _mm256_loadu_ps(r0); - __m256 _r01 = _mm256_loadu_ps(r0 + 8); - __m256 _r02 = _mm256_loadu_ps(r0 + 16); - __m256 _r10 = _mm256_loadu_ps(r1); - __m256 _r11 = _mm256_loadu_ps(r1 + 8); - __m256 _r12 = _mm256_loadu_ps(r1 + 16); - __m256 _r20 = _mm256_loadu_ps(r2); - __m256 _r21 = _mm256_loadu_ps(r2 + 8); - __m256 _r22 = _mm256_loadu_ps(r2 + 16); + _mm256_storeu_ps(output_data + 16, _sum2); + + __m256 _sum3 = _bias0; + __m256 _r05 = _mm256_loadu_ps(r0 + 40); + __m256 _r15 = _mm256_loadu_ps(r1 + 40); + __m256 _r25 = _mm256_loadu_ps(r2 + 40); + + _sum3 = _mm256_fmadd_ps(_k00, _r03, _sum3); + _sum3 = _mm256_fmadd_ps(_k01, _r04, _sum3); + _sum3 = _mm256_fmadd_ps(_k02, _r05, _sum3); + _sum3 = _mm256_fmadd_ps(_k10, _r13, _sum3); + _sum3 = _mm256_fmadd_ps(_k11, _r14, _sum3); + _sum3 = _mm256_fmadd_ps(_k12, _r15, _sum3); + _sum3 = _mm256_fmadd_ps(_k20, _r23, _sum3); + _sum3 = _mm256_fmadd_ps(_k21, _r24, _sum3); + _sum3 = _mm256_fmadd_ps(_k22, _r25, _sum3); + + if (has_act) { + _sum3 = activation8_m256(_sum3, act_type); + } + _mm256_storeu_ps(output_data + 24, _sum3); + + __m256 _sum4 = _bias0; + __m256 _r06 = _mm256_loadu_ps(r0 + 48); + __m256 _r16 = _mm256_loadu_ps(r1 + 48); + __m256 _r26 = _mm256_loadu_ps(r2 + 48); + + _sum4 = _mm256_fmadd_ps(_k00, _r04, _sum4); + _sum4 = _mm256_fmadd_ps(_k01, _r05, _sum4); + _sum4 = _mm256_fmadd_ps(_k02, _r06, _sum4); + _sum4 = _mm256_fmadd_ps(_k10, _r14, _sum4); + _sum4 = _mm256_fmadd_ps(_k11, _r15, _sum4); + _sum4 = _mm256_fmadd_ps(_k12, _r16, _sum4); + _sum4 = _mm256_fmadd_ps(_k20, _r24, _sum4); + _sum4 = _mm256_fmadd_ps(_k21, _r25, _sum4); + _sum4 = _mm256_fmadd_ps(_k22, _r26, _sum4); + + if (has_act) { + _sum4 = activation8_m256(_sum4, act_type); + } + _mm256_storeu_ps(output_data + 32, _sum4); + + __m256 _sum5 = _bias0; + __m256 _r07 = _mm256_loadu_ps(r0 + 56); + __m256 _r17 = _mm256_loadu_ps(r1 + 56); + __m256 _r27 = _mm256_loadu_ps(r2 + 56); + + _sum5 = _mm256_fmadd_ps(_k00, _r05, _sum5); + _sum5 = _mm256_fmadd_ps(_k01, _r06, _sum5); + _sum5 = _mm256_fmadd_ps(_k02, _r07, _sum5); + _sum5 = _mm256_fmadd_ps(_k10, _r15, _sum5); + _sum5 = _mm256_fmadd_ps(_k11, _r16, _sum5); + _sum5 = _mm256_fmadd_ps(_k12, _r17, _sum5); + _sum5 = _mm256_fmadd_ps(_k20, _r25, _sum5); + _sum5 = _mm256_fmadd_ps(_k21, _r26, _sum5); + _sum5 = _mm256_fmadd_ps(_k22, _r27, _sum5); + + if (has_act) { + _sum5 = activation8_m256(_sum5, act_type); + } + _mm256_storeu_ps(output_data + 40, _sum5); + + __m256 _sum6 = _bias0; + __m256 _r08 = _mm256_loadu_ps(r0 + 64); + __m256 _r18 = _mm256_loadu_ps(r1 + 64); + __m256 _r28 = _mm256_loadu_ps(r2 + 64); + + _sum6 = _mm256_fmadd_ps(_k00, _r06, _sum6); + _sum6 = _mm256_fmadd_ps(_k01, _r07, _sum6); + _sum6 = _mm256_fmadd_ps(_k02, _r08, _sum6); + _sum6 = _mm256_fmadd_ps(_k10, _r16, _sum6); + _sum6 = _mm256_fmadd_ps(_k11, _r17, _sum6); + _sum6 = _mm256_fmadd_ps(_k12, _r18, _sum6); + _sum6 = _mm256_fmadd_ps(_k20, _r26, _sum6); + _sum6 = _mm256_fmadd_ps(_k21, _r27, _sum6); + _sum6 = _mm256_fmadd_ps(_k22, _r28, _sum6); + + if (has_act) { + _sum6 = activation8_m256(_sum6, act_type); + } + _mm256_storeu_ps(output_data + 48, _sum6); + + __m256 _sum7 = _bias0; + __m256 _r09 = _mm256_loadu_ps(r0 + 72); + __m256 _r19 = _mm256_loadu_ps(r1 + 72); + __m256 _r29 = _mm256_loadu_ps(r2 + 72); + + _sum7 = _mm256_fmadd_ps(_k00, _r07, _sum7); + _sum7 = _mm256_fmadd_ps(_k01, _r08, _sum7); + _sum7 = _mm256_fmadd_ps(_k02, _r09, _sum7); + _sum7 = _mm256_fmadd_ps(_k10, _r17, _sum7); + _sum7 = _mm256_fmadd_ps(_k11, _r18, _sum7); + _sum7 = _mm256_fmadd_ps(_k12, _r19, _sum7); + _sum7 = _mm256_fmadd_ps(_k20, _r27, _sum7); + _sum7 = _mm256_fmadd_ps(_k21, _r28, _sum7); + _sum7 = _mm256_fmadd_ps(_k22, _r29, _sum7); + + if (has_act) { + _sum7 = activation8_m256(_sum7, act_type); + } + _mm256_storeu_ps(output_data + 56, _sum7); - _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); - _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); - _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); - _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); - _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); - _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); - _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); - _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); - _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); + r0 += 64; + r1 += 64; + r2 += 64; + output_data += 64; + } + for (; j + 3 < output_width; j += 4) { + __m256 _sum0 = _bias0; + + __m256 _r00 = _mm256_loadu_ps(r0); + __m256 _r01 = _mm256_loadu_ps(r0 + 8); + __m256 _r02 = _mm256_loadu_ps(r0 + 16); + __m256 _r10 = _mm256_loadu_ps(r1); + __m256 _r11 = _mm256_loadu_ps(r1 + 8); + __m256 _r12 = _mm256_loadu_ps(r1 + 16); + __m256 _r20 = _mm256_loadu_ps(r2); + __m256 _r21 = _mm256_loadu_ps(r2 + 8); + __m256 _r22 = _mm256_loadu_ps(r2 + 16); + + _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); + _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); + _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); + _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); + _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); + _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); + _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); + _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); + _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); + + if (has_act) { + _sum0 = activation8_m256(_sum0, act_type); + } + _mm256_storeu_ps(output_data, _sum0); + + __m256 _sum1 = _bias0; + __m256 _r03 = _mm256_loadu_ps(r0 + 24); + __m256 _r13 = _mm256_loadu_ps(r1 + 24); + __m256 _r23 = _mm256_loadu_ps(r2 + 24); + + _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1); + _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1); + _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1); + _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1); + _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1); + _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1); + _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1); + _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1); + _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1); + + if (has_act) { + _sum1 = activation8_m256(_sum1, act_type); + } + _mm256_storeu_ps(output_data + 8, _sum1); + + __m256 _sum2 = _bias0; + __m256 _r04 = _mm256_loadu_ps(r0 + 32); + __m256 _r14 = _mm256_loadu_ps(r1 + 32); + __m256 _r24 = _mm256_loadu_ps(r2 + 32); + + _sum2 = _mm256_fmadd_ps(_k00, _r02, _sum2); + _sum2 = _mm256_fmadd_ps(_k01, _r03, _sum2); + _sum2 = _mm256_fmadd_ps(_k02, _r04, _sum2); + _sum2 = _mm256_fmadd_ps(_k10, _r12, _sum2); + _sum2 = _mm256_fmadd_ps(_k11, _r13, _sum2); + _sum2 = _mm256_fmadd_ps(_k12, _r14, _sum2); + _sum2 = _mm256_fmadd_ps(_k20, _r22, _sum2); + _sum2 = _mm256_fmadd_ps(_k21, _r23, _sum2); + _sum2 = _mm256_fmadd_ps(_k22, _r24, _sum2); + + if (has_act) { + _sum2 = activation8_m256(_sum2, act_type); + } + _mm256_storeu_ps(output_data + 16, _sum2); + + __m256 _sum3 = _bias0; + __m256 _r05 = _mm256_loadu_ps(r0 + 40); + __m256 _r15 = _mm256_loadu_ps(r1 + 40); + __m256 _r25 = _mm256_loadu_ps(r2 + 40); + + _sum3 = _mm256_fmadd_ps(_k00, _r03, _sum3); + _sum3 = _mm256_fmadd_ps(_k01, _r04, _sum3); + _sum3 = _mm256_fmadd_ps(_k02, _r05, _sum3); + _sum3 = _mm256_fmadd_ps(_k10, _r13, _sum3); + _sum3 = _mm256_fmadd_ps(_k11, _r14, _sum3); + _sum3 = _mm256_fmadd_ps(_k12, _r15, _sum3); + _sum3 = _mm256_fmadd_ps(_k20, _r23, _sum3); + _sum3 = _mm256_fmadd_ps(_k21, _r24, _sum3); + _sum3 = _mm256_fmadd_ps(_k22, _r25, _sum3); + + if (has_act) { + _sum3 = activation8_m256(_sum3, act_type); + } + _mm256_storeu_ps(output_data + 24, _sum3); - if (has_act) { - _sum0 = activation8_m256(_sum0, act_type); - } - _mm256_storeu_ps(output_data, _sum0); + r0 += 32; + r1 += 32; + r2 += 32; + output_data += 32; + } + for (; j + 1 < output_width; j += 2) { + __m256 _sum0 = _bias0; + + __m256 _r00 = _mm256_loadu_ps(r0); + __m256 _r01 = _mm256_loadu_ps(r0 + 8); + __m256 _r02 = _mm256_loadu_ps(r0 + 16); + __m256 _r10 = _mm256_loadu_ps(r1); + __m256 _r11 = _mm256_loadu_ps(r1 + 8); + __m256 _r12 = _mm256_loadu_ps(r1 + 16); + __m256 _r20 = _mm256_loadu_ps(r2); + __m256 _r21 = _mm256_loadu_ps(r2 + 8); + __m256 _r22 = _mm256_loadu_ps(r2 + 16); + + _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); + _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); + _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); + _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); + _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); + _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); + _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); + _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); + _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); + + if (has_act) { + _sum0 = activation8_m256(_sum0, act_type); + } + _mm256_storeu_ps(output_data, _sum0); + + __m256 _sum1 = _bias0; + __m256 _r03 = _mm256_loadu_ps(r0 + 24); + __m256 _r13 = _mm256_loadu_ps(r1 + 24); + __m256 _r23 = _mm256_loadu_ps(r2 + 24); + + _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1); + _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1); + _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1); + _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1); + _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1); + _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1); + _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1); + _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1); + _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1); + + if (has_act) { + _sum1 = activation8_m256(_sum1, act_type); + } + _mm256_storeu_ps(output_data + 8, _sum1); - r0 += 8; - r1 += 8; - r2 += 8; - output_data += 8; + r0 += 16; + r1 += 16; + r2 += 16; + output_data += 16; + } + for (; j < output_width; ++j) { + __m256 _sum0 = _bias0; + + __m256 _r00 = _mm256_loadu_ps(r0); + __m256 _r01 = _mm256_loadu_ps(r0 + 8); + __m256 _r02 = _mm256_loadu_ps(r0 + 16); + __m256 _r10 = _mm256_loadu_ps(r1); + __m256 _r11 = _mm256_loadu_ps(r1 + 8); + __m256 _r12 = _mm256_loadu_ps(r1 + 16); + __m256 _r20 = _mm256_loadu_ps(r2); + __m256 _r21 = _mm256_loadu_ps(r2 + 8); + __m256 _r22 = _mm256_loadu_ps(r2 + 16); + + _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0); + _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0); + _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0); + _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0); + _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0); + _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0); + _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0); + _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0); + _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0); + + if (has_act) { + _sum0 = activation8_m256(_sum0, act_type); } - r0 += 2 * 8; - r1 += 2 * 8; - r2 += 2 * 8; - } // end of for output_height - } // end of for channel_num - } // end of for batch_size + _mm256_storeu_ps(output_data, _sum0); + + r0 += 8; + r1 += 8; + r2 += 8; + output_data += 8; + } + r0 += 2 * 8; + r1 += 2 * 8; + r2 += 2 * 8; + } // end of for output_height + } // end of for batch_size * channel_num } // input [bs, ic/8, ih, iw, 8] diff --git a/lite/backends/x86/math/conv_utils.cc b/lite/backends/x86/math/conv_utils.cc index 980145a6dde..8c30f8f739b 100644 --- a/lite/backends/x86/math/conv_utils.cc +++ b/lite/backends/x86/math/conv_utils.cc @@ -557,6 +557,124 @@ void padding1_float(lite::Tensor* input, } } +void pack_padding8_m256(lite::Tensor* input, + lite::Tensor* output, + const int channel_num, + const std::vector& paddings) { + CHECK_EQ(input->dims().size(), 4UL); + int batch_size = input->dims()[0]; + int input_channel = input->dims()[1]; + int input_height = input->dims()[2]; + int input_width = input->dims()[3]; + + CHECK_EQ((input_channel & 7), 0); + const float* input_data = input->data(); + + CHECK_EQ(paddings.size(), 4UL); + int top = paddings[0]; + int bottom = paddings[1]; + int left = paddings[2]; + int right = paddings[3]; + + // in + const int kernel_size = input_height * input_width; + const int pack_step = 8 * kernel_size; + const int batch_step = channel_num * pack_step; + + // out + int out_height = input_height + top + bottom; + int out_width = input_width + left + right; + + // output [bs, ic/8, oh, ow, 8] + output->Resize({batch_size, channel_num, out_height, out_width, 8}); + auto output_data = output->mutable_data(); + + int top_size = top * out_width; + int bottom_size = bottom * out_width; + + __m256 pad_val = _mm256_set1_ps(0.f); + + for (int bs = 0; bs < batch_size; ++bs) { + for (int ic = 0; ic < channel_num; ++ic) { + const float* input_ptr = input_data + bs * batch_step + ic * pack_step; + + const float* r0 = (input_ptr); + const float* r1 = (input_ptr + kernel_size); + const float* r2 = (input_ptr + kernel_size * 2); + const float* r3 = (input_ptr + kernel_size * 3); + const float* r4 = (input_ptr + kernel_size * 4); + const float* r5 = (input_ptr + kernel_size * 5); + const float* r6 = (input_ptr + kernel_size * 6); + const float* r7 = (input_ptr + kernel_size * 7); + + // fill top + for (int y = 0; y < top_size; ++y) { + _mm256_storeu_ps(output_data, pad_val); + output_data += 8; + } + // fill center + for (int y = 0; y < input_height; ++y) { + for (int x = 0; x < left; ++x) { + _mm256_storeu_ps(output_data, pad_val); + output_data += 8; + } + // pack and transpose + int pos = 0; + for (; pos + 7 < input_width; pos += 8) { + __m256 _row0 = _mm256_loadu_ps(r0); + __m256 _row1 = _mm256_loadu_ps(r1); + __m256 _row2 = _mm256_loadu_ps(r2); + __m256 _row3 = _mm256_loadu_ps(r3); + __m256 _row4 = _mm256_loadu_ps(r4); + __m256 _row5 = _mm256_loadu_ps(r5); + __m256 _row6 = _mm256_loadu_ps(r6); + __m256 _row7 = _mm256_loadu_ps(r7); + transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7); + _mm256_storeu_ps(output_data, _row0); + _mm256_storeu_ps(output_data + 8, _row1); + _mm256_storeu_ps(output_data + 16, _row2); + _mm256_storeu_ps(output_data + 24, _row3); + _mm256_storeu_ps(output_data + 32, _row4); + _mm256_storeu_ps(output_data + 40, _row5); + _mm256_storeu_ps(output_data + 48, _row6); + _mm256_storeu_ps(output_data + 56, _row7); + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + r7 += 8; + output_data += 64; + } + + for (; pos < input_width; ++pos) { + output_data[0] = *r0++; + output_data[1] = *r1++; + output_data[2] = *r2++; + output_data[3] = *r3++; + output_data[4] = *r4++; + output_data[5] = *r5++; + output_data[6] = *r6++; + output_data[7] = *r7++; + output_data += 8; + } + + for (int x = 0; x < right; ++x) { + _mm256_storeu_ps(output_data, pad_val); + output_data += 8; + } + } + // fill bottom + for (int y = 0; y < bottom_size; ++y) { + _mm256_storeu_ps(output_data, pad_val); + output_data += 8; + } + } + } +} + __m256 activation8_m256(__m256 input, const lite_api::ActivationType act_type) { if (act_type == lite_api::ActivationType::kRelu) { return _mm256_max_ps(input, _mm256_setzero_ps()); diff --git a/lite/backends/x86/math/conv_utils.h b/lite/backends/x86/math/conv_utils.h index c48124e3283..d2deb8831a9 100644 --- a/lite/backends/x86/math/conv_utils.h +++ b/lite/backends/x86/math/conv_utils.h @@ -48,6 +48,11 @@ void padding1_float(lite::Tensor* input, lite::Tensor* output, const std::vector& paddings); +void pack_padding8_m256(lite::Tensor* input, + lite::Tensor* output, + const int channel_num, + const std::vector& paddings); + // for activation - only support relu, relu6 __m256 activation8_m256(__m256 input, const lite_api::ActivationType act_type); __m128 activation4_m128(__m128 input, const lite_api::ActivationType act_type); diff --git a/lite/kernels/x86/conv_compute.cc b/lite/kernels/x86/conv_compute.cc index 1c8d48d6c59..75b47f70cb9 100644 --- a/lite/kernels/x86/conv_compute.cc +++ b/lite/kernels/x86/conv_compute.cc @@ -36,8 +36,12 @@ void Conv2dCompute::PrepareForRun() { const int stride_h = param.strides[0]; const int stride_w = param.strides[1]; + auto dilations = *param.dilations; + bool no_dilation = (static_cast(dilations[0]) == 1) && + (static_cast(dilations[1]) == 1); + if (input_channel == groups && output_channel == groups && - (groups & 3) == 0) { + (groups & 3) == 0 && no_dilation) { if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1) { impl_ = new DepthwiseConv; VLOG(3) << "invoking conv_depthwise_3x3s1"; diff --git a/lite/kernels/x86/conv_compute.h b/lite/kernels/x86/conv_compute.h index 80eeaae499a..9eb8567acc7 100644 --- a/lite/kernels/x86/conv_compute.h +++ b/lite/kernels/x86/conv_compute.h @@ -69,6 +69,24 @@ class Conv2dCompute : public KernelLite { param.output->template mutable_data(); const int batch_size = static_cast(param.x->dims()[0]); + const int kh = static_cast(param.filter->dims()[2]); + const int kw = static_cast(param.filter->dims()[3]); + + const int sh = static_cast(param.strides[0]); + const int sw = static_cast(param.strides[1]); + + auto paddings = *param.paddings; + const int ph = paddings[0]; + const int pw = paddings[2]; + + bool kps_equal = (pw == ph) && (sw == sh) && (kw == kh); + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + bool flag_1x1gemm = false; + if (kw == 1 && sw == 1 && pw == 0 && kps_equal && pads_equal) { + flag_1x1gemm = true; + } + std::vector filter_shape_vec(filter.dims().Vectorize()); std::vector output_shape_vec(param.output->dims().Vectorize()); size_t data_dim = filter_shape_vec.size() - 2; @@ -122,7 +140,7 @@ class Conv2dCompute : public KernelLite { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { + } else if (data_dim == 2U && !flag_1x1gemm) { // im2col im2col(context, in_slice, diff --git a/lite/kernels/x86/conv_depthwise.cc b/lite/kernels/x86/conv_depthwise.cc index f2f74824d6a..4601f96c910 100644 --- a/lite/kernels/x86/conv_depthwise.cc +++ b/lite/kernels/x86/conv_depthwise.cc @@ -36,12 +36,12 @@ void DepthwiseConv::Run() { input_channel % 8 == 0 ? 8 : input_channel % 4 == 0 ? 4 : 1; const int pack_num = input_channel / pack_size; - // [bs, ic, ih, iw] & pack_size=8 => [bs, ic/8, ih, iw, 8] - // [bs, ic, ih, iw] & pack_size=4 => [bs, ic/4, ih, iw, 4] if (pack_size == 8) { - lite::x86::math::pack8_m256(param.x, &input_pack_, pack_num, false); - lite::x86::math::padding8_m256( - &input_pack_, &input_padding_, *(param.paddings)); + // lite::x86::math::pack8_m256(param.x, &input_pack_, pack_num, false); + // lite::x86::math::padding8_m256( + // &input_pack_, &input_padding_, *(param.paddings)); + lite::x86::math::pack_padding8_m256( + param.x, &input_padding_, pack_num, *(param.paddings)); } else if (pack_size == 4) { lite::x86::math::pack4_m128(param.x, &input_pack_, pack_num, false); lite::x86::math::padding4_m128(