From 4247e4aed631dae5c07d1a54140cf9fff789b797 Mon Sep 17 00:00:00 2001 From: Hoi Vo Date: Tue, 30 Apr 2024 11:41:51 -0700 Subject: [PATCH] Update more parallelization for mul_mat --- CMakeLists.txt | 6 +-- ggml-common.h | 4 +- ggml-impl.h | 2 +- ggml-quants.c | 144 ++++++++++++++++++++++++++++++++++++++++++++++--- ggml.c | 107 ++++++++++-------------------------- 5 files changed, 169 insertions(+), 94 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ed1e776455a11..df5db44f6a751 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,8 +999,8 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW if (LLAMA_NATIVE) include(cmake/FindSIMD.cmake) endif () - if (LLAMA_AVX512) - list(APPEND ARCH_FLAGS /arch:AVX512) + if (LLAMA_AVX512_) + list(APPEND ARCH_FLAGS /arch:AVX2 /arch:AVX512 /fp:fast /d2jumptablerdata /Zc:forScope /O2 /Ob1) # MSVC has no compile-time flags enabling specific # AVX512 extensions, neither it defines the # macros corresponding to the extensions. @@ -1014,7 +1014,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW add_compile_definitions($<$:__AVX512VNNI__>) endif() elseif (LLAMA_AVX2) - list(APPEND ARCH_FLAGS /arch:AVX2) + list(APPEND ARCH_FLAGS /arch:AVX2 /fp:fast /d2jumptablerdata /Zc:forScope /O2 /Ob1) elseif (LLAMA_AVX) list(APPEND ARCH_FLAGS /arch:AVX) endif() diff --git a/ggml-common.h b/ggml-common.h index 517c9bb43b380..58185b24d3b57 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -66,10 +66,10 @@ typedef sycl::half2 ggml_half2; // QK_K = super-block size #ifdef GGML_QKK_64 -#define QK_K 64 +#define QK_K 64u #define K_SCALE_SIZE 4 #else -#define QK_K 256 +#define QK_K 256u #define K_SCALE_SIZE 12 #endif // GGML_QKK_64 diff --git a/ggml-impl.h b/ggml-impl.h index e68b728775c41..0c997d3ed521f 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -88,7 +88,7 @@ typedef uint16_t ggml_fp16_internal_t; #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) #if !defined(__riscv) #include #endif diff --git a/ggml-quants.c b/ggml-quants.c index c35e4e3b877c8..419f62850c929 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3644,7 +3644,7 @@ void quantize_row_q8_K(const float * restrict x, block_q8_K * restrict y, uint32 #if defined(__AVX2__) || defined(__AVX512F__) for (int64_t i = 0; i < nb; i++) { -#ifdef __AVX512F__ +#if 0 // def __AVX512F__ __m512 ax[4]; __m512 maxabs = _mm512_set1_ps(0.0f); @@ -3705,7 +3705,7 @@ void quantize_row_q8_K(const float * restrict x, block_q8_K * restrict y, uint32 const float iscale = 126.99999f / amax; const float shift = 12582912.f; -#ifdef __AVX512F__ +#if 0 // def __AVX512F__ const __m512 xshift = _mm512_set1_ps(shift); const __m512 xscale = _mm512_set1_ps(iscale); @@ -5301,13 +5301,56 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sum; -#elif defined __AVX2__ +#elif defined(__AVX2__) || defined(__AVX512F__) + +#if 0 // def __AVX512F__ + + static __declspec(align(64)) const uint8_t shuf0[64] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3 + }; + + static __declspec(align(64)) const uint8_t shuf1[64] = { + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7 + }; + + static __declspec(align(64)) const uint8_t shuf2[64] = { + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11 + }; + + static __declspec(align(64)) const uint8_t shuf3[64] = { + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + + const __m512i m3 = _mm512_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m512 acc = _mm512_setzero_ps(); + + __m256 zero256 = _mm256_setzero_ps(); + __m512 zero512 = _mm512_setzero_ps(); + __m512i zero512i = _mm512_setzero_si512(); + +#else const __m256i m3 = _mm256_set1_epi8(3); const __m128i m4 = _mm_set1_epi8(0xF); __m256 acc = _mm256_setzero_ps(); +#endif // __AVX512F__ + for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); @@ -5322,6 +5365,84 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i mins = _mm256_cvtepi8_epi16(mins8); const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); +#if 0 // def __AVX512F__ + + const __m256 bdmin_ss = _mm256_broadcastss_ps(_mm_load_ss(&dmin)); + const __m256 acc_mins = _mm256_fmadd_ps(bdmin_ss, _mm256_cvtepi32_ps(prod), zero256); + + acc = _mm512_add_ps(acc, _mm512_insertf32x8(zero512, acc_mins, 0)); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + __m512i scales_low = _mm512_inserti64x2(_mm512_castps_si512(zero512), l_scales, 0); + scales_low = _mm512_inserti64x2(scales_low, l_scales, 1); + scales_low = _mm512_inserti64x2(scales_low, h_scales, 2); + scales_low = _mm512_inserti64x2(scales_low, h_scales, 3); + + const __m512i q2bits = _mm512_loadu_si512((const __m512i*)q2); + + const __m256i q8_0_low = _mm256_loadu_si256((const __m256i*)(q8 + 0)); + __m512i q8_0 = _mm512_inserti64x4(zero512i, q8_0_low, 0); + const __m256i q8_0_high = _mm256_loadu_si256((const __m256i*)(q8 + 128)); + q8_0 = _mm512_inserti64x4(q8_0, q8_0_high, 1); + + const __m256i q8_1_low = _mm256_loadu_si256((const __m256i*)(q8 + 32)); + __m512i q8_1 = _mm512_inserti64x4(zero512i, q8_1_low, 0); + const __m256i q8_1_high = _mm256_loadu_si256((const __m256i*)(q8 + 160)); + q8_1 = _mm512_inserti64x4(q8_1, q8_1_high, 1); + + const __m256i q8_2_low = _mm256_loadu_si256((const __m256i*)(q8 + 64)); + __m512i q8_2 = _mm512_inserti64x4(zero512i, q8_2_low, 0); + const __m256i q8_2_high = _mm256_loadu_si256((const __m256i*)(q8 + 192)); + q8_2 = _mm512_inserti64x4(q8_2, q8_2_high, 1); + + const __m256i q8_3_low = _mm256_loadu_si256((const __m256i*)(q8 + 96)); + __m512i q8_3 = _mm512_inserti64x4(zero512i, q8_3_low, 0); + const __m256i q8_3_high = _mm256_loadu_si256((const __m256i*)(q8 + 224)); + q8_3 = _mm512_inserti64x4(q8_3, q8_3_high, 1); + + const __m512i q2_0 = _mm512_and_si512(q2bits, m3); + const __m512i q2_1 = _mm512_and_si512(_mm512_srli_epi16(q2bits, 2), m3); + const __m512i q2_2 = _mm512_and_si512(_mm512_srli_epi16(q2bits, 4), m3); + const __m512i q2_3 = _mm512_and_si512(_mm512_srli_epi16(q2bits, 6), m3); + + __m512i p0 = _mm512_maddubs_epi16(q2_0, q8_0); + __m512i p1 = _mm512_maddubs_epi16(q2_1, q8_1); + __m512i p2 = _mm512_maddubs_epi16(q2_2, q8_2); + __m512i p3 = _mm512_maddubs_epi16(q2_3, q8_3); + + const __m512i ctrl0 = _mm512_load_si512((__m512i *)&shuf0); + const __m512i ctrl1 = _mm512_load_si512((__m512i *)&shuf1); + const __m512i ctrl2 = _mm512_load_si512((__m512i *)&shuf2); + const __m512i ctrl3 = _mm512_load_si512((__m512i *)&shuf3); + + __m512i v_scale0 = _mm512_shuffle_epi8(scales_low, ctrl0); + __m512i v_scale1 = _mm512_shuffle_epi8(scales_low, ctrl1); + __m512i v_scale2 = _mm512_shuffle_epi8(scales_low, ctrl2); + __m512i v_scale3 = _mm512_shuffle_epi8(scales_low, ctrl3); + + p0 = _mm512_madd_epi16(v_scale0, p0); + p1 = _mm512_madd_epi16(v_scale1, p1); + p2 = _mm512_madd_epi16(v_scale2, p2); + p3 = _mm512_madd_epi16(v_scale3, p3); + + p0 = _mm512_add_epi32(p0, p1); + p2 = _mm512_add_epi32(p2, p3); + p0 = _mm512_add_epi32(p0, p2); + + const __m512 bd_ss = _mm512_broadcastss_ps(_mm_load_ss(&d)); + acc = _mm512_fmadd_ps(bd_ss, _mm512_cvtepi32_ps(p0), acc); + } + + __m256 res = _mm512_extractf32x8_ps(acc, 1); + res = _mm256_add_ps(res, _mm512_castps512_ps256(acc)); +// __debugbreak(); + *s = hsum_float_8(res); + +#else + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); @@ -5350,10 +5471,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + __m256i v_scale0 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)); + __m256i v_scale1 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)); + __m256i v_scale2 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)); + __m256i v_scale3 = _mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)); + + p0 = _mm256_madd_epi16(v_scale0, p0); + p1 = _mm256_madd_epi16(v_scale1, p1); + p2 = _mm256_madd_epi16(v_scale2, p2); + p3 = _mm256_madd_epi16(v_scale3, p3); p0 = _mm256_add_epi32(p0, p1); p2 = _mm256_add_epi32(p2, p3); @@ -5362,11 +5488,13 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r } acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } +// __debugbreak(); *s = hsum_float_8(acc); +#endif // __AVX2__ || __AVX512F__ + #elif defined __AVX__ const __m128i m3 = _mm_set1_epi8(0x3); diff --git a/ggml.c b/ggml.c index c54a5ab6d8819..bc9f6c13823e0 100644 --- a/ggml.c +++ b/ggml.c @@ -1629,7 +1629,7 @@ void ggml_vec_add_f32(const int32_t n, float * z, const float * x, const float * #ifdef GGML_SIMD int64_t i = 0; -#ifdef __AVX512F__ +#if 0 // def __AVX512F__ const int64_t xn = (n & ~(GGML_F32_EPR16 - 1)); @@ -2055,7 +2055,7 @@ void ggml_vec_mul_f32(const int32_t n, float * z, const float * x, const float * #ifdef GGML_SIMD int64_t i = 0; -#ifdef __AVX512F__ +#if 0 // def __AVX512F__ const int64_t xn = (n & ~(GGML_F32_EPR16 - 1)); @@ -2390,7 +2390,7 @@ void ggml_vec_dot_f32(const int n, float * restrict s, size_t bs, const float * float sumf = 0.0f; #ifdef GGML_SIMD int64_t i = 0; -#ifdef __AVX512F__ +#if 0 // def __AVX512F__ const int64_t xn = (n & ~(GGML_F32_EPR16 - 1)); if (xn) { @@ -2503,7 +2503,7 @@ void ggml_vec_dot_f16(const int n, float * restrict s, size_t bs, const ggml_fp1 float sumf = 0.0; #if defined(GGML_SIMD) int64_t i = 0; -#ifdef __AVX512F__ +#if 0 // def __AVX512F__ const int64_t xn = (n & ~(GGML_F16_EPR16 - 1)); if (xn) { @@ -3398,9 +3398,6 @@ static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12"); // #ifdef GGML_TENSOR_OP_PERF -atomic_int64 mul_mat_time = 0; -atomic_int mul_mat_type = 0; -atomic_int64 mul_mat_element_sum = 0; #define GGML_TENSOR_NODE_COUNT 4096 atomic_int thread_create_count = 0; atomic_int64 thread_create_time = 0; @@ -3459,14 +3456,6 @@ print_tensor_op_perf_data ( printf("\n%8d %5.2f\n\n", total_count, total_percent); - printf("mul_mat init time %8.2fsec\n", (double)mul_mat_time / (1000. * 1000.)); - printf("mul_mat type mismatch %d\n", mul_mat_type); - printf("mul_mat element sum %zd\n", mul_mat_element_sum); - printf("mul_mat type mismatch ration %d / %d = %5.2f\n\n", - mul_mat_type, - compute_op_counts[GGML_OP_MUL_MAT], - (double)mul_mat_type / (double)compute_op_counts[GGML_OP_MUL_MAT]); - printf(" Total Total Tensor\n"); printf(" Count Time(sec) %% Time(ms) Tensor Op\n\n"); @@ -12103,17 +12092,29 @@ static void ggml_compute_forward_mul_mat( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS + const enum ggml_type type = src0->type; const int ith = params->ith; const int nth = params->nth; - const enum ggml_type type = src0->type; +#ifdef _WIN32 + + if (params->type == GGML_TASK_TYPE_INIT) { + +#ifdef GGML_TENSOR_OP_PERF + vec_dot_type_counts[type_traits[type].vec_dot_type] += 1; +#endif // GGML_TENSOR_OP_PERF + + *params->barrier0 = nth; + return; + } + +#endif // _WIN32 + + GGML_TENSOR_BINARY_OP_LOCALS const bool src1_cont = ggml_is_contiguous(src1); + ggml_vec_dot_t vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; @@ -12217,54 +12218,6 @@ static void ggml_compute_forward_mul_mat( } #endif - const enum ggml_type src1_type = src1->type; - -#if 0 // #ifdef GGML_TENSOR_OP_PERF - const bool init_mat = (vec_dot_type != src1_type); -#else - bool init_mat = ((vec_dot_type != src1_type) && - (vec_dot_type != GGML_TYPE_F16) && - (vec_dot_type != GGML_TYPE_Q8_K)); -#endif // GGML_TENSOR_OP_PERF - - if (params->type == GGML_TASK_TYPE_INIT) { - -#ifdef GGML_TENSOR_OP_PERF - vec_dot_type_counts[vec_dot_type] += 1; -#endif // GGML_TENSOR_OP_PERF - - *params->barrier0 = nth; - - if (init_mat) { - -#ifdef GGML_TENSOR_OP_PERF - int64_t t0 = ggml_time_us(); - atomic_fetch_add(&mul_mat_type, 1); - atomic_fetch_add64(&mul_mat_element_sum, ne00 * ne1 * ne12 * ne13); -#endif // GGML_TENSOR_OP_PERF - - char * wdata = params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - - assert(params->wsize >= ne11*ne12*ne13*row_size); - GGML_ASSERT(src1_type == GGML_TYPE_F32); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11), wdata, ne10); - wdata += row_size; - } - } - } -#ifdef GGML_TENSOR_OP_PERF - atomic_fetch_add64(&mul_mat_time, ggml_time_us() - t0); -#endif // GGML_TENSOR_OP_PERF - - } - - return; - } const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = ne1*ne12*ne13; // src1 rows @@ -12346,19 +12299,15 @@ static void ggml_compute_forward_mul_mat( atomic_int64 dot_time32 = 0; #endif // GGML_VECTOR_DOT_PERF + const enum ggml_type src1_type = src1->type; + const bool init_mat = ((vec_dot_type != src1_type) && (vec_dot_type != GGML_TYPE_F16)); + size_t row_size = ggml_row_size(vec_dot_type, ne10); char * wdata = src1->data; if (init_mat) { wdata = params->wdata; - } else if ((vec_dot_type != src1_type) && (vec_dot_type == GGML_TYPE_F16)) { - row_size = ggml_row_size(src1_type, ne10); - vec_dot = (ggml_vec_dot_t)ggml_vec_dot_f16_f32; - - } else if ((vec_dot_type != src1_type) && (vec_dot_type == GGML_TYPE_Q8_K)) { - wdata = params->wdata; - assert(params->wsize >= ne11*ne12*ne13*row_size); GGML_ASSERT(src1_type == GGML_TYPE_F32); @@ -12395,11 +12344,9 @@ static void ggml_compute_forward_mul_mat( } while (*params->barrier0); } - // - // Mark initialization done. - // - - init_mat = TRUE; + } else if (vec_dot_type != src1_type) { + row_size = ggml_row_size(src1_type, ne10); + vec_dot = (ggml_vec_dot_t)ggml_vec_dot_f16_f32; } // attempt to reduce false-sharing (does not seem to make a difference)