Skip to content

Commit

Permalink
CUDA: fix MMQ writeback for int8 tensor cores ggerganov#8100
Browse files Browse the repository at this point in the history
Co-Authored-By: Johannes Gäßler <johannesg@5d6.de>
  • Loading branch information
Nexesenex and JohannesGaessler committed Jun 25, 2024
1 parent 54e1ff4 commit 0d5cc68
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2035,15 +2035,13 @@ static __device__ __forceinline__ void mmq_write_back_mma(
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
#endif // INT8_MMA_AVAILABLE

dst += (threadIdx.y % ntx) * mma_C::J*stride;

#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const int j = j0 + mma_C::get_j(l);
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);

if (j > j_max) {
continue;
Expand Down

0 comments on commit 0d5cc68

Please sign in to comment.