-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matmul nbits to optimize memory layout for avx instructions #22203
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
TEST(MatMulNBits, Float32_Accuracy4) { | ||
TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>(); | ||
TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>(); | ||
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>(); | ||
//TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>(); | ||
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>(); | ||
//TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TEST(MatMulNBits, Float32_Accuracy4) { | |
TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>(); | |
TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>(); | |
TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>(); | |
TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>(); | |
TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>(); | |
TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>(); | |
TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>(); | |
TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>(); | |
TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>(); | |
TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>(); | |
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>(); | |
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>(); | |
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>(); | |
//TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>(); | |
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>(); | |
//TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>(); | |
} | |
TEST(MatMulNBits, Float32_Accuracy4) { | |
// TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>(); | |
// TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>(); | |
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>(); | |
// TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>(); | |
} |
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
float* AScaledBlkSum // scale_k * Sum_blklen(a_i) | ||
) | ||
{ | ||
const size_t BlkLen = 32; |
Check warning
Code scanning / PREfast
The const variable 'BlkLen' can be computed at compile-time. Consider using constexpr (con.5). Warning
) | ||
{ | ||
const size_t BlkLen = 32; | ||
const int64_t SubBlkLen = 4 * BlkLen; // process 128 weights at a time and then process the remaining weights |
Check warning
Code scanning / PREfast
The const variable 'SubBlkLen' can be computed at compile-time. Consider using constexpr (con.5). Warning
|
||
// Convert int32 to int8 | ||
i_16_epi8[index] = _mm512_cvtepi32_epi8(i0); | ||
//_mm_storeu_si128(dst++, i0_8); |
Check notice
Code scanning / CodeQL
Commented-out code
} | ||
|
||
while (k_remaining > 0) { | ||
// for (size_t k = 0; k < CountK; k += BlkLen) { |
Check notice
Code scanning / CodeQL
Commented-out code
} | ||
|
||
TEST(MatMulNBits, LongTestFloat32) { | ||
// onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json"); |
Check notice
Code scanning / CodeQL
Commented-out code
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
…hus not to implement avx512 Signed-off-by: liqunfu <liqun.fu@microsoft.com>
… to be in a separate loop. defer this work later Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
static MLAS_FORCEINLINE | ||
__m512 load_1blksum_512(const float* BlksumPtr) { | ||
// Create a mask to set only the lowest element | ||
const __mmask16 mask = 0x01; // Binary: 0000 0000 0000 0001 |
Check warning
Code scanning / PREfast
The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning
// Function to load a single float value into the lowest element of a __m256 register | ||
static MLAS_FORCEINLINE | ||
__m256 load_1blksum_256(const float* BlksumPtr) { | ||
const __mmask8 mask = 0x01; // Binary: 0000 0001 |
Check warning
Code scanning / PREfast
The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning
@@ -32,6 +32,42 @@ | |||
return result; | |||
} | |||
|
|||
static MLAS_FORCEINLINE | |||
__m512 load_broadcast_512(const float& combined_scale) { | |||
const __mmask16 mask = 00000001; // Binary: 0000 0000 0000 0001, lowest element |
Check warning
Code scanning / PREfast
The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning
// const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); | ||
// return; | ||
|
||
const __mmask8 mask = 0xff; // Binary: 1111 1111, to set all elements |
Check warning
Code scanning / PREfast
The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning
Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
Signed-off-by: Liqun Fu <liqun_fu@hotmail.com>
Description
Motivation and Context