diff --git a/src/flint/flintlib/fmpz_mod.pxd b/src/flint/flintlib/fmpz_mod.pxd index a96a0e3a..b6d96f82 100644 --- a/src/flint/flintlib/fmpz_mod.pxd +++ b/src/flint/flintlib/fmpz_mod.pxd @@ -1,10 +1,11 @@ from flint.flintlib.flint cimport ulong, slong -from flint.flintlib.fmpz cimport fmpz_t, fmpz_preinvn_struct +from flint.flintlib.fmpz cimport fmpz_t, fmpz_struct, fmpz_preinvn_struct from flint.flintlib.nmod cimport nmod_t -# unimported types {'fmpz_mod_discrete_log_pohlig_hellman_t'} - cdef extern from "flint/fmpz_mod.h": + # + # fmpz_mod structs, a la Pohlig - Hellman + # ctypedef struct fmpz_mod_ctx_struct: fmpz_t n nmod_t mod @@ -13,6 +14,36 @@ cdef extern from "flint/fmpz_mod.h": fmpz_preinvn_struct * ninv_huge ctypedef fmpz_mod_ctx_struct fmpz_mod_ctx_t[1] + # + # discrete logs structs, a la Pohlig - Hellman + # + + ctypedef struct fmpz_mod_discrete_log_pohlig_hellman_table_entry_struct: + fmpz_t gammapow + ulong cm + + ctypedef struct fmpz_mod_discrete_log_pohlig_hellman_entry_struct: + slong exp + ulong prime + fmpz_t gamma + fmpz_t gammainv + fmpz_t startingbeta + fmpz_t co + fmpz_t startinge + fmpz_t idem + ulong cbound + ulong dbound + fmpz_mod_discrete_log_pohlig_hellman_table_entry_struct * table # length cbound */ + + ctypedef struct fmpz_mod_discrete_log_pohlig_hellman_struct: + fmpz_mod_ctx_t fpctx + fmpz_t pm1 # p - 1 */ + fmpz_t alpha # p.r. of p */ + fmpz_t alphainv + slong num_factors # factors of p - 1 + fmpz_mod_discrete_log_pohlig_hellman_entry_struct * entries + ctypedef fmpz_mod_discrete_log_pohlig_hellman_struct fmpz_mod_discrete_log_pohlig_hellman_t[1] + # Parsed from here void fmpz_mod_ctx_init(fmpz_mod_ctx_t ctx, const fmpz_t n) void fmpz_mod_ctx_clear(fmpz_mod_ctx_t ctx) @@ -37,9 +68,9 @@ cdef extern from "flint/fmpz_mod.h": int fmpz_mod_divides(fmpz_t a, const fmpz_t b, const fmpz_t c, const fmpz_mod_ctx_t ctx) void fmpz_mod_pow_ui(fmpz_t a, const fmpz_t b, ulong e, const fmpz_mod_ctx_t ctx) int fmpz_mod_pow_fmpz(fmpz_t a, const fmpz_t b, const fmpz_t e, const fmpz_mod_ctx_t ctx) - # void fmpz_mod_discrete_log_pohlig_hellman_init(fmpz_mod_discrete_log_pohlig_hellman_t L) - # void fmpz_mod_discrete_log_pohlig_hellman_clear(fmpz_mod_discrete_log_pohlig_hellman_t L) - # double fmpz_mod_discrete_log_pohlig_hellman_precompute_prime(fmpz_mod_discrete_log_pohlig_hellman_t L, const fmpz_t p) - # const fmpz_struct * fmpz_mod_discrete_log_pohlig_hellman_primitive_root(const fmpz_mod_discrete_log_pohlig_hellman_t L) - # void fmpz_mod_discrete_log_pohlig_hellman_run(fmpz_t x, const fmpz_mod_discrete_log_pohlig_hellman_t L, const fmpz_t y) + void fmpz_mod_discrete_log_pohlig_hellman_init(fmpz_mod_discrete_log_pohlig_hellman_t L) + void fmpz_mod_discrete_log_pohlig_hellman_clear(fmpz_mod_discrete_log_pohlig_hellman_t L) + double fmpz_mod_discrete_log_pohlig_hellman_precompute_prime(fmpz_mod_discrete_log_pohlig_hellman_t L, const fmpz_t p) + const fmpz_struct * fmpz_mod_discrete_log_pohlig_hellman_primitive_root(const fmpz_mod_discrete_log_pohlig_hellman_t L) + void fmpz_mod_discrete_log_pohlig_hellman_run(fmpz_t x, const fmpz_mod_discrete_log_pohlig_hellman_t L, const fmpz_t y) int fmpz_next_smooth_prime(fmpz_t a, const fmpz_t b) diff --git a/src/flint/test/test.py b/src/flint/test/test.py index 8121d1ed..3a5b7622 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -1770,6 +1770,50 @@ def test_fmpz_mod(): assert fmpz(test_y) / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod assert test_y / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod +def test_fmpz_mod_dlog(): + from flint import fmpz, fmpz_mod_ctx + + # Input modulus must be prime + F = fmpz_mod_ctx(4) + g, a = F(1), F(2) + assert raises(lambda: g.discrete_log(a, check=True), ValueError) + + # Moduli must match + F1, F2 = fmpz_mod_ctx(2), fmpz_mod_ctx(3) + g = F1(2) + a = F2(4) + assert raises(lambda: g.discrete_log(a, check=True), ValueError) + + # Need to use either fmpz_mod or something which can be case to + # fmpz + assert raises(lambda: g.discrete_log("A", check=True), TypeError) + + F = fmpz_mod_ctx(163) + g = F(2) + a = g**123 + + assert 123 == g.discrete_log(a) + + a_int = pow(2, 123, 163) + a_fmpz = fmpz(a_int) + assert 123 == g.discrete_log(a_int) + assert 123 == g.discrete_log(a_fmpz) + + # Randomised testing with smooth large modulus + e2, e3 = 92, 79 + p = 2**e2 * 3**e3 + 1 + F = fmpz_mod_ctx(p) + + import random + for _ in range(10): + g = F(random.randint(0,p)) + for _ in range(10): + i = random.randint(0,p) + a = g**i + x = g.discrete_log(a) + assert g**x == a + + all_tests = [ test_pyflint, @@ -1790,4 +1834,5 @@ def test_fmpz_mod(): test_nmod_mat, test_arb, test_fmpz_mod, + test_fmpz_mod_dlog ] diff --git a/src/flint/types/fmpz_mod.pxd b/src/flint/types/fmpz_mod.pxd index 264d35d7..11b68928 100644 --- a/src/flint/types/fmpz_mod.pxd +++ b/src/flint/types/fmpz_mod.pxd @@ -1,14 +1,18 @@ from flint.flint_base.flint_base cimport flint_scalar from flint.flintlib.fmpz cimport fmpz_t -from flint.flintlib.fmpz_mod cimport fmpz_mod_ctx_t +from flint.flintlib.fmpz_mod cimport ( + fmpz_mod_ctx_t, + fmpz_mod_discrete_log_pohlig_hellman_t +) cdef class fmpz_mod_ctx: cdef fmpz_mod_ctx_t val - + cdef fmpz_mod_discrete_log_pohlig_hellman_t *L + cdef any_as_fmpz_mod(self, obj) + cdef _precompute_dlog_prime(self) + cdef class fmpz_mod(flint_scalar): cdef fmpz_mod_ctx ctx cdef fmpz_t val - - cdef any_as_fmpz_mod(self, obj) - + cdef fmpz_t *x_g diff --git a/src/flint/types/fmpz_mod.pyx b/src/flint/types/fmpz_mod.pyx index 362a9161..686d21e4 100644 --- a/src/flint/types/fmpz_mod.pyx +++ b/src/flint/types/fmpz_mod.pyx @@ -1,21 +1,29 @@ -from flint.flintlib.fmpz cimport ( +from flint.flintlib.fmpz cimport( fmpz_t, fmpz_one, fmpz_set, fmpz_init, fmpz_clear, - fmpz_equal + fmpz_equal, + fmpz_is_probabprime, + fmpz_mul, + fmpz_invmod, + fmpz_divexact, + fmpz_gcd, + fmpz_is_one ) +from flint.flintlib.fmpz cimport fmpz_mod as fmpz_type_mod from flint.flintlib.fmpz_mod cimport * from flint.utils.typecheck cimport typecheck from flint.flint_base.flint_base cimport flint_scalar -from flint.types.fmpz cimport ( +from flint.types.fmpz cimport( fmpz, any_as_fmpz, fmpz_get_intlong ) - +cimport cython +cimport libc.stdlib cdef class fmpz_mod_ctx: r""" @@ -30,24 +38,30 @@ cdef class fmpz_mod_ctx: cdef fmpz one = fmpz.__new__(fmpz) fmpz_one(one.val) fmpz_mod_ctx_init(self.val, one.val) + self.L = NULL + def __dealloc__(self): fmpz_mod_ctx_clear(self.val) + if self.L: + fmpz_mod_discrete_log_pohlig_hellman_clear(self.L[0]) def __init__(self, mod): # Ensure modulus is fmpz type if not typecheck(mod, fmpz): mod = any_as_fmpz(mod) if mod is NotImplemented: - raise TypeError("Context modulus must be able to be case to an `fmpz` type") + raise TypeError( + "Context modulus must be able to be case to an `fmpz` type" + ) # Ensure modulus is positive if mod < 1: raise ValueError("Modulus is expected to be positive") - # Init the context - fmpz_mod_ctx_init(self.val, (mod).val) - + # Set the modulus + fmpz_mod_ctx_set_modulus(self.val, (mod).val) + def modulus(self): """ Return the modulus from the context as an fmpz @@ -62,7 +76,55 @@ cdef class fmpz_mod_ctx: fmpz_set(n.val, (self.val.n)) return n + cdef _precompute_dlog_prime(self): + """ + Initalise the dlog data, all discrete logs are solved with an + internally chosen base `y` + """ + self.L = libc.stdlib.malloc( + cython.sizeof(fmpz_mod_discrete_log_pohlig_hellman_struct) + ) + fmpz_mod_discrete_log_pohlig_hellman_init(self.L[0]) + fmpz_mod_discrete_log_pohlig_hellman_precompute_prime( + self.L[0], self.val.n + ) + + cdef any_as_fmpz_mod(self, obj): + # If `obj` is an `fmpz_mod`, just check moduli + # match + # TODO: we could allow conversion from one modulus to another? + if typecheck(obj, fmpz_mod): + if self != (obj).ctx: + raise ValueError("moduli must match") + return obj + + # Try and convert obj to fmpz + if not typecheck(obj, fmpz): + obj = any_as_fmpz(obj) + if obj is NotImplemented: + return NotImplemented + + # We have been able to cast `obj` to an `fmpz` so + # we create a new `fmpz_mod` and set the val + cdef fmpz_mod res + res = fmpz_mod.__new__(fmpz_mod) + res.ctx = self + fmpz_mod_set_fmpz(res.val, (obj).val, self.val) + + return res + def __eq__(self, other): + # TODO: + # If we could cache contexts, then we would ensure that only + # the a is b check is needed for equality. + + # Most often, we expect both `fmpz_mod` to be pointing to the + # same ctx, so this seems the fastest way to check + if self is other: + return True + + # If they're not the same object in memory, they may have the + # same modulus, which is good enough if typecheck(other, fmpz_mod_ctx): return fmpz_equal(self.val.n, (other).val.n) return False @@ -99,9 +161,12 @@ cdef class fmpz_mod(flint_scalar): def __cinit__(self): fmpz_init(self.val) + self.x_g = NULL def __dealloc__(self): fmpz_clear(self.val) + if self.x_g: + fmpz_clear(self.x_g[0]) def __init__(self, val, ctx): if not typecheck(ctx, fmpz_mod_ctx): @@ -125,12 +190,6 @@ cdef class fmpz_mod(flint_scalar): raise NotImplementedError fmpz_mod_set_fmpz(self.val, (val).val, self.ctx.val) - cdef any_as_fmpz_mod(self, obj): - try: - return self.ctx(obj) - except NotImplementedError: - return NotImplemented - def is_zero(self): """ Return whether an element is equal to zero @@ -142,7 +201,7 @@ cdef class fmpz_mod(flint_scalar): False """ return self == 0 - + def is_one(self): """ Return whether an element is equal to one @@ -158,13 +217,133 @@ cdef class fmpz_mod(flint_scalar): res = fmpz_mod_is_one(self.val, self.ctx.val) return res == 1 + def inverse(self, check=True): + r""" + Computes :math:`a^{-1} \pmod N` + + When check=False, the solutions is assumed to exist and Flint will abort on + failure. + + >>> mod_ctx = fmpz_mod_ctx(163) + >>> mod_ctx(2).inverse() + fmpz_mod(82, 163) + >>> mod_ctx(2).inverse(check=False) + fmpz_mod(82, 163) + """ + cdef fmpz_mod res + res = fmpz_mod.__new__(fmpz_mod) + res.ctx = self.ctx + + if check is False: + fmpz_mod_inv(res.val, self.val, self.ctx.val) + return res + + cdef bint r + cdef fmpz one = fmpz.__new__(fmpz) + fmpz_one(one.val) + + r = fmpz_mod_divides( + res.val, one.val, self.val, self.ctx.val + ) + if r == 0: + raise ZeroDivisionError( + f"{self} is not invertible modulo {self.ctx.modulus()}" + ) + + return res + + def discrete_log(self, a, check=False): + """ + Solve the discrete logarithm problem, using `self = g` as a base. + Assumes a solution, :math:`a = g^x \pmod p` exists. + + NOTE: Requires that the context modulus is prime. + + >>> F = fmpz_mod_ctx(163) + >>> g = F(2) + >>> x = 123 + >>> a = g**123 + >>> g.discrete_log(a) + 123 + """ + cdef bint is_prime + + # Ensure that the modulus is prime + if check: + is_prime = fmpz_is_probabprime(self.ctx.val.n) + if not is_prime: + raise ValueError("modulus must be prime") + + # Then check the type of the input + if typecheck(a, fmpz_mod): + if self.ctx != (a).ctx: + raise ValueError("moduli must match") + else: + a = self.ctx.any_as_fmpz_mod(a) + if a is NotImplemented: + raise TypeError + + # First, Ensure that self.ctx.L has performed precomputations + # This generates a `y` which is a primative root, and used as + # the base in `fmpz_mod_discrete_log_pohlig_hellman_run` + if not self.ctx.L: + self.ctx._precompute_dlog_prime() + + # Solve the discrete log for the chosen base and target + # g = y^x_g and a = y^x_a + # We want to find x such that a = g^x => + # (y^x_a) = (y^x_g)^x => x = (x_a / x_g) mod (p-1) + + # For repeated calls to discrete_log, it's more efficient to + # store x_g rather than keep computing it + if not self.x_g: + self.x_g = libc.stdlib.malloc( + cython.sizeof(fmpz_t) + ) + fmpz_mod_discrete_log_pohlig_hellman_run( + self.x_g[0], self.ctx.L[0], self.val + ) + + # Then we need to compute x_a which will be different for each call + cdef fmpz_t x_a + fmpz_init(x_a) + fmpz_mod_discrete_log_pohlig_hellman_run( + x_a, self.ctx.L[0], (a).val + ) + + # If g is not a primative root, then x_g and pm1 will share + # a common factor. We can use this to compute the order of + # g. + cdef fmpz_t g, g_order, x_g + fmpz_init(g) + fmpz_init(g_order) + fmpz_init(x_g) + + fmpz_gcd(g, self.x_g[0], self.ctx.L[0].pm1) + if not fmpz_is_one(g): + fmpz_divexact(x_g, self.x_g[0], g) + fmpz_divexact(x_a, x_a, g) + fmpz_divexact(g_order, self.ctx.L[0].pm1, g) + else: + fmpz_set(g_order, self.ctx.L[0].pm1) + fmpz_set(x_g, self.x_g[0]) + + # Finally, compute output exponent by computing + # (x_a / x_g) mod g_order + cdef fmpz x = fmpz.__new__(fmpz) + fmpz_invmod(x.val, x_g, g_order) + fmpz_mul(x.val, x.val, x_a) + fmpz_type_mod(x.val, x.val, g_order) + + return x + def __richcmp__(self, other, int op): cdef bint res if op != 2 and op != 3: raise TypeError("fmpz_mod cannot be ordered") if not typecheck(other, fmpz_mod): - other = self.any_as_fmpz_mod(other) + other = self.ctx.any_as_fmpz_mod(other) if typecheck(self, fmpz_mod) and typecheck(other, fmpz_mod): res = fmpz_equal(self.val, (other).val) and \ @@ -186,7 +365,7 @@ cdef class fmpz_mod(flint_scalar): ) def __hash__(self): - return hash((int(self))) + return hash((int(self))) def __int__(self): return fmpz_get_intlong(self.val) @@ -209,7 +388,7 @@ cdef class fmpz_mod(flint_scalar): return res def __add__(self, other): - other = self.any_as_fmpz_mod(other) + other = self.ctx.any_as_fmpz_mod(other) if other is NotImplemented: return NotImplemented @@ -224,25 +403,42 @@ cdef class fmpz_mod(flint_scalar): def __radd__(self, other): return self.__add__(other) - def __sub__(self, other): - other = self.any_as_fmpz_mod(other) - if other is NotImplemented: - return NotImplemented - + @staticmethod + def _sub_(left, right): cdef fmpz_mod res res = fmpz_mod.__new__(fmpz_mod) - res.ctx = self.ctx + # Case when left and right are already fmpz_mod + if typecheck(left, fmpz_mod) and typecheck(right, fmpz_mod): + if not (left).ctx == (right).ctx: + raise ValueError("moduli must match") + + # Case when right is not fmpz_mod, try to convert to fmpz + elif typecheck(left, fmpz_mod): + right = (left).ctx.any_as_fmpz_mod(right) + if right is NotImplemented: + return NotImplemented + + # Case when left is not fmpz_mod, try to convert to fmpz + else: + left = (right).ctx.any_as_fmpz_mod(left) + if left is NotImplemented: + return NotImplemented + + res.ctx = (left).ctx fmpz_mod_sub( - res.val, self.val, (other).val, self.ctx.val + res.val, (left).val, (right).val, res.ctx.val ) return res - def __rsub__(self, other): - return self.__sub__(other).__neg__() + def __sub__(s, t): + return fmpz_mod._sub_(s, t) + + def __rsub__(s, t): + return fmpz_mod._sub_(t, s) def __mul__(self, other): - other = self.any_as_fmpz_mod(other) + other = self.ctx.any_as_fmpz_mod(other) if other is NotImplemented: return NotImplemented @@ -263,38 +459,32 @@ cdef class fmpz_mod(flint_scalar): cdef bint check cdef fmpz_mod res res = fmpz_mod.__new__(fmpz_mod) - - # Division when left and right are fmpz_mod + + # Case when left and right are already fmpz_mod if typecheck(left, fmpz_mod) and typecheck(right, fmpz_mod): - res.ctx = (left).ctx if not (left).ctx == (right).ctx: raise ValueError("moduli must match") - check = fmpz_mod_divides( - res.val, (left).val, (right).val, res.ctx.val - ) - - # Case when only left is fmpz_mod + + # Case when right is not fmpz_mod, try to convert to fmpz elif typecheck(left, fmpz_mod): - res.ctx = (left).ctx - right = any_as_fmpz(right) + right = (left).ctx.any_as_fmpz_mod(right) if right is NotImplemented: - return NotImplemented - check = fmpz_mod_divides( - res.val, (left).val, (right).val, res.ctx.val - ) + return NotImplemented - # Case when right is an fmpz_mod + # Case when left is not fmpz_mod, try to convert to fmpz else: - res.ctx = (right).ctx - left = any_as_fmpz(left) + left = (right).ctx.any_as_fmpz_mod(left) if left is NotImplemented: - return NotImplemented - check = fmpz_mod_divides( - res.val, (left).val, (right).val, res.ctx.val - ) - + return NotImplemented + + res.ctx = (left).ctx + check = fmpz_mod_divides( + res.val, (left).val, (right).val, res.ctx.val + ) if check == 0: - raise ZeroDivisionError(f"{right} is not invertible modulo {res.ctx.modulus()}") + raise ZeroDivisionError( + f"{right} is not invertible modulo {res.ctx.modulus()}" + ) return res @@ -307,39 +497,6 @@ cdef class fmpz_mod(flint_scalar): def __floordiv__(self, other): return NotImplemented - def inverse(self, check=True): - r""" - Computes :math:`a^{-1} \pmod N` - - When check=False, the solutions is assumed to exist and Flint will abort on - failure. - - >>> mod_ctx = fmpz_mod_ctx(163) - >>> mod_ctx(2).inverse() - fmpz_mod(82, 163) - >>> mod_ctx(2).inverse(check=False) - fmpz_mod(82, 163) - """ - cdef fmpz_mod res - res = fmpz_mod.__new__(fmpz_mod) - res.ctx = self.ctx - - if check is False: - fmpz_mod_inv(res.val, self.val, self.ctx.val) - return res - - cdef bint r - cdef fmpz one = fmpz.__new__(fmpz) - fmpz_one(one.val) - - r = fmpz_mod_divides( - res.val, one.val, self.val, self.ctx.val - ) - if r == 0: - raise ZeroDivisionError(f"{self} is not invertible modulo {self.ctx.modulus()}") - - return res - def __invert__(self): return self.inverse() @@ -359,6 +516,8 @@ cdef class fmpz_mod(flint_scalar): ) if check == 0: - raise ZeroDivisionError(f"{self} is not invertible modulo {self.ctx.modulus()}") + raise ZeroDivisionError( + f"{self} is not invertible modulo {self.ctx.modulus()}" + ) return res