Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Fix batched impl for NVidia GPU #6164

Merged
merged 3 commits into from Mar 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 25 additions & 4 deletions ggml-sycl.cpp
Expand Up @@ -15246,6 +15246,9 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];

bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;

SYCL_CHECK(
CHECK_TRY_ERROR(g_sycl_handles[g_main_device] = main_stream));

Expand Down Expand Up @@ -15276,24 +15279,38 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,

dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
if (no_mixed_dtypes) {
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
cu_compute_type = dpct::library_data_t::real_half;
cu_data_type = dpct::library_data_t::real_half;
}

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const void * alpha = &alpha_f32;
const void * beta = &beta_f32;
if (no_mixed_dtypes) {
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
alpha = &alpha_f16;
beta = &beta_f16;
}

// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
// oneMKL open source supports half, half, float, float: datatypes
// when oneMKL open source supports half, half, float, float: datatypes

dst_t = (char *) dst_ddf;
if (no_mixed_dtypes) {
dst_t = (char *) dst_f16.alloc(ne_dst);

nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand Down Expand Up @@ -15379,6 +15396,10 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
}
#endif

if (no_mixed_dtypes) {
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
Expand Down