From 0402d4f8a049b9fc4402a9245ab863f017f36a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 24 Jun 2024 20:50:41 +0200 Subject: [PATCH] CUDA: fix MMQ writeback for int8 tensor cores --- ggml-cuda/mmq.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 1fc948be5bbe8..31fcbf1397b6b 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma( static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); #endif // INT8_MMA_AVAILABLE - dst += (threadIdx.y % ntx) * mma_C::J*stride; - #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + mma_C::get_j(l); + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); if (j > j_max) { continue;