Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

include modsqrt with lifting over prime powers #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions crypto_commons/rsa/rsa_commons.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import itertools
from collections import Counter
from functools import reduce

from crypto_commons.generic import bytes_to_long, find_divisor, multiply, long_to_bytes
from crypto_commons.generic import bytes_to_long, find_divisor, multiply, long_to_bytes, long_range


def rsa_printable(x, exp, n):
Expand Down Expand Up @@ -269,9 +270,37 @@ def homomorphic_blinding_rsa(payload, get_signature, N, splits=2):
return result_sig


def modular_sqrt_composite_powers(c, primes):
"""
Calculates square root mod composite value given all modulus factors, even they are repeated
For a = b^2 mod p^k1*q^k2*r^k3*m^k4... calculates b
:param c: residue
:param factors: list of modulus prime factors
:return: all potential root values
"""
factors = Counter(primes).items()
simple_roots = {prime: list({modular_sqrt(c, prime), prime - modular_sqrt(c, prime)}) for prime, _ in factors}
assert all(pow(root, 2, prime) == c % prime for prime, roots in simple_roots.items() for root in roots)
f = lambda x: x ** 2 - c
df = lambda x: 2 * x
roots = {}
for prime, k in factors:
if k > 1:
lifted = hensel_lifting(f, df, prime, k, simple_roots[prime])
assert all(pow(root, 2, prime ** k) == (c % prime ** k) for root in lifted)
roots[prime ** k] = lifted
else:
roots[prime] = simple_roots[prime]
res = [[(residue, modulo) for residue in roots] for modulo, roots in roots.items()]
solutions = [solve_crt(x) for x in itertools.product(*res)]
n = multiply(primes)
assert all(pow(solution, 2, n) for solution in solutions)
return solutions


def modular_sqrt_composite(c, factors):
"""
Calculates modular square root of composite value for given all modulus factors
Calculates square root mod composite value given all co-prime modulus factors
For a = b^2 mod p*q*r*m... calculates b
:param c: residue
:param factors: list of modulus prime factors
Expand All @@ -295,18 +324,19 @@ def modular_sqrt(a, p):
:param p: modulus
:return: root value
"""
if legendre_symbol(a, p) != 1:
return 0
elif a == 0:
a = a % p
if a == 0:
return 0
elif p == 2:
return p
return a
elif legendre_symbol(a, p) != 1:
return 0
elif p % 4 == 3:
return pow(a, (p + 1) // 4, p)
s = p - 1
e = 0
while s % 2 == 0:
s /= 2
s //= 2
e += 1
n = 2
while legendre_symbol(n, p) != -1:
Expand All @@ -318,7 +348,7 @@ def modular_sqrt(a, p):
while True:
t = b
m = 0
for m in xrange(r):
for m in long_range(0, r):
if t == 1:
break
t = pow(t, 2, p)
Expand Down
Empty file.
Empty file.
38 changes: 38 additions & 0 deletions crypto_commons/tests/rsa/test_mod_sqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import random
import unittest

from crypto_commons.generic import get_primes, multiply
from crypto_commons.rsa.rsa_commons import modular_sqrt, modular_sqrt_composite, modular_sqrt_composite_powers


class TestModSqrt(unittest.TestCase):
def test_roots_over_prime(self):
primes = get_primes(10000)
for i in range(0, 10):
prime = random.choice(primes)
root = random.randint(prime // 2, prime - 1)
residue = pow(root, 2, prime)
with self.subTest(prime=prime, root=root, residue=residue):
computed_root = modular_sqrt(residue, prime)
self.assertTrue(root in [computed_root, prime - computed_root])

def test_roots_over_simple_composites(self):
primes = get_primes(10000)
for i in range(0, 10):
factors = random.sample(primes, k=random.randint(2, 10))
modulus = multiply(factors)
root = random.randint(modulus // 2, modulus - 1)
residue = pow(root, 2, modulus)
with self.subTest(root=root, residue=residue, factors=factors):
self.assertTrue(root in modular_sqrt_composite(residue, factors))

def test_roots_over_composites_with_prime_powers(self):
primes = get_primes(10000)
for i in range(0, 10):
simple_factors = random.sample(primes, k=random.randint(2, 10))
factors = sum([[prime] * random.randint(2, 10) for prime in simple_factors], [])
modulus = multiply(factors)
root = random.randint(modulus // 2, modulus - 1)
residue = pow(root, 2, modulus)
with self.subTest(root=root, residue=residue, factors=factors):
self.assertTrue(root in modular_sqrt_composite_powers(residue, factors))