Skip to content

Commit

Permalink
Enable X32-GIO-PACKW AVX microkernel
Browse files Browse the repository at this point in the history
- Enable AVX X32 packw
- prefetch support kstride
- AVX and AVX512 use SIMD_TILE and SIMD_SIZE for consistency

PiperOrigin-RevId: 698217641
  • Loading branch information
fbarchard authored and xnnpack-bot committed Nov 20, 2024
1 parent 52f3146 commit e297edd
Show file tree
Hide file tree
Showing 22 changed files with 387 additions and 407 deletions.
2 changes: 1 addition & 1 deletion cmake/gen/avx_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ SET(PROD_AVX_MICROKERNEL_SRCS
src/qu8-vmul/gen/qu8-vmul-minmax-fp32-avx-mul16-ld64-u16.c
src/qu8-vmulc/gen/qu8-vmulc-minmax-fp32-avx-mul16-ld64-u16.c
src/x8-lut/gen/x8-lut-avx-u64.c
src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx-u4.c
src/x32-packw/gen/x32-packw-x16s4-gemm-goi-avx-u4.c
src/x32-transposec/gen/x32-transposec-8x8-reuse-multi-avx.c
Expand Down Expand Up @@ -482,7 +483,6 @@ SET(NON_PROD_AVX_MICROKERNEL_SRCS
src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u1-prfm.c
src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u1.c
src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8-prfm.c
src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx-u4-prfm.c
src/x32-packw/gen/x32-packw-x16s4-gemm-goi-avx-u4-prfm.c
src/x32-packw/gen/x32-packw-x32-gemm-gio-avx-u1-prfm.c
Expand Down
2 changes: 1 addition & 1 deletion gen/avx_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ PROD_AVX_MICROKERNEL_SRCS = [
"src/qu8-vmul/gen/qu8-vmul-minmax-fp32-avx-mul16-ld64-u16.c",
"src/qu8-vmulc/gen/qu8-vmulc-minmax-fp32-avx-mul16-ld64-u16.c",
"src/x8-lut/gen/x8-lut-avx-u64.c",
"src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c",
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx-u4.c",
"src/x32-packw/gen/x32-packw-x16s4-gemm-goi-avx-u4.c",
"src/x32-transposec/gen/x32-transposec-8x8-reuse-multi-avx.c",
Expand Down Expand Up @@ -479,7 +480,6 @@ NON_PROD_AVX_MICROKERNEL_SRCS = [
"src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u1-prfm.c",
"src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u1.c",
"src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8-prfm.c",
"src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c",
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx-u4-prfm.c",
"src/x32-packw/gen/x32-packw-x16s4-gemm-goi-avx-u4-prfm.c",
"src/x32-packw/gen/x32-packw-x32-gemm-gio-avx-u1-prfm.c",
Expand Down
4 changes: 2 additions & 2 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ static void init_f32_gemm_config(void) {
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast);
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast_prfm);
f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w;
f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8;
f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx_u4;
f32_gemm_config.mr = 5;
f32_gemm_config.nr = 16;
Expand All @@ -721,7 +721,7 @@ static void init_f32_gemm_config(void) {
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__avx_broadcast);
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x16__avx_broadcast);
f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w;
f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8;
f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx_u4;
f32_gemm_config.mr = 5;
f32_gemm_config.nr = 16;
Expand Down
18 changes: 9 additions & 9 deletions src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u1-prfm.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1_prfm(
for (; n >= 16; n -= 16) {
if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_loadu_ps(b + 0);
const __m256 vb8 = _mm256_loadu_ps(b + 8);
const __m256 vb1 = _mm256_loadu_ps(b + 8);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += 16;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -75,10 +75,10 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1_prfm(
// KC remainder loop
for (; k > 0; --k) {
const __m256 v0 = _mm256_loadu_ps(w + 0);
const __m256 v8 = _mm256_loadu_ps(w + 8);
const __m256 v1 = _mm256_loadu_ps(w + 8);
xnn_prefetch_to_l1((const int8_t*) w + 960);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand All @@ -90,13 +90,13 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1_prfm(
assert(n >= 1);
assert(n <= 15);
const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[16 - n]);
const __m256i vmask8 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);
const __m256i vmask1 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);

if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_maskload_ps(b + 0, vmask0);
const __m256 vb8 = _mm256_maskload_ps(b + 8, vmask8);
const __m256 vb1 = _mm256_maskload_ps(b + 8, vmask1);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += n;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -107,9 +107,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1_prfm(
// KC main loop
for (size_t k = kc; k > 0; --k) {
const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0);
const __m256 v8 = _mm256_maskload_ps(w + 8, vmask8);
const __m256 v1 = _mm256_maskload_ps(w + 8, vmask1);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand Down
18 changes: 9 additions & 9 deletions src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u1.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1(
for (; n >= 16; n -= 16) {
if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_loadu_ps(b + 0);
const __m256 vb8 = _mm256_loadu_ps(b + 8);
const __m256 vb1 = _mm256_loadu_ps(b + 8);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += 16;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -74,9 +74,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1(
// KC remainder loop
for (; k > 0; --k) {
const __m256 v0 = _mm256_loadu_ps(w + 0);
const __m256 v8 = _mm256_loadu_ps(w + 8);
const __m256 v1 = _mm256_loadu_ps(w + 8);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand All @@ -88,13 +88,13 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1(
assert(n >= 1);
assert(n <= 15);
const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[16 - n]);
const __m256i vmask8 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);
const __m256i vmask1 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);

if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_maskload_ps(b + 0, vmask0);
const __m256 vb8 = _mm256_maskload_ps(b + 8, vmask8);
const __m256 vb1 = _mm256_maskload_ps(b + 8, vmask1);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += n;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -105,9 +105,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u1(
// KC main loop
for (size_t k = kc; k > 0; --k) {
const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0);
const __m256 v8 = _mm256_maskload_ps(w + 8, vmask8);
const __m256 v1 = _mm256_maskload_ps(w + 8, vmask1);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand Down
66 changes: 33 additions & 33 deletions src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8-prfm.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8_prfm(
for (; n >= 16; n -= 16) {
if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_loadu_ps(b + 0);
const __m256 vb8 = _mm256_loadu_ps(b + 8);
const __m256 vb1 = _mm256_loadu_ps(b + 8);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += 16;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -74,56 +74,56 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8_prfm(
// KC main loop 8x16
for (; k >= 8; k -= 8) {
const __m256 v0_0 = _mm256_loadu_ps(w + 0 + 0 * k_stride);
const __m256 v8_0 = _mm256_loadu_ps(w + 8 + 0 * k_stride);
const __m256 v1_0 = _mm256_loadu_ps(w + 8 + 0 * k_stride);
const __m256 v0_1 = _mm256_loadu_ps(w + 0 + 1 * k_stride);
const __m256 v8_1 = _mm256_loadu_ps(w + 8 + 1 * k_stride);
const __m256 v1_1 = _mm256_loadu_ps(w + 8 + 1 * k_stride);
const __m256 v0_2 = _mm256_loadu_ps(w + 0 + 2 * k_stride);
const __m256 v8_2 = _mm256_loadu_ps(w + 8 + 2 * k_stride);
const __m256 v1_2 = _mm256_loadu_ps(w + 8 + 2 * k_stride);
const __m256 v0_3 = _mm256_loadu_ps(w + 0 + 3 * k_stride);
const __m256 v8_3 = _mm256_loadu_ps(w + 8 + 3 * k_stride);
const __m256 v1_3 = _mm256_loadu_ps(w + 8 + 3 * k_stride);
const __m256 v0_4 = _mm256_loadu_ps(w + 0 + 4 * k_stride);
const __m256 v8_4 = _mm256_loadu_ps(w + 8 + 4 * k_stride);
const __m256 v1_4 = _mm256_loadu_ps(w + 8 + 4 * k_stride);
const __m256 v0_5 = _mm256_loadu_ps(w + 0 + 5 * k_stride);
const __m256 v8_5 = _mm256_loadu_ps(w + 8 + 5 * k_stride);
const __m256 v1_5 = _mm256_loadu_ps(w + 8 + 5 * k_stride);
const __m256 v0_6 = _mm256_loadu_ps(w + 0 + 6 * k_stride);
const __m256 v8_6 = _mm256_loadu_ps(w + 8 + 6 * k_stride);
const __m256 v1_6 = _mm256_loadu_ps(w + 8 + 6 * k_stride);
const __m256 v0_7 = _mm256_loadu_ps(w + 0 + 7 * k_stride);
const __m256 v8_7 = _mm256_loadu_ps(w + 8 + 7 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
xnn_prefetch_to_l1((const int8_t*) w + 960);
const __m256 v1_7 = _mm256_loadu_ps(w + 8 + 7 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 0 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 1 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 2 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 3 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 4 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 5 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 6 * k_stride);
xnn_prefetch_to_l1((const int8_t*) w + 960 + 7 * k_stride);
_mm256_store_ps(packed_w + 0, v0_0);
_mm256_store_ps(packed_w + 8, v8_0);
_mm256_store_ps(packed_w + 8, v1_0);
_mm256_store_ps(packed_w + 16, v0_1);
_mm256_store_ps(packed_w + 24, v8_1);
_mm256_store_ps(packed_w + 24, v1_1);
_mm256_store_ps(packed_w + 32, v0_2);
_mm256_store_ps(packed_w + 40, v8_2);
_mm256_store_ps(packed_w + 40, v1_2);
_mm256_store_ps(packed_w + 48, v0_3);
_mm256_store_ps(packed_w + 56, v8_3);
_mm256_store_ps(packed_w + 56, v1_3);
_mm256_store_ps(packed_w + 64, v0_4);
_mm256_store_ps(packed_w + 72, v8_4);
_mm256_store_ps(packed_w + 72, v1_4);
_mm256_store_ps(packed_w + 80, v0_5);
_mm256_store_ps(packed_w + 88, v8_5);
_mm256_store_ps(packed_w + 88, v1_5);
_mm256_store_ps(packed_w + 96, v0_6);
_mm256_store_ps(packed_w + 104, v8_6);
_mm256_store_ps(packed_w + 104, v1_6);
_mm256_store_ps(packed_w + 112, v0_7);
_mm256_store_ps(packed_w + 120, v8_7);
_mm256_store_ps(packed_w + 120, v1_7);
w += k_stride * 8;
packed_w += 128;
}

// KC remainder loop
for (; k > 0; --k) {
const __m256 v0 = _mm256_loadu_ps(w + 0);
const __m256 v8 = _mm256_loadu_ps(w + 8);
const __m256 v1 = _mm256_loadu_ps(w + 8);
xnn_prefetch_to_l1((const int8_t*) w + 960);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand All @@ -135,13 +135,13 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8_prfm(
assert(n >= 1);
assert(n <= 15);
const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[16 - n]);
const __m256i vmask8 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);
const __m256i vmask1 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);

if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_maskload_ps(b + 0, vmask0);
const __m256 vb8 = _mm256_maskload_ps(b + 8, vmask8);
const __m256 vb1 = _mm256_maskload_ps(b + 8, vmask1);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += n;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -152,9 +152,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8_prfm(
// KC main loop
for (size_t k = kc; k > 0; --k) {
const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0);
const __m256 v8 = _mm256_maskload_ps(w + 8, vmask8);
const __m256 v1 = _mm256_maskload_ps(w + 8, vmask1);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand Down
50 changes: 25 additions & 25 deletions src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8(
for (; n >= 16; n -= 16) {
if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_loadu_ps(b + 0);
const __m256 vb8 = _mm256_loadu_ps(b + 8);
const __m256 vb1 = _mm256_loadu_ps(b + 8);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += 16;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -73,47 +73,47 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8(
// KC main loop 8x16
for (; k >= 8; k -= 8) {
const __m256 v0_0 = _mm256_loadu_ps(w + 0 + 0 * k_stride);
const __m256 v8_0 = _mm256_loadu_ps(w + 8 + 0 * k_stride);
const __m256 v1_0 = _mm256_loadu_ps(w + 8 + 0 * k_stride);
const __m256 v0_1 = _mm256_loadu_ps(w + 0 + 1 * k_stride);
const __m256 v8_1 = _mm256_loadu_ps(w + 8 + 1 * k_stride);
const __m256 v1_1 = _mm256_loadu_ps(w + 8 + 1 * k_stride);
const __m256 v0_2 = _mm256_loadu_ps(w + 0 + 2 * k_stride);
const __m256 v8_2 = _mm256_loadu_ps(w + 8 + 2 * k_stride);
const __m256 v1_2 = _mm256_loadu_ps(w + 8 + 2 * k_stride);
const __m256 v0_3 = _mm256_loadu_ps(w + 0 + 3 * k_stride);
const __m256 v8_3 = _mm256_loadu_ps(w + 8 + 3 * k_stride);
const __m256 v1_3 = _mm256_loadu_ps(w + 8 + 3 * k_stride);
const __m256 v0_4 = _mm256_loadu_ps(w + 0 + 4 * k_stride);
const __m256 v8_4 = _mm256_loadu_ps(w + 8 + 4 * k_stride);
const __m256 v1_4 = _mm256_loadu_ps(w + 8 + 4 * k_stride);
const __m256 v0_5 = _mm256_loadu_ps(w + 0 + 5 * k_stride);
const __m256 v8_5 = _mm256_loadu_ps(w + 8 + 5 * k_stride);
const __m256 v1_5 = _mm256_loadu_ps(w + 8 + 5 * k_stride);
const __m256 v0_6 = _mm256_loadu_ps(w + 0 + 6 * k_stride);
const __m256 v8_6 = _mm256_loadu_ps(w + 8 + 6 * k_stride);
const __m256 v1_6 = _mm256_loadu_ps(w + 8 + 6 * k_stride);
const __m256 v0_7 = _mm256_loadu_ps(w + 0 + 7 * k_stride);
const __m256 v8_7 = _mm256_loadu_ps(w + 8 + 7 * k_stride);
const __m256 v1_7 = _mm256_loadu_ps(w + 8 + 7 * k_stride);
_mm256_store_ps(packed_w + 0, v0_0);
_mm256_store_ps(packed_w + 8, v8_0);
_mm256_store_ps(packed_w + 8, v1_0);
_mm256_store_ps(packed_w + 16, v0_1);
_mm256_store_ps(packed_w + 24, v8_1);
_mm256_store_ps(packed_w + 24, v1_1);
_mm256_store_ps(packed_w + 32, v0_2);
_mm256_store_ps(packed_w + 40, v8_2);
_mm256_store_ps(packed_w + 40, v1_2);
_mm256_store_ps(packed_w + 48, v0_3);
_mm256_store_ps(packed_w + 56, v8_3);
_mm256_store_ps(packed_w + 56, v1_3);
_mm256_store_ps(packed_w + 64, v0_4);
_mm256_store_ps(packed_w + 72, v8_4);
_mm256_store_ps(packed_w + 72, v1_4);
_mm256_store_ps(packed_w + 80, v0_5);
_mm256_store_ps(packed_w + 88, v8_5);
_mm256_store_ps(packed_w + 88, v1_5);
_mm256_store_ps(packed_w + 96, v0_6);
_mm256_store_ps(packed_w + 104, v8_6);
_mm256_store_ps(packed_w + 104, v1_6);
_mm256_store_ps(packed_w + 112, v0_7);
_mm256_store_ps(packed_w + 120, v8_7);
_mm256_store_ps(packed_w + 120, v1_7);
w += k_stride * 8;
packed_w += 128;
}

// KC remainder loop
for (; k > 0; --k) {
const __m256 v0 = _mm256_loadu_ps(w + 0);
const __m256 v8 = _mm256_loadu_ps(w + 8);
const __m256 v1 = _mm256_loadu_ps(w + 8);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand All @@ -125,13 +125,13 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8(
assert(n >= 1);
assert(n <= 15);
const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[16 - n]);
const __m256i vmask8 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);
const __m256i vmask1 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);

if XNN_LIKELY(b != NULL) {
const __m256 vb0 = _mm256_maskload_ps(b + 0, vmask0);
const __m256 vb8 = _mm256_maskload_ps(b + 8, vmask8);
const __m256 vb1 = _mm256_maskload_ps(b + 8, vmask1);
_mm256_store_ps(packed_w + 0, vb0);
_mm256_store_ps(packed_w + 8, vb8);
_mm256_store_ps(packed_w + 8, vb1);
b += n;
} else {
_mm256_store_ps(packed_w + 0, vzero);
Expand All @@ -142,9 +142,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8(
// KC main loop
for (size_t k = kc; k > 0; --k) {
const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0);
const __m256 v8 = _mm256_maskload_ps(w + 8, vmask8);
const __m256 v1 = _mm256_maskload_ps(w + 8, vmask1);
_mm256_store_ps(packed_w + 0, v0);
_mm256_store_ps(packed_w + 8, v8);
_mm256_store_ps(packed_w + 8, v1);
w += k_stride;
packed_w += 16;
}
Expand Down
1 change: 0 additions & 1 deletion src/x32-packw/gen/x32-packw-x16-gemm-gio-avx512f-u1-prfm.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__avx512f_u1_prfm(
}
packed_w += 16;

// KC main loop 1x16
size_t k = kc;

// KC remainder loop
Expand Down
Loading

0 comments on commit e297edd

Please sign in to comment.