diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 6d56845821f69..65f36ae7a6522 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -4894,7 +4894,6 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr const uint64_t *iq1s_grid, const uint8_t *ksigns_iq2xs, const uint8_t *kmask_iq2xs) { - const int i = item_ct1.get_group(2); const block_iq1_s * x = (const block_iq1_s *) vx; @@ -4903,11 +4902,15 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr const int il = tid/8; // 0...3 const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; - const int i8 = 4*ib+il; - uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); - const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); - const float d = (float)x[i].d * (2*(h & 7) + 1); - for (int j = 0; j < 8; ++j) y[j] = d * grid[j]; + const uint8_t * qs = x[i].qs + 8*ib; + const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]); + const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1); + const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } #else assert(false); #endif @@ -7808,23 +7811,22 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int ib32 = iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - const uint8_t h1 = bq1->scales[2*ib32+0]; - const uint8_t h2 = bq1->scales[2*ib32+1]; - const int * q8 = (const int *)bq8_1[ib32].qs; - const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); - const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); - const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); - const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); - for (int j = 0; j < 2; ++j) { - sumi1 = dpct::dp4a(q8[j+0], grid1[j], sumi1); - sumi2 = dpct::dp4a(q8[j+2], grid2[j], sumi2); - sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3); - sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4); - } - const float d = (float)bq1->d * bq8_1[ib32].ds[0]; - return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + - sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); + const uint8_t * qs = bq1->qs + 4*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); + q8 += 8; + } + const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f; + return d * sumi; #else assert(false); return 0.f;