Skip to content

Commit

Permalink
ggml : simplify scalar dot
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 5, 2023
1 parent f08f6f7 commit 79e49c9
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2236,15 +2236,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
float sumf = 0.0;

for (int i = 0; i < nb; i++) {
const int8_t * py = y[i].qs;

int sumi = 0;

for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0xf) - 8;
const int v1 = (x[i].qs[j] >> 4) - 8;

sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
}

sumf += (x[i].d*y[i].d)*sumi;
Expand Down Expand Up @@ -2360,15 +2358,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
float sumf = 0.0;

for (int i = 0; i < nb; i++) {
const int8_t * py = y[i].qs;

int sumi = 0;

for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0xf);
const int v1 = (x[i].qs[j] >> 4);

sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
}

sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
Expand Down Expand Up @@ -2694,8 +2690,6 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
float sumf = 0.0;

for (int i = 0; i < nb; i++) {
const int8_t * py = y[i].qs;

uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));

Expand All @@ -2708,7 +2702,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;

sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
}

sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;
Expand Down Expand Up @@ -2889,8 +2883,6 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
float sumf = 0.0;

for (int i = 0; i < nb; i++) {
const int8_t * py = y[i].qs;

uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));

Expand All @@ -2903,7 +2895,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;

sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
}

sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1);
Expand Down

0 comments on commit 79e49c9

Please sign in to comment.