Skip to content

Commit

Permalink
Update more parallelization for mul_mat
Browse files Browse the repository at this point in the history
  • Loading branch information
HoiV committed Apr 30, 2024
1 parent 79e5396 commit 4247e4a
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 94 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -999,8 +999,8 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_NATIVE)
include(cmake/FindSIMD.cmake)
endif ()
if (LLAMA_AVX512)
list(APPEND ARCH_FLAGS /arch:AVX512)
if (LLAMA_AVX512_)
list(APPEND ARCH_FLAGS /arch:AVX2 /arch:AVX512 /fp:fast /d2jumptablerdata /Zc:forScope /O2 /Ob1)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
Expand All @@ -1014,7 +1014,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2 /fp:fast /d2jumptablerdata /Zc:forScope /O2 /Ob1)
elseif (LLAMA_AVX)
list(APPEND ARCH_FLAGS /arch:AVX)
endif()
Expand Down
4 changes: 2 additions & 2 deletions ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ typedef sycl::half2 ggml_half2;
// QK_K = super-block size

#ifdef GGML_QKK_64
#define QK_K 64
#define QK_K 64u
#define K_SCALE_SIZE 4
#else
#define QK_K 256
#define QK_K 256u
#define K_SCALE_SIZE 12
#endif // GGML_QKK_64

Expand Down
2 changes: 1 addition & 1 deletion ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ typedef uint16_t ggml_fp16_internal_t;
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
#if !defined(__riscv)
#include <immintrin.h>
#endif
Expand Down
144 changes: 136 additions & 8 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -3644,7 +3644,7 @@ void quantize_row_q8_K(const float * restrict x, block_q8_K * restrict y, uint32
#if defined(__AVX2__) || defined(__AVX512F__)
for (int64_t i = 0; i < nb; i++) {

#ifdef __AVX512F__
#if 0 // def __AVX512F__

__m512 ax[4];
__m512 maxabs = _mm512_set1_ps(0.0f);
Expand Down Expand Up @@ -3705,7 +3705,7 @@ void quantize_row_q8_K(const float * restrict x, block_q8_K * restrict y, uint32
const float iscale = 126.99999f / amax;
const float shift = 12582912.f;

#ifdef __AVX512F__
#if 0 // def __AVX512F__

const __m512 xshift = _mm512_set1_ps(shift);
const __m512 xscale = _mm512_set1_ps(iscale);
Expand Down Expand Up @@ -5301,13 +5301,56 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r

*s = sum;

#elif defined __AVX2__
#elif defined(__AVX2__) || defined(__AVX512F__)

#if 0 // def __AVX512F__

static __declspec(align(64)) const uint8_t shuf0[64] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3
};

static __declspec(align(64)) const uint8_t shuf1[64] = {
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7
};

static __declspec(align(64)) const uint8_t shuf2[64] = {
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11
};

static __declspec(align(64)) const uint8_t shuf3[64] = {
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
};

const __m512i m3 = _mm512_set1_epi8(3);
const __m128i m4 = _mm_set1_epi8(0xF);

__m512 acc = _mm512_setzero_ps();

__m256 zero256 = _mm256_setzero_ps();
__m512 zero512 = _mm512_setzero_ps();
__m512i zero512i = _mm512_setzero_si512();

#else

const __m256i m3 = _mm256_set1_epi8(3);
const __m128i m4 = _mm_set1_epi8(0xF);

__m256 acc = _mm256_setzero_ps();

#endif // __AVX512F__

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
Expand All @@ -5322,6 +5365,84 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i mins = _mm256_cvtepi8_epi16(mins8);
const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));

#if 0 // def __AVX512F__

const __m256 bdmin_ss = _mm256_broadcastss_ps(_mm_load_ss(&dmin));
const __m256 acc_mins = _mm256_fmadd_ps(bdmin_ss, _mm256_cvtepi32_ps(prod), zero256);

acc = _mm512_add_ps(acc, _mm512_insertf32x8(zero512, acc_mins, 0));

const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);

const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
__m512i scales_low = _mm512_inserti64x2(_mm512_castps_si512(zero512), l_scales, 0);
scales_low = _mm512_inserti64x2(scales_low, l_scales, 1);
scales_low = _mm512_inserti64x2(scales_low, h_scales, 2);
scales_low = _mm512_inserti64x2(scales_low, h_scales, 3);

const __m512i q2bits = _mm512_loadu_si512((const __m512i*)q2);

const __m256i q8_0_low = _mm256_loadu_si256((const __m256i*)(q8 + 0));
__m512i q8_0 = _mm512_inserti64x4(zero512i, q8_0_low, 0);
const __m256i q8_0_high = _mm256_loadu_si256((const __m256i*)(q8 + 128));
q8_0 = _mm512_inserti64x4(q8_0, q8_0_high, 1);

const __m256i q8_1_low = _mm256_loadu_si256((const __m256i*)(q8 + 32));
__m512i q8_1 = _mm512_inserti64x4(zero512i, q8_1_low, 0);
const __m256i q8_1_high = _mm256_loadu_si256((const __m256i*)(q8 + 160));
q8_1 = _mm512_inserti64x4(q8_1, q8_1_high, 1);

const __m256i q8_2_low = _mm256_loadu_si256((const __m256i*)(q8 + 64));
__m512i q8_2 = _mm512_inserti64x4(zero512i, q8_2_low, 0);
const __m256i q8_2_high = _mm256_loadu_si256((const __m256i*)(q8 + 192));
q8_2 = _mm512_inserti64x4(q8_2, q8_2_high, 1);

const __m256i q8_3_low = _mm256_loadu_si256((const __m256i*)(q8 + 96));
__m512i q8_3 = _mm512_inserti64x4(zero512i, q8_3_low, 0);
const __m256i q8_3_high = _mm256_loadu_si256((const __m256i*)(q8 + 224));
q8_3 = _mm512_inserti64x4(q8_3, q8_3_high, 1);

const __m512i q2_0 = _mm512_and_si512(q2bits, m3);
const __m512i q2_1 = _mm512_and_si512(_mm512_srli_epi16(q2bits, 2), m3);
const __m512i q2_2 = _mm512_and_si512(_mm512_srli_epi16(q2bits, 4), m3);
const __m512i q2_3 = _mm512_and_si512(_mm512_srli_epi16(q2bits, 6), m3);

__m512i p0 = _mm512_maddubs_epi16(q2_0, q8_0);
__m512i p1 = _mm512_maddubs_epi16(q2_1, q8_1);
__m512i p2 = _mm512_maddubs_epi16(q2_2, q8_2);
__m512i p3 = _mm512_maddubs_epi16(q2_3, q8_3);

const __m512i ctrl0 = _mm512_load_si512((__m512i *)&shuf0);
const __m512i ctrl1 = _mm512_load_si512((__m512i *)&shuf1);
const __m512i ctrl2 = _mm512_load_si512((__m512i *)&shuf2);
const __m512i ctrl3 = _mm512_load_si512((__m512i *)&shuf3);

__m512i v_scale0 = _mm512_shuffle_epi8(scales_low, ctrl0);
__m512i v_scale1 = _mm512_shuffle_epi8(scales_low, ctrl1);
__m512i v_scale2 = _mm512_shuffle_epi8(scales_low, ctrl2);
__m512i v_scale3 = _mm512_shuffle_epi8(scales_low, ctrl3);

p0 = _mm512_madd_epi16(v_scale0, p0);
p1 = _mm512_madd_epi16(v_scale1, p1);
p2 = _mm512_madd_epi16(v_scale2, p2);
p3 = _mm512_madd_epi16(v_scale3, p3);

p0 = _mm512_add_epi32(p0, p1);
p2 = _mm512_add_epi32(p2, p3);
p0 = _mm512_add_epi32(p0, p2);

const __m512 bd_ss = _mm512_broadcastss_ps(_mm_load_ss(&d));
acc = _mm512_fmadd_ps(bd_ss, _mm512_cvtepi32_ps(p0), acc);
}

__m256 res = _mm512_extractf32x8_ps(acc, 1);
res = _mm256_add_ps(res, _mm512_castps512_ps256(acc));
// __debugbreak();
*s = hsum_float_8(res);

#else

acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);

const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
Expand Down Expand Up @@ -5350,10 +5471,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
__m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
__m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);

p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
__m256i v_scale0 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0));
__m256i v_scale1 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1));
__m256i v_scale2 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2));
__m256i v_scale3 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3));

p0 = _mm256_madd_epi16(v_scale0, p0);
p1 = _mm256_madd_epi16(v_scale1, p1);
p2 = _mm256_madd_epi16(v_scale2, p2);
p3 = _mm256_madd_epi16(v_scale3, p3);

p0 = _mm256_add_epi32(p0, p1);
p2 = _mm256_add_epi32(p2, p3);
Expand All @@ -5362,11 +5488,13 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
}

acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);

}

// __debugbreak();
*s = hsum_float_8(acc);

#endif // __AVX2__ || __AVX512F__

#elif defined __AVX__

const __m128i m3 = _mm_set1_epi8(0x3);
Expand Down
Loading

0 comments on commit 4247e4a

Please sign in to comment.