Skip to content

Commit

Permalink
simplify code, make functions constexpr
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jun 22, 2024
1 parent 79e1e30 commit bc2cbd5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 57 deletions.
4 changes: 2 additions & 2 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};

static int get_mmq_x_max_host(const int cc) {
static constexpr int get_mmq_x_max_host(int cc) {
#ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else
Expand All @@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
}

// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc) {
static constexpr int get_mmq_y_host(int cc) {
return cc >= CC_VOLTA ? 128 : 64;
}

Expand Down
99 changes: 44 additions & 55 deletions ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() {
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}

#define GET_MMQ_DP4A_TXS_BODY \
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \
tile_x_sizes{0, 0, 0}

static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
GET_MMQ_DP4A_TXS_BODY;
}

template <int mmq_y>
static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device(ggml_type type) {
GET_MMQ_DP4A_TXS_BODY;
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
tile_x_sizes{0, 0, 0};
}

#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
Expand All @@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");

#define MMQ_MMA_GET_TILE_X_K_BODY \
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
0

static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
MMQ_MMA_GET_TILE_X_K_BODY;
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
0;
}

#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
Expand Down Expand Up @@ -154,7 +143,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
Expand Down Expand Up @@ -204,7 +193,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_0);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
Expand Down Expand Up @@ -317,7 +306,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
Expand Down Expand Up @@ -367,7 +356,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_1);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
Expand Down Expand Up @@ -479,7 +468,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
Expand Down Expand Up @@ -548,7 +537,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_0);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
Expand Down Expand Up @@ -644,7 +633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
Expand Down Expand Up @@ -711,7 +700,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_1);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
Expand Down Expand Up @@ -808,7 +797,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_tile + WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
Expand Down Expand Up @@ -858,7 +847,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q8_0);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
Expand Down Expand Up @@ -954,7 +943,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE
Expand Down Expand Up @@ -1013,7 +1002,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q2_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
Expand Down Expand Up @@ -1135,7 +1124,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + WARP_SIZE*2);
int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
Expand Down Expand Up @@ -1233,7 +1222,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q3_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * x_sc = (const int *) x_df + txs.dm;
Expand Down Expand Up @@ -1361,7 +1350,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
Expand Down Expand Up @@ -1437,7 +1426,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q4_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm;
Expand Down Expand Up @@ -1578,7 +1567,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
int * x_qs = (int *) x_tile;
half2 * x_dm = (half2 *) (x_qs + txs.qs);
int * x_sc = (int *) (x_dm + txs.dm);
Expand Down Expand Up @@ -1668,7 +1657,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q5_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + txs.qs;
const int * x_sc = (const int *) x_dm + txs.dm;
Expand Down Expand Up @@ -1800,7 +1789,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + WARP_SIZE*2);
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
int * x_sc = (int *) (x_df + txs.dm);
Expand Down Expand Up @@ -1882,7 +1871,7 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y>(GGML_TYPE_Q6_K);
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * x_sc = (const int *) x_df + txs.dm;
Expand Down Expand Up @@ -2422,7 +2411,7 @@ struct mmq_args {

template<ggml_type type>
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
Expand Down

0 comments on commit bc2cbd5

Please sign in to comment.