diff --git a/demo/test.c b/demo/test.c index 651e2b419..3b23d8629 100644 --- a/demo/test.c +++ b/demo/test.c @@ -1,6 +1,31 @@ #include #include "shared.h" +static mp_err mp_sqr_simple(const mp_int *a, mp_int *b) +{ + mp_err err; + if (b->dp == a->dp) { + mp_int t; + if ((err = mp_init_size(&t, (2 * a->used) + 1)) != MP_OKAY) { + return err; + } + if ((err = s_mp_sqr(a, &t)) != MP_OKAY) { + mp_clear(&t); + return err; + } + mp_exch(b, &t); + mp_clear(&t); + } else { + if ((err = mp_grow(b, (2 * a->used) + 1)) != MP_OKAY) { + return err; + } + mp_zero(b); + err = s_mp_sqr(a, b); + } + b->sign = MP_ZPOS; + return err; +} + static long rand_long(void) { long x; @@ -1930,7 +1955,7 @@ static int test_s_mp_sqr_karatsuba(void) for (size = MP_SQR_KARATSUBA_CUTOFF; size < MP_SQR_KARATSUBA_CUTOFF + 20; size++) { DO(mp_rand(&a, size)); DO(s_mp_sqr_karatsuba(&a, &b)); - DO(s_mp_sqr(&a, &c)); + DO(mp_sqr_simple(&a, &c)); if (mp_cmp(&b, &c) != MP_EQ) { fprintf(stderr, "Karatsuba squaring failed at size %d\n", size); goto LBL_ERR; @@ -2003,7 +2028,7 @@ static int test_s_mp_sqr_toom(void) for (size = MP_SQR_TOOM_CUTOFF; size < MP_SQR_TOOM_CUTOFF + 20; size++) { DO(mp_rand(&a, size)); DO(s_mp_sqr_toom(&a, &b)); - DO(s_mp_sqr(&a, &c)); + DO(mp_sqr_simple(&a, &c)); if (mp_cmp(&b, &c) != MP_EQ) { fprintf(stderr, "Toom-Cook 3-way squaring failed at size %d\n", size); goto LBL_ERR; diff --git a/mp_sqr.c b/mp_sqr.c index c88fe449b..66f4971f2 100644 --- a/mp_sqr.c +++ b/mp_sqr.c @@ -3,6 +3,17 @@ /* LibTomMath, multiple-precision integer library -- Tom St Denis */ /* SPDX-License-Identifier: Unlicense */ +static mp_err s_mp_sqr_noalias(const mp_int *a, mp_int *b) +{ + if (MP_HAS(S_MP_SQR_COMBA) && (a->used < (MP_MAX_COMBA / 2))) { + return s_mp_sqr_comba(a, b); + } else if (MP_HAS(S_MP_SQR)) { + return s_mp_sqr(a, b); + } else { + return MP_VAL; + } +} + /* computes b = a*a */ mp_err mp_sqr(const mp_int *a, mp_int *b) { @@ -13,14 +24,23 @@ mp_err mp_sqr(const mp_int *a, mp_int *b) } else if (MP_HAS(S_MP_SQR_KARATSUBA) && /* Karatsuba? */ (a->used >= MP_SQR_KARATSUBA_CUTOFF)) { err = s_mp_sqr_karatsuba(a, b); - } else if (MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */ - (a->dp != b->dp) && - (a->used < (MP_MAX_COMBA / 2))) { - err = s_mp_sqr_comba(a, b); - } else if (MP_HAS(S_MP_SQR)) { - err = s_mp_sqr(a, b); + } else if (b->dp == a->dp) { + mp_int t; + if ((err = mp_init_size(&t, (2 * a->used) + 1)) != MP_OKAY) { + return err; + } + if ((err = s_mp_sqr_noalias(a, &t)) != MP_OKAY) { + mp_clear(&t); + return err; + } + mp_exch(b, &t); + mp_clear(&t); } else { - err = MP_VAL; + if ((err = mp_grow(b, (2 * a->used) + 1)) != MP_OKAY) { + return err; + } + mp_zero(b); + err = s_mp_sqr_noalias(a, b); } b->sign = MP_ZPOS; return err; diff --git a/s_mp_sqr.c b/s_mp_sqr.c index 4a2030638..c09662a9a 100644 --- a/s_mp_sqr.c +++ b/s_mp_sqr.c @@ -6,17 +6,7 @@ /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */ mp_err s_mp_sqr(const mp_int *a, mp_int *b) { - mp_int t; - int ix, pa; - mp_err err; - - pa = a->used; - if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) { - return err; - } - - /* default used is maximum possible size */ - t.used = (2 * pa) + 1; + int ix, pa = a->used; for (ix = 0; ix < pa; ix++) { mp_digit u; @@ -24,11 +14,11 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b) /* first calculate the digit at 2*ix */ /* calculate double precision result */ - mp_word r = (mp_word)t.dp[2*ix] + + mp_word r = (mp_word)b->dp[2*ix] + ((mp_word)a->dp[ix] * (mp_word)a->dp[ix]); /* store lower part in result */ - t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK); + b->dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK); /* get the carry */ u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); @@ -40,26 +30,25 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b) /* now calculate the double precision result, note we use * addition instead of *2 since it's easier to optimize */ - r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u; + r = (mp_word)b->dp[ix + iy] + r + r + (mp_word)u; /* store lower part */ - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); + b->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); /* get carry */ u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); } /* propagate upwards */ while (u != 0uL) { - r = (mp_word)t.dp[ix + iy] + (mp_word)u; - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); + r = (mp_word)b->dp[ix + iy] + (mp_word)u; + b->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); ++iy; } } - mp_clamp(&t); - mp_exch(&t, b); - mp_clear(&t); + b->used = (2 * pa) + 1; + mp_clamp(b); return MP_OKAY; } #endif diff --git a/s_mp_sqr_comba.c b/s_mp_sqr_comba.c index 854aa859c..accda25b6 100644 --- a/s_mp_sqr_comba.c +++ b/s_mp_sqr_comba.c @@ -15,19 +15,12 @@ After that loop you do the squares and add them in. mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) { - int oldused, pa, ix; + int ix, pa = a->used, pb = 2 * pa; mp_word W1; - mp_err err; - - /* grow the destination as required */ - pa = a->used + a->used; - if ((err = mp_grow(b, pa)) != MP_OKAY) { - return err; - } /* number of output digits to produce */ W1 = 0; - for (ix = 0; ix < pa; ix++) { + for (ix = 0; ix < pb; ix++) { int tx, ty, iy, iz; mp_word W; @@ -35,13 +28,13 @@ mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) W = 0; /* get offsets into the two bignums */ - ty = MP_MIN(a->used-1, ix); + ty = MP_MIN(pa-1, ix); tx = ix - ty; /* this is the number of times the loop will iterrate, essentially - while (tx++ < a->used && ty-- >= 0) { ... } + while (tx++ < pa && ty-- >= 0) { ... } */ - iy = MP_MIN(a->used-tx, ty+1); + iy = MP_MIN(pa-tx, ty+1); /* now for squaring tx can never equal ty * we halve the distance since they approach at a rate of 2x @@ -69,11 +62,7 @@ mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) W1 = W >> (mp_word)MP_DIGIT_BIT; } - /* clear unused digits [that existed in the old copy of c] */ - oldused = b->used; - b->used = a->used + a->used; - s_mp_zero_digs(b->dp + b->used, oldused - b->used); - + b->used = pb; mp_clamp(b); return MP_OKAY; } diff --git a/tommath_class.h b/tommath_class.h index 5c8622465..2403325f4 100644 --- a/tommath_class.h +++ b/tommath_class.h @@ -889,9 +889,14 @@ #endif #if defined(MP_SQR_C) +# define MP_CLEAR_C +# define MP_EXCH_C +# define MP_GROW_C +# define MP_INIT_SIZE_C # define S_MP_SQR_C # define S_MP_SQR_COMBA_C # define S_MP_SQR_KARATSUBA_C +# define S_MP_SQR_NOALIAS_C # define S_MP_SQR_TOOM_C #endif @@ -1222,14 +1227,10 @@ #if defined(S_MP_SQR_C) # define MP_CLAMP_C -# define MP_CLEAR_C -# define MP_EXCH_C -# define MP_INIT_SIZE_C #endif #if defined(S_MP_SQR_COMBA_C) # define MP_CLAMP_C -# define MP_GROW_C # define S_MP_ZERO_DIGS_C #endif