Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul nbits to optimize memory layout for avx instructions #22203

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,158 @@ QuantizeARow_CompInt8_avx512(
}
}

__m512
ComputeMulScal(const float* a_ptr, size_t step, float& scale)
{
const __m512 signBit = _mm512_set1_ps(-0.0f);
__m512 maxAbs = _mm512_setzero_ps();

for (size_t kk = 0; kk < step; kk += 16) {
const size_t klen = std::min(size_t(16), step - kk);

uint32_t mask = 0xffff >> (16 - klen);
__m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), a_ptr + kk);

// Compute max(abs(e)) for the block
maxAbs = _mm512_max_ps(maxAbs, _mm512_andnot_ps(signBit, v0));
}

__m256 max8 =
_mm256_max_ps(_mm512_extractf32x8_ps(maxAbs, 1), _mm512_extractf32x8_ps(maxAbs, 0));
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max8, 1), _mm256_castps256_ps128(max8));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
const float maxScalar = _mm_cvtss_f32(max4);

// Quantize these floats
scale = maxScalar / 127.f;

const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f;
return _mm512_set1_ps(inverse_scale);
}

void
QuantizeInt8ComputeBlksum(const float* a_ptr, size_t step, __m512& mul, float scale, __m256i& i0_32_epi8, float& blksum)
{
const __m256i one_16_epi16 = _mm256_set1_epi16(1);
__m256i sum_16_epi16 = _mm256_setzero_si256();
__m128i i_16_epi8[2];
int index = 0;
for (size_t kk = 0; kk < step; kk += 16, index++) {
const size_t klen = std::min(size_t(16), step - kk);

uint32_t mask = 0xffff >> (16 - klen);
__m512 v0 = _mm512_maskz_loadu_ps(__mmask16(mask), A + kk);
v0 = _mm512_mul_ps(v0, mul);

// Round to nearest integer
v0 = _mm512_roundscale_ps(v0, _MM_ROUND_NEAREST);

// Convert floats to integers
__m512i i0 = _mm512_cvtps_epi32(v0);

// Convert int32 to int8
i_16_epi8[index] = _mm512_cvtepi32_epi8(i0);
//_mm_storeu_si128(dst++, i0_8);

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.

// accumulate Sum(a_i)
__m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8[index]);
sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16);
}
i0_32_epi8 = _mm256_set_m128i(i_16_epi8[0], i_16_epi8[1]);
const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16);
blksum = scale * hsum_8_epi32(sum_8_epi32);
}

void
Quantize1BlkBlkLen32(const float* a_ptr, size_t step, __m256i& i_32_epi8, float& scale, float& blksum)
{
// 32 float to 32 epi8s in i0_32_epi8
__m512 mul = ComputeMulScal(a_ptr, step, scale);
QuantizeInt8ComputeBlksum(a_ptr, step, mul, scale, i_32_epi8, blksum);
}

void
stoore_4blk_blklen32_interleaved(__m256i i_32_epi8[4], int8_t* blob)
{
// 0 1 2 3 32 33 34 35 64 65 66 67 96 97 98 99
// 4 5 6 7 36 37 38 39 68 69 70 71 100 101 102 103
// 8 9 10 11 40 41 42 43 72 73 74 75 104 105 106 107
// 12 13 14 15 44 45 46 47 76 77 78 79 108 109 110 111
//
// 16 17 18 19 48 49 50 51 80 81 82 83 112 113 114 115
// 20 21 22 23 52 53 54 55 84 85 86 87 116 117 118 119
// 24 25 26 27 56 57 58 59 88 89 90 91 120 121 122 123
// 28 29 30 31 60 61 62 63 92 93 94 95 124 125 126 127

// Interleave and store i_32_epi8[4] in the specified layout
__m256i a0_lower = _mm256_permute2x128_si256(i_32_epi8[0], i_32_epi8[1], 0x20);
__m256i a0_higher = _mm256_permute2x128_si256(i_32_epi8[0], i_32_epi8[1], 0x31);
__m256i a1_lower = _mm256_permute2x128_si256(i_32_epi8[2], i_32_epi8[3], 0x20);
__m256i a1_higher = _mm256_permute2x128_si256(i_32_epi8[2], i_32_epi8[3], 0x31);

__m512i a_lower = _mm512_inserti64x4(_mm512_castsi256_si512(a0_lower), a1_lower, 1);
__m512i a_higher = _mm512_inserti64x4(_mm512_castsi256_si512(a0_higher), a1_higher, 1);

__m512i idx = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
__m512i a_lower_interleaved = _mm512_permutexvar_epi32(idx, a_lower);
__m512i a_higher_interleaved = _mm512_permutexvar_epi32(idx, a_higher);

_mm512_storeu_si512(reinterpret_cast<__m512i*>(blob + 0 * 64), a_lower_interleaved);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(blob + 1 * 64), a_higher_interleaved);
}

void MLASCALL
QuantizeARow_CompInt8_avx512_blklen32(
const float* A,
size_t CountK,
std::byte* QuantA,
float* QuantAScale,
float* AScaledBlkSum // scale_k * Sum_blklen(a_i)
)
{
const size_t BlkLen = 32;

Check warning

Code scanning / PREfast

The const variable 'BlkLen' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'BlkLen' can be computed at compile-time. Consider using constexpr (con.5).
const int64_t SubBlkLen = 4 * BlkLen; // process 128 weights at a time and then process the remaining weights

Check warning

Code scanning / PREfast

The const variable 'SubBlkLen' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'SubBlkLen' can be computed at compile-time. Consider using constexpr (con.5).

const float* a_ptr = A;
int8_t* quant_a_ptr = reinterpret_cast<int8_t*>(QuantA);
float* scale_ptr = QuantAScale;
float* blksum_ptr = AScaledBlkSum;

size_t k_remaining = CountK;

for (; k_remaining >= SubBlkLen; k_remaining -= SubBlkLen) {
__m256i i_32_epi8[4];
float scale[4];
float blksum[4];
for (int i = 0; i < 4; i++) {
Quantize1BlkBlkLen32(a_ptr, BlkLen, i_32_epi8[i], scale[i], blksum[i]);
}
stoore_4blk_blklen32_interleaved(i_32_epi8, quant_a_ptr);
quant_a_ptr += BlkLen * 4;
std::copy(scale, scale + 4, scale_ptr);
scale_ptr += 4;
std::copy(blksum, blksum + 4, blksum_ptr);
blksum_ptr += 4;
}

while (k_remaining > 0) {
//for (size_t k = 0; k < CountK; k += BlkLen) {
const size_t step = std::min(BlkLen, k_remaining);
__m256i i_32_epi8;
float scale;
float blksum;
Quantize1BlkBlkLen32(a_ptr, BlkLen, i_32_epi8, scale, blksum);
_mm256_storeu_epi8(quant_a_ptr, i_32_epi8);
quant_a_ptr += BlkLen;
*scale_ptr = scale;
scale_ptr++;
*blksum_ptr = blksum;
blksum_ptr++;
k_remaining -= BlkLen;
}
}

static void
SQ4BitGemmPackQuantBDataAndBlkSum512(
size_t N,
Expand Down
120 changes: 50 additions & 70 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi
bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127
}

static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3};
static MLAS_FORCEINLINE
__m512 load_4blksum_512(const float* BlksumPtr)
{
// Load 128-bit data into __m128 register
__m128 blksum4_4_ps = _mm_loadu_ps(BlksumPtr);

// Insert the __m256 register into the lower 256 bits of the __m512 register
return _mm512_insertf32x4(_mm512_setzero_ps(), blksum4_4_ps, 0);
}

static MLAS_FORCEINLINE void
accumulate_blklen32_r1c1blk4_avx512(
Expand All @@ -36,20 +44,13 @@ accumulate_blklen32_r1c1blk4_avx512(
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps);

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
// __m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133

const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3

const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333
const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8);
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8);
const __m512i sum_32_epi16 = _mm512_add_epi16(dot0_32_epi16, dot1_32_epi16);
const __m512i one_32_epi16 = generate_ones_32_epi16();
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16);
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);
}
Expand All @@ -70,47 +71,37 @@ accumulate_blklen32_r2c1blk4_avx512(
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
// __m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133

const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3
const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 00112233 x 4 epi16s
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 00112233 x 4 epi16s
const __m512i sum_32_epi16 = _mm512_add_epi16(dot0_32_epi16, dot1_32_epi16);

const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333
const __m512i one_32_epi16 = generate_ones_32_epi16();
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0123 x 4 epi32s
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);

const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
const __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);
}
{
const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123
const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps);
__m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
// __m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133
const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8);
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8);
const __m512i sum_32_epi16 = _mm512_add_epi16(dot0_32_epi16, dot1_32_epi16);

const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1
const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3

const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333
const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333
const __m512i one_32_epi16 = generate_ones_32_epi16();
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16);
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);

const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123
const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps);
__m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps);

acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1);
}
}
Expand All @@ -122,30 +113,29 @@ accumulate_blklen32_r1c1blk4_avx512vnni(
const std::byte* QuantBDataPtr,
const float* scale_a,
const float* scale_b,
//const float* blksum_a,
//const float* blksum_b,
__m512& acc0
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4 (64 unsigned int8)

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
//__m512i idx = _mm512_loadu_epi8(&index_array[0]);
scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133
const __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111
const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333
__m512i sum_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8);
sum_16_epi32 = _mm512_dpbusd_epi32(sum_16_epi32, bv1_64_epi8, av1_64_epi8); // 0123012301230123

const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);

//const __m512 blksum_a0_ps = load_4blksum_512(blksum_a); // 0123000000000000
//const __m512 blksum_b0_ps = load_4blksum_512(blksum_b); // 0123000000000000
//acc0 = _mm512_fmadd_ps(blksum_a0_ps, blksum_b0_ps, acc0);
}
}

Expand All @@ -164,24 +154,17 @@ accumulate_blklen32_r2c1blk4_avx512vnni(
)
{
__m512i bv0_64_epi8, bv1_64_epi8;
load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8);
__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0);
//__m512i idx = _mm512_loadu_epi8(&index_array[0]);
load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); // 0000111122223333 x 4 (64 unsigned int8)

const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123
{
const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123
const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps);
__m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123

scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133

const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111
const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333
__m512i sum_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8);
sum_16_epi32 = _mm512_dpbusd_epi32(sum_16_epi32, bv1_64_epi8, av01_64_epi8); // 0123012301230123

const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0);
}
Expand All @@ -190,14 +173,9 @@ accumulate_blklen32_r2c1blk4_avx512vnni(
const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps);
__m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123

scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133

const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111
const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333
__m512i sum_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8);
sum_16_epi32 = _mm512_dpbusd_epi32(sum_16_epi32, bv1_64_epi8, av11_64_epi8); // 0123012301230123

const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133
const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133
const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133
const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32);
acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1);
}
Expand All @@ -208,6 +186,8 @@ accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_e
{
__m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8);
const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32);
// TODO: to compare with:
// acc = _mm256_fmadd_ps(sum_ps, _mm256_broadcast_ps((__m128 const*)(&combined_scale)), acc);
acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc);
}

Expand Down
Loading
Loading