diff --git a/demo/test.c b/demo/test.c index c5141ef7..f290dbf2 100644 --- a/demo/test.c +++ b/demo/test.c @@ -1229,11 +1229,20 @@ static int test_mp_cnt_lsb(void) static int test_mp_reduce_2k(void) { int ix, cnt; + bool is2k; mp_int a, b, c, d; DOR(mp_init_multi(&a, &b, &c, &d, NULL)); /* test mp_reduce_2k */ + + /* Algorithm as implemented does not work if the least significant digit is zero */ + DO(mp_2expt(&a, 100)); + DO(mp_sub_d(&a, 1, &a)); + DO(mp_sub_d(&a, MP_MASK, &a)); + is2k = mp_reduce_is_2k(&a); + EXPECT(!is2k); + for (cnt = 3; cnt <= 128; ++cnt) { mp_digit tmp; diff --git a/mp_reduce_is_2k.c b/mp_reduce_is_2k.c index 9774f96e..d5496338 100644 --- a/mp_reduce_is_2k.c +++ b/mp_reduce_is_2k.c @@ -11,9 +11,16 @@ bool mp_reduce_is_2k(const mp_int *a) } else if (a->used == 1) { return true; } else if (a->used > 1) { - int ix, iy = mp_count_bits(a), iw = 1; - mp_digit iz = 1; + int ix, iy, iw = 1; + mp_digit iz; + /* Algorithm as implemented does not work if the least significant digit is zero */ + iz = a->dp[0] & MP_MASK; + if (iz == 0u) { + return false; + } + iy = mp_count_bits(a); + iz = 1; /* Test every bit from the second digit up, must be 1 */ for (ix = MP_DIGIT_BIT; ix < iy; ix++) { if ((a->dp[iw] & iz) == 0u) {