Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul nbits to optimize memory layout for avx instructions #22203

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

liqunfu
Copy link
Contributor

@liqunfu liqunfu commented Sep 24, 2024

Description

Motivation and Context

Signed-off-by: liqunfu <liqun.fu@microsoft.com>
@liqunfu liqunfu requested a review from a team as a code owner September 24, 2024 15:50
@liqunfu liqunfu marked this pull request as draft September 24, 2024 15:50
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 367 to 394
TEST(MatMulNBits, Float32_Accuracy4) {
TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
//TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TEST(MatMulNBits, Float32_Accuracy4) {
TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
//TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
//TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>();
//TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>();
}
TEST(MatMulNBits, Float32_Accuracy4) {
// TestMatMulNBitsTyped<float, 1, 1, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 1, 2, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 1, 32, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 1, 32, 32, 16, 4>();
// TestMatMulNBitsTyped<float, 1, 32, 16, 128, 4>();
// TestMatMulNBitsTyped<float, 1, 288, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 1, 288, 1024, 16, 4>();
// TestMatMulNBitsTyped<float, 1, 288, 1024, 128, 4>();
// TestMatMulNBitsTyped<float, 1, 288, 93, 32, 4>();
// TestMatMulNBitsTyped<float, 1, 288, 93, 128, 4>();
// TestMatMulNBitsTyped<float, 1, 288, 1234, 16, 4>();
// TestMatMulNBitsTyped<float, 2, 1, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 2, 2, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 1, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 2, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 32, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 32, 32, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 32, 16, 128, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>();
// TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>();
}

float* AScaledBlkSum // scale_k * Sum_blklen(a_i)
)
{
const size_t BlkLen = 32;

Check warning

Code scanning / PREfast

The const variable 'BlkLen' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'BlkLen' can be computed at compile-time. Consider using constexpr (con.5).
)
{
const size_t BlkLen = 32;
const int64_t SubBlkLen = 4 * BlkLen; // process 128 weights at a time and then process the remaining weights

Check warning

Code scanning / PREfast

The const variable 'SubBlkLen' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'SubBlkLen' can be computed at compile-time. Consider using constexpr (con.5).

// Convert int32 to int8
i_16_epi8[index] = _mm512_cvtepi32_epi8(i0);
//_mm_storeu_si128(dst++, i0_8);

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
}

while (k_remaining > 0) {
// for (size_t k = 0; k < CountK; k += BlkLen) {

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
}

TEST(MatMulNBits, LongTestFloat32) {
// onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json");

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
…hus not to implement avx512

Signed-off-by: liqunfu <liqun.fu@microsoft.com>
… to be in a separate loop. defer this work later

Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
static MLAS_FORCEINLINE
__m512 load_1blksum_512(const float* BlksumPtr) {
// Create a mask to set only the lowest element
const __mmask16 mask = 0x01; // Binary: 0000 0000 0000 0001

Check warning

Code scanning / PREfast

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5).
// Function to load a single float value into the lowest element of a __m256 register
static MLAS_FORCEINLINE
__m256 load_1blksum_256(const float* BlksumPtr) {
const __mmask8 mask = 0x01; // Binary: 0000 0001

Check warning

Code scanning / PREfast

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5).
@@ -32,6 +32,42 @@
return result;
}

static MLAS_FORCEINLINE
__m512 load_broadcast_512(const float& combined_scale) {
const __mmask16 mask = 00000001; // Binary: 0000 0000 0000 0001, lowest element

Check warning

Code scanning / PREfast

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5).
// const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps);
// return;

const __mmask8 mask = 0xff; // Binary: 1111 1111, to set all elements

Check warning

Code scanning / PREfast

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'mask' can be computed at compile-time. Consider using constexpr (con.5).
liqunfu and others added 2 commits December 13, 2024 10:09
Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
Signed-off-by: Liqun Fu <liqun_fu@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant