Skip to content

Commit

Permalink
fix UT of concat
Browse files Browse the repository at this point in the history
  • Loading branch information
arthw committed Jul 14, 2024
1 parent e700d37 commit a364ec7
Showing 1 changed file with 4 additions and 34 deletions.
38 changes: 4 additions & 34 deletions ggml/src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3343,10 +3343,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();;

bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;


void * src0_ddq = src0->data;
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
float * src1_ddf = (float *) src1->data;
Expand All @@ -3364,15 +3360,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
: src1_f16_alloc.get();

ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t;

dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
if (no_mixed_dtypes) {
cu_compute_type = dpct::library_data_t::real_half;
cu_data_type = dpct::library_data_t::real_half;
}

// dst strides
size_t nbd2 = dst->nb[2];
Expand All @@ -3381,26 +3372,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;

const void * alpha = &alpha_f32;
const void * beta = &beta_f32;
if (no_mixed_dtypes) {
alpha = &alpha_f16;
beta = &beta_f16;
}

// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
// when oneMKL open source supports half, half, float, float: datatypes

dst_t = (char *) dst_ddf;
if (no_mixed_dtypes) {
dst_t = (char *) dst_f16.alloc(ne_dst);

nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand Down Expand Up @@ -3462,11 +3437,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
cu_compute_type)));
}

if (no_mixed_dtypes) {
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
Expand Down Expand Up @@ -5069,9 +5039,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons

ggml_type a_type = a->type;

if (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MUL_MAT_ID){
if (op->src[0]->type == GGML_TYPE_BF16) return false;
}

if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
Expand All @@ -5082,10 +5049,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
return false;
}
}

ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16) {
return false;
}

return true;
} break;
case GGML_OP_GET_ROWS:
Expand Down Expand Up @@ -5133,7 +5102,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONCAT:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
int dim = op->op_params[0];
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
} break;
case GGML_OP_DUP:
case GGML_OP_NONE:
Expand Down

0 comments on commit a364ec7

Please sign in to comment.