Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make IQ1_M work for QK_K = 64 #6327

Merged
merged 3 commits into from Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions ggml-common.h
Expand Up @@ -377,13 +377,20 @@ typedef struct {
} block_iq1_s;
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");

// 1.8125 bpw
// 1.75 bpw
typedef struct {
uint8_t qs[QK_K/8]; // grid index, low 8 bits
uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
uint8_t scales[QK_K/32]; // 4-bit block scales
#if QK_K == 64
ggml_half d;
#endif
uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
} block_iq1_m;
#if QK_K == 64
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding");
#else
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
#endif

// Used by IQ1_M quants
typedef union {
Expand Down
20 changes: 19 additions & 1 deletion ggml-metal.metal
Expand Up @@ -4497,7 +4497,9 @@ void kernel_mul_mv_iq1_m_f32_impl(

device const float * y4 = y + 32 * ix;

#if QK_K != 64
iq1m_scale_t scale;
#endif

for (int ib32 = ix; ib32 < nb32; ib32 += 32) {

Expand All @@ -4519,7 +4521,9 @@ void kernel_mul_mv_iq1_m_f32_impl(

for (int row = 0; row < N_DST; row++) {

#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
#endif

constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
Expand All @@ -4535,8 +4539,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
}
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
#if QK_K == 64
const float d = (float) *((device const half *)(sc - 1));
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
#else
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
#endif

sc += nb*sizeof(block_iq1_m)/2;
qs += nb*sizeof(block_iq1_m);
Expand Down Expand Up @@ -5277,13 +5287,21 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
iq1m_scale_t scale;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
#if QK_K == 64
const float d = xb->d;
#else
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
#endif
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
#if QK_K == 64
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
#else
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
#endif
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
Expand Down
70 changes: 69 additions & 1 deletion ggml-quants.c
Expand Up @@ -3481,19 +3481,30 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
float delta[4];
uint16_t idx[4];

#if QK_K != 64
iq1m_scale_t scale;
#endif

for (int i = 0; i < nb; i++) {

const uint16_t * sc = (const uint16_t *)x[i].scales;
#if QK_K == 64
const float d = GGML_FP16_TO_FP32(x[i].d);
#else
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = GGML_FP16_TO_FP32(scale.f16);
#endif
const uint8_t * qs = x[i].qs;
const uint8_t * qh = x[i].qh;

for (int ib = 0; ib < QK_K/32; ++ib) {
#if QK_K == 64
const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1);
const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1);
#else
const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
#endif
idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
Expand Down Expand Up @@ -9756,11 +9767,17 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void

const int nb = n / QK_K;

#if QK_K != 64
iq1m_scale_t scale;
#endif

#if defined __ARM_NEON

#if QK_K == 64
const int32x4_t mask = vdupq_n_s32(0xf);
#else
const int32x4_t mask = vdupq_n_s32(0x7);
#endif
const int32x4_t mone = vdupq_n_s32(1);
const int32x4_t mzero = vdupq_n_s32(0);

Expand All @@ -9784,7 +9801,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;

#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
#endif

int32x4_t sumi1 = mzero;
int32x4_t sumi2 = mzero;
Expand Down Expand Up @@ -9813,7 +9832,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
const int32x4_t p34 = vpaddq_s32(p3, p4);

#if QK_K == 64
int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12);
#else
int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
#endif
scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);

sumi1 = vmlaq_s32(sumi1, scales_4, p12);
Expand All @@ -9823,14 +9846,22 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void

}

#if QK_K == 64
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
#else
sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
#endif
}

*s = sumf;

#elif defined __AVX2__

#if QK_K == 64
const __m256i mask = _mm256_set1_epi16(0xf);
#else
const __m256i mask = _mm256_set1_epi16(0x7);
#endif
const __m256i mone = _mm256_set1_epi16(1);

__m256 accum1 = _mm256_setzero_ps();
Expand All @@ -9842,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;

#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
#endif

__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
Expand Down Expand Up @@ -9872,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void

const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
#if QK_K == 64
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0));
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
#else
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
#endif
scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
Expand All @@ -9887,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
qs += 8; qh += 4;
}

#if QK_K == 64
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
#else
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
#endif
accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);

Expand All @@ -9907,7 +9949,9 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;

#if QK_K != 64
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
#endif

int sumi1 = 0, sumi2 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
Expand All @@ -9927,15 +9971,24 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
sum1[l/2] += lsum1;
sum2[l/2] += lsum2*delta[l];
}
#if QK_K == 64
const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1;
const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1;
#else
const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
#endif
sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
qs += 4;
qh += 2;
}

#if QK_K == 64
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
#else
sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
#endif
}

*s = sumf;
Expand Down Expand Up @@ -11986,7 +12039,9 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy

for (int ibl = 0; ibl < nbl; ++ibl) {

//y[ibl].d = GGML_FP32_TO_FP16(0.f);
#if QK_K == 64
y[ibl].d = GGML_FP32_TO_FP16(0.f);
#endif
memset(y[ibl].qs, 0, QK_K/8);
memset(y[ibl].qh, 0, QK_K/16);
memset(y[ibl].scales, 0, QK_K/32);
Expand Down Expand Up @@ -12161,13 +12216,22 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
}

uint16_t * sc = (uint16_t *)y[ibl].scales;
#if QK_K == 64
float d = max_scale/31;
#else
float d = max_scale/15;
#endif
float id = 1/d;
float sumqx_f = 0, sumq2_f = 0;
for (int ib = 0; ib < QK_K/block_size; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
#if QK_K == 64
l = MAX(0, MIN(15, l));
sc[ib/4] |= (l << 4*(ib%4));
#else
l = MAX(0, MIN(7, l));
sc[ib/4] |= (l << 3*(ib%4));
#endif
y[ibl].qh[ib] |= masks[shifts[ib]];
const float * xb = xbl + block_size*ib;
if (quant_weights) {
Expand All @@ -12190,10 +12254,14 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
}
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
#if QK_K == 64
y[ibl].d = s.f16;
#else
sc[0] |= ((s.u16 & 0x000f) << 12);
sc[1] |= ((s.u16 & 0x00f0) << 8);
sc[2] |= ((s.u16 & 0x0f00) << 4);
sc[3] |= ((s.u16 & 0xf000) << 0);
#endif
}
}

Expand Down