Skip to content

Commit

Permalink
try alternative 2 for mp_sqr
Browse files Browse the repository at this point in the history
  • Loading branch information
minad committed Oct 30, 2019
1 parent ead719c commit 44caa78
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 50 deletions.
29 changes: 27 additions & 2 deletions demo/test.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
#include <inttypes.h>
#include "shared.h"

static mp_err mp_sqr_simple(const mp_int *a, mp_int *b)
{
mp_err err;
if (b->dp == a->dp) {
mp_int t;
if ((err = mp_init_size(&t, (2 * a->used) + 1)) != MP_OKAY) {
return err;
}
if ((err = s_mp_sqr(a, &t)) != MP_OKAY) {
mp_clear(&t);
return err;
}
mp_exch(b, &t);
mp_clear(&t);
} else {
if ((err = mp_grow(b, (2 * a->used) + 1)) != MP_OKAY) {
return err;
}
mp_zero(b);
err = s_mp_sqr(a, b);
}
b->sign = MP_ZPOS;
return err;
}

static long rand_long(void)
{
long x;
Expand Down Expand Up @@ -1930,7 +1955,7 @@ static int test_s_mp_sqr_karatsuba(void)
for (size = MP_SQR_KARATSUBA_CUTOFF; size < MP_SQR_KARATSUBA_CUTOFF + 20; size++) {
DO(mp_rand(&a, size));
DO(s_mp_sqr_karatsuba(&a, &b));
DO(s_mp_sqr(&a, &c));
DO(mp_sqr_simple(&a, &c));
if (mp_cmp(&b, &c) != MP_EQ) {
fprintf(stderr, "Karatsuba squaring failed at size %d\n", size);
goto LBL_ERR;
Expand Down Expand Up @@ -2003,7 +2028,7 @@ static int test_s_mp_sqr_toom(void)
for (size = MP_SQR_TOOM_CUTOFF; size < MP_SQR_TOOM_CUTOFF + 20; size++) {
DO(mp_rand(&a, size));
DO(s_mp_sqr_toom(&a, &b));
DO(s_mp_sqr(&a, &c));
DO(mp_sqr_simple(&a, &c));
if (mp_cmp(&b, &c) != MP_EQ) {
fprintf(stderr, "Toom-Cook 3-way squaring failed at size %d\n", size);
goto LBL_ERR;
Expand Down
34 changes: 27 additions & 7 deletions mp_sqr.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
/* SPDX-License-Identifier: Unlicense */

static mp_err s_mp_sqr_noalias(const mp_int *a, mp_int *b)
{
if (MP_HAS(S_MP_SQR_COMBA) && (a->used < (MP_MAX_COMBA / 2))) {
return s_mp_sqr_comba(a, b);
} else if (MP_HAS(S_MP_SQR)) {
return s_mp_sqr(a, b);
} else {
return MP_VAL;
}
}

/* computes b = a*a */
mp_err mp_sqr(const mp_int *a, mp_int *b)
{
Expand All @@ -13,14 +24,23 @@ mp_err mp_sqr(const mp_int *a, mp_int *b)
} else if (MP_HAS(S_MP_SQR_KARATSUBA) && /* Karatsuba? */
(a->used >= MP_SQR_KARATSUBA_CUTOFF)) {
err = s_mp_sqr_karatsuba(a, b);
} else if (MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */
(a->dp != b->dp) &&
(a->used < (MP_MAX_COMBA / 2))) {
err = s_mp_sqr_comba(a, b);
} else if (MP_HAS(S_MP_SQR)) {
err = s_mp_sqr(a, b);
} else if (b->dp == a->dp) {
mp_int t;
if ((err = mp_init_size(&t, (2 * a->used) + 1)) != MP_OKAY) {
return err;
}
if ((err = s_mp_sqr_noalias(a, &t)) != MP_OKAY) {
mp_clear(&t);
return err;
}
mp_exch(b, &t);
mp_clear(&t);
} else {
err = MP_VAL;
if ((err = mp_grow(b, (2 * a->used) + 1)) != MP_OKAY) {
return err;
}
mp_zero(b);
err = s_mp_sqr_noalias(a, b);
}
b->sign = MP_ZPOS;
return err;
Expand Down
29 changes: 9 additions & 20 deletions s_mp_sqr.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,19 @@
/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
mp_err s_mp_sqr(const mp_int *a, mp_int *b)
{
mp_int t;
int ix, pa;
mp_err err;

pa = a->used;
if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
return err;
}

/* default used is maximum possible size */
t.used = (2 * pa) + 1;
int ix, pa = a->used;

for (ix = 0; ix < pa; ix++) {
mp_digit u;
int iy;

/* first calculate the digit at 2*ix */
/* calculate double precision result */
mp_word r = (mp_word)t.dp[2*ix] +
mp_word r = (mp_word)b->dp[2*ix] +
((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);

/* store lower part in result */
t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
b->dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);

/* get the carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
Expand All @@ -40,26 +30,25 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
/* now calculate the double precision result, note we use
* addition instead of *2 since it's easier to optimize
*/
r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
r = (mp_word)b->dp[ix + iy] + r + r + (mp_word)u;

/* store lower part */
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
b->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);

/* get carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
}
/* propagate upwards */
while (u != 0uL) {
r = (mp_word)t.dp[ix + iy] + (mp_word)u;
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
r = (mp_word)b->dp[ix + iy] + (mp_word)u;
b->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
++iy;
}
}

mp_clamp(&t);
mp_exch(&t, b);
mp_clear(&t);
b->used = (2 * pa) + 1;
mp_clamp(b);
return MP_OKAY;
}
#endif
23 changes: 6 additions & 17 deletions s_mp_sqr_comba.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,26 @@ After that loop you do the squares and add them in.

mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b)
{
int oldused, pa, ix;
int ix, pa = a->used, pb = 2 * pa;
mp_word W1;
mp_err err;

/* grow the destination as required */
pa = a->used + a->used;
if ((err = mp_grow(b, pa)) != MP_OKAY) {
return err;
}

/* number of output digits to produce */
W1 = 0;
for (ix = 0; ix < pa; ix++) {
for (ix = 0; ix < pb; ix++) {
int tx, ty, iy, iz;
mp_word W;

/* clear counter */
W = 0;

/* get offsets into the two bignums */
ty = MP_MIN(a->used-1, ix);
ty = MP_MIN(pa-1, ix);
tx = ix - ty;

/* this is the number of times the loop will iterrate, essentially
while (tx++ < a->used && ty-- >= 0) { ... }
while (tx++ < pa && ty-- >= 0) { ... }
*/
iy = MP_MIN(a->used-tx, ty+1);
iy = MP_MIN(pa-tx, ty+1);

/* now for squaring tx can never equal ty
* we halve the distance since they approach at a rate of 2x
Expand Down Expand Up @@ -69,11 +62,7 @@ mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b)
W1 = W >> (mp_word)MP_DIGIT_BIT;
}

/* clear unused digits [that existed in the old copy of c] */
oldused = b->used;
b->used = a->used + a->used;
s_mp_zero_digs(b->dp + b->used, oldused - b->used);

b->used = pb;
mp_clamp(b);
return MP_OKAY;
}
Expand Down
9 changes: 5 additions & 4 deletions tommath_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,14 @@
#endif

#if defined(MP_SQR_C)
# define MP_CLEAR_C
# define MP_EXCH_C
# define MP_GROW_C
# define MP_INIT_SIZE_C
# define S_MP_SQR_C
# define S_MP_SQR_COMBA_C
# define S_MP_SQR_KARATSUBA_C
# define S_MP_SQR_NOALIAS_C
# define S_MP_SQR_TOOM_C
#endif

Expand Down Expand Up @@ -1222,14 +1227,10 @@

#if defined(S_MP_SQR_C)
# define MP_CLAMP_C
# define MP_CLEAR_C
# define MP_EXCH_C
# define MP_INIT_SIZE_C
#endif

#if defined(S_MP_SQR_COMBA_C)
# define MP_CLAMP_C
# define MP_GROW_C
# define S_MP_ZERO_DIGS_C
#endif

Expand Down

0 comments on commit 44caa78

Please sign in to comment.