Skip to content

Commit

Permalink
fix mul_mat_id to match the change of api
Browse files Browse the repository at this point in the history
  • Loading branch information
arthw committed May 21, 2024
1 parent 917dc8c commit 4e3db77
Showing 1 changed file with 256 additions and 54 deletions.
310 changes: 256 additions & 54 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,57 @@ namespace dpct
using shared_memory = detail::device_memory<T, shared, Dimension>;


template <typename T,
sycl::access::address_space addressSpace =
sycl::access::address_space::global_space,
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
sycl::memory_scope memoryScope = sycl::memory_scope::device>
inline T atomic_fetch_add(T *addr, T operand) {
auto atm =
sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
return atm.fetch_add(operand);
}

template <sycl::access::address_space addressSpace =
sycl::access::address_space::global_space,
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
sycl::memory_scope memoryScope = sycl::memory_scope::device,
typename T1, typename T2>
inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
auto atm =
sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
return atm.fetch_add(operand);
}

template <typename T, sycl::access::address_space addressSpace =
sycl::access::address_space::global_space>
inline T atomic_fetch_add(T *addr, T operand,
sycl::memory_order memoryOrder) {
switch (memoryOrder) {
case sycl::memory_order::relaxed:
return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
sycl::memory_scope::device>(addr, operand);
case sycl::memory_order::acq_rel:
return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
sycl::memory_scope::device>(addr, operand);
case sycl::memory_order::seq_cst:
return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
sycl::memory_scope::device>(addr, operand);
default:
assert(false && "Invalid memory_order for atomics. Valid memory_order for "
"atomics are: sycl::memory_order::relaxed, "
"sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
}
}

template <sycl::access::address_space addressSpace =
sycl::access::address_space::global_space,
typename T1, typename T2>
inline T1 atomic_fetch_add(T1 *addr, T2 operand,
sycl::memory_order memoryOrder) {
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
}

} // COPY from DPCT head files

#define GGML_COMMON_DECL_SYCL
Expand Down Expand Up @@ -3060,6 +3111,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
bool ggml_backend_is_sycl(ggml_backend_t backend);
int ggml_backend_sycl_get_device(ggml_backend_t backend);
int get_main_device();
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
void print_ggml_tensor(const char*name, struct ggml_tensor *src);
void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);

Expand Down Expand Up @@ -15899,22 +15951,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
}
#endif

struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};

__dpct_inline__ static void k_copy_src1_to_contiguous(
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
const sycl::nd_item<3> &item_ct1, int &src1_row) {
int32_t iid1 = item_ct1.get_group(2);
int32_t id = item_ct1.get_group(1);

const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);

if (row_id_i != i02) {
return;
}

const int64_t i11 = id % ne11;
const int64_t i12 = iid1;

if (item_ct1.get_local_id(2) == 0) {
src1_row =
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
cur_src1_row, 1);
row_mapping[src1_row] = {id, iid1};
}
/*
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
performance if there is no access to global memory.
*/
item_ct1.barrier();

const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);

#pragma unroll
for (int i = item_ct1.get_local_id(2); i < ne10;
i += item_ct1.get_local_range(2)) {
src1_row_contiguous[i] = src1_row_original[i];
}
}

__dpct_inline__ static void k_copy_dst_from_contiguous(
char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
size_t nb2, const sycl::nd_item<3> &item_ct1) {
int32_t i = item_ct1.get_group(2);

const int32_t i1 = row_mapping[i].i1;
const int32_t i2 = row_mapping[i].i2;

const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);

#pragma unroll
for (int j = item_ct1.get_local_id(2); j < ne0;
j += item_ct1.get_local_range(2)) {
dst_row_original[j] = dst_row_contiguous[j];
}
}

static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
"mul_mat_id does not support split buffers");
const ggml_tensor *ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];

const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];

const int32_t id = ((int32_t *)dst->op_params)[0];
const int32_t n_as = src0->ne[2];
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];

std::vector<char> ids_host(ggml_nbytes(ids));
const char *ids_dev = (const char *)ids->data;
const char * ids_dev = (const char *) ids->data;

SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
Expand Down Expand Up @@ -15954,24 +16070,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,

src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];

if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id =
*(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
id * ids->nb[0]);

GGML_ASSERT(row_id >= 0 && row_id < n_as);
src0_row.nb[3] = nb02;

src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;

dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);

const int64_t i11 = id % ne11;
const int64_t i12 = iid1;

const int64_t i1 = id;
const int64_t i2 = i12;

src0_row_extra.data_device[g_main_device] =
src0_original + row_id * src0->nb[2];
src0_original + i02*nb02;
src1_row_extra.data_device[g_main_device] =
src1_original + i01 * src1->nb[1];
src1_original + + i11*nb11 + i12*nb12;
dst_row_extra.data_device[g_main_device] =
dst_original + i01 * dst->nb[1];
dst_original + i1*nb1 + i2*nb2;

ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
}
}
} else {
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
Expand All @@ -15980,63 +16112,134 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();

for (int32_t row_id = 0; row_id < n_as; ++row_id) {
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

if (row_id_i != row_id) {
continue;
}
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);

GGML_ASSERT(row_id >= 0 && row_id < n_as);
if (row_id_i != i02) {
continue;
}

SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
src1_original + i01 * nb11, nb11)));
num_src1_rows++;
num_src1_rows++;
}
}

if (num_src1_rows == 0) {
continue;
}

src0_row_extra.data_device[g_main_device] =
src0_original + row_id * src0->nb[2];

sycl_pool_alloc<int> dev_cur_src1_row(1);
sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));

{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
/*
DPCT1049:81: The work-group size passed to the SYCL kernel may
exceed the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if
needed.
*/
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<int, 0> src1_row_acc_ct1(cgh);

char *__restrict src1_contiguous_get_ct1 =
src1_contiguous.get();
int *__restrict dev_cur_src1_row_get_ct2 =
dev_cur_src1_row.get();
mmid_row_mapping *__restrict dev_row_mapping_get_ct3 =
dev_row_mapping.get();
size_t ids_nb_ct6 = ids->nb[1];
size_t ids_nb_ct7 = ids->nb[0];

cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get_ct1,
dev_cur_src1_row_get_ct2,
dev_row_mapping_get_ct3, ids_dev, i02,
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
item_ct1, src1_row_acc_ct1);
});
});
/*
DPCT1010:196: SYCL uses exceptions to report errors and does not
use the error codes. The call was replaced with 0. You need to
rewrite this code.
*/
/*
DPCT1009:197: SYCL uses exceptions to report errors and does not
use the error codes. The call was replaced by a placeholder
string. You need to rewrite this code.
*/
SYCL_CHECK(0);
}

src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;

GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;

src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;

dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;

ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);

num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);

if (row_id_i != row_id) {
continue;
}

GGML_ASSERT(row_id >= 0 && row_id < n_as);

SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
dst_original + i01 * nb1,
dst_contiguous.get() + num_src1_rows * nb1, nb1)));
num_src1_rows++;
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
sycl::range<3> grid_dims(1, 1, num_src1_rows);
/*
DPCT1049:82: The work-group size passed to the SYCL kernel may
exceed the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if
needed.
*/
stream->submit([&](sycl::handler &cgh) {
const char *__restrict dst_contiguous_get_ct1 =
dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get_ct2 =
dev_row_mapping.get();

cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get_ct1,
dev_row_mapping_get_ct2,
ne0, nb1, nb2, item_ct1);
});
});
/*
DPCT1010:198: SYCL uses exceptions to report errors and does not
use the error codes. The call was replaced with 0. You need to
rewrite this code.
*/
/*
DPCT1009:199: SYCL uses exceptions to report errors and does not
use the error codes. The call was replaced by a placeholder
string. You need to rewrite this code.
*/
SYCL_CHECK(0);
}
}
}

if (dst->backend == GGML_BACKEND_TYPE_CPU) {
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
}
}
}
catch (sycl::exception const &exc) {
Expand Down Expand Up @@ -17020,10 +17223,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
UNUSED(buffer);
}

// unused at the moment
//static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
// return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
//}
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
}

GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
Expand Down

0 comments on commit 4e3db77

Please sign in to comment.