Skip to content

Commit

Permalink
simple code for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
arthw committed Aug 1, 2024
1 parent 1947c12 commit 6211ac0
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions ggml/src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2816,8 +2816,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
}
}

for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
continue;
}
Expand Down Expand Up @@ -2882,8 +2881,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;

for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
continue;
}
Expand Down Expand Up @@ -3025,8 +3023,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;

ggml_sycl_set_device(ctx.device);
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
if (dev[id].row_low == dev[id].row_high) {
continue;
}
Expand Down Expand Up @@ -4343,12 +4340,9 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device_id) {
static bool ggml_backend_sycl_buffer_type_initialized = false;

if (!ggml_backend_sycl_buffer_type_initialized) {
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
auto & device = dpct::dev_mgr::instance().get_device(id);
// queue_ptr stream = &(device.default_queue());
queue_ptr stream = ggml_sycl_info().device_infos[id].qptrs[0];

ggml_backend_sycl_buffer_types[id] = {
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
/* .context = */ new ggml_backend_sycl_buffer_type_context{id, GGML_SYCL_NAME + std::to_string(id), stream},
Expand All @@ -4369,8 +4363,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_conte
static bool ggml_backend_sycl_buffer_type_initialized = false;

if (!ggml_backend_sycl_buffer_type_initialized) {
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
ggml_backend_sycl_buffer_types[id] = {
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
/* .context = */ new ggml_backend_sycl_buffer_type_context{id, GGML_SYCL_NAME + std::to_string(id), ctx->stream(id, 0)},
Expand Down Expand Up @@ -4399,8 +4392,7 @@ static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tens
struct ggml_backend_sycl_split_buffer_context {
~ggml_backend_sycl_split_buffer_context() try {
for (ggml_tensor_extra_gpu * extra : tensor_extras) {
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
if (extra->events[id][is] != nullptr) {
/*
Expand Down Expand Up @@ -5169,8 +5161,7 @@ extern "C" int ggml_backend_sycl_reg_devices();

int ggml_backend_sycl_reg_devices() {
assert(ggml_sycl_info().device_count>0);
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
int id = ggml_backend_sycl_get_device_id(i);
for (auto & id: ggml_sycl_info().ids) {
char name[128];
snprintf(name, sizeof(name), "%s%d", GGML_SYCL_NAME, id);
ggml_backend_register(name, ggml_backend_reg_sycl_init, ggml_backend_sycl_buffer_type(id), (void *) (intptr_t) id);
Expand Down

0 comments on commit 6211ac0

Please sign in to comment.