diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 1fc948be5bbe8..31fcbf1397b6b 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma( static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); #endif // INT8_MMA_AVAILABLE - dst += (threadIdx.y % ntx) * mma_C::J*stride; - #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + mma_C::get_j(l); + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); if (j > j_max) { continue;