Skip to content

Commit

Permalink
Add utilities to make Strong Lucas tests easier
Browse files Browse the repository at this point in the history
The Strong Lucas test coming in the next PR will already be complicated
enough.  It'll be convenient, and less distracting, if we already have
functions for certain operations we'll need.

One thing we'll need to do is detect inputs that are perfect squares.
Fortunately, this is pretty easy to do robustly and quickly, with
Newton's method.  We _don't_ want to use `std::sqrt`, because that takes
us into the floating point domain for no good reason, which could give
us wrong answers for larger integers.

The other thing we need is Jacobi symbols.  These are a lot more
obscure, but thankfully, still resonably straightforward to compute.
The Wikipedia page (https://en.wikipedia.org/wiki/Jacobi_symbol) has a
good explanation, and in particular, good instructions for computing
values.

With these utilities in place, the Strong Lucas code should be easier to
review.
  • Loading branch information
chiphogg committed Nov 15, 2024
1 parent 7e1baf6 commit d56ffc0
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
77 changes: 77 additions & 0 deletions src/core/include/mp-units/ext/prime.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,83 @@ struct NumberDecomposition {
return false;
}

// The Jacobi symbol, notated as `(a/n)`, is defined for odd positive `n` and any integer `a`, taking values
// in the set `{-1, 0, 1}`. Besides being a completely multiplicative function (so that, for example, both
// (a*b/n) = (a/n) * (b/n), and (a/n*m) = (a/n) * (a/m)), it obeys the following symmetry rules, which enable
// its calculation:
//
// 1. (a/1) = 1, and (1/n) = 1, for all a and n.
//
// 2. (a/n) = 0 whenever a and n have a nontrivial common factor.
//
// 3. (a/n) = (b/n) whenever (a % n) = (b % n).
//
// 4. (2a/n) = (a/n) if n % 8 = 1 or 7, and -(a/n) if n % 8 = 3 or 5.
//
// 5. (a/n) = (n/a) * x if a and n are both odd, positive, and coprime. Here, x is 1 if either (a % 4) = 1
// or (n % 4) = 1, and -1 otherwise.
//
// 6. (-1/n) = 1 if n % 4 = 1, and -1 if n % 4 = 3.
[[nodiscard]] consteval int jacobi_symbol(int64_t raw_a, uint64_t n)
{
// Rule 1: n=1 case.
if (n == 1u) {
return 1;
}

// Starting conditions: transform `a` to strictly non-negative values, setting `result` to the sign that we
// pick up (if any) from following these rules (i.e., rules 3 and 6).
int result = ((raw_a >= 0) || (n % 4u == 1u)) ? 1 : -1;
auto a = static_cast<uint64_t>(raw_a < 0 ? -raw_a : raw_a) % n;

while (a != 0u) {
// Rule 4.
const int sign_for_even = (n % 8u == 1u || n % 8u == 7u) ? 1 : -1;
while (a % 2u == 0u) {
a /= 2u;
result *= sign_for_even;
}

// Rule 1: a=1 case.
if (a == 1u) {
return result;
}

// Rule 2.
if (std::gcd(a, n) != 1u) {
return 0;
}

// Note that at this point, we know that `a` and `n` are coprime, and are both odd and positive.
// Therefore, we meet the preconditions for rule 5 (the "flip-and-reduce" rule).
result *= (n % 4u == 1u || a % 4u == 1u) ? 1 : -1;
const uint64_t new_a = n % a;
n = a;
a = new_a;
}

return 0;
}

[[nodiscard]] consteval bool is_perfect_square(uint64_t n)
{
if (n < 2u) {
return true;
}

uint64_t prev = n / 2u;
while (true) {
const uint64_t curr = (prev + n / prev) / 2u;
if (curr * curr == n) {
return true;
}
if (curr >= prev) {
return false;
}
prev = curr;
}
}

[[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n)
{
for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) {
Expand Down
43 changes: 43 additions & 0 deletions test/static/prime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,47 @@ static_assert(miller_rabin_probable_prime(2u, 9'007'199'254'740'881u), "Large kn

static_assert(miller_rabin_probable_prime(2u, 18'446'744'073'709'551'557u), "Largest 64-bit prime");

// Jacobi symbols --- a building block for the Strong Lucas probable prime test, needed for Baillie-PSW.
static_assert(jacobi_symbol(1, 1u) == 1, "Jacobi symbol always 1 when 'numerator' is 1");
static_assert(jacobi_symbol(1, 3u) == 1, "Jacobi symbol always 1 when 'numerator' is 1");
static_assert(jacobi_symbol(1, 5u) == 1, "Jacobi symbol always 1 when 'numerator' is 1");
static_assert(jacobi_symbol(1, 987654321u) == 1, "Jacobi symbol always 1 when 'numerator' is 1");

static_assert(jacobi_symbol(3, 1u) == 1, "Jacobi symbol always 1 when 'denominator' is 1");
static_assert(jacobi_symbol(5, 1u) == 1, "Jacobi symbol always 1 when 'denominator' is 1");
static_assert(jacobi_symbol(-1234567890, 1u) == 1, "Jacobi symbol always 1 when 'denominator' is 1");

static_assert(jacobi_symbol(10, 5u) == 0, "Jacobi symbol always 0 when there's a common factor");
static_assert(jacobi_symbol(25, 15u) == 0, "Jacobi symbol always 0 when there's a common factor");
static_assert(jacobi_symbol(-24, 9u) == 0, "Jacobi symbol always 0 when there's a common factor");

static_assert(jacobi_symbol(14, 9u) == +jacobi_symbol(7, 9u),
"Divide numerator by 2: positive when (denom % 8) in {1, 7}");
static_assert(jacobi_symbol(14, 15u) == +jacobi_symbol(7, 15u),
"Divide numerator by 2: positive when (denom % 8) in {1, 7}");
static_assert(jacobi_symbol(14, 11u) == -jacobi_symbol(7, 11u),
"Divide numerator by 2: negative when (denom % 8) in {3, 5}");
static_assert(jacobi_symbol(14, 13u) == -jacobi_symbol(7, 13u),
"Divide numerator by 2: negative when (denom % 8) in {3, 5}");
static_assert(jacobi_symbol(19, 9u) == +jacobi_symbol(9, 19u), "Flip is identity when (n % 4) = 1");
static_assert(jacobi_symbol(17, 7u) == +jacobi_symbol(7, 17u), "Flip is identity when (a % 4) = 1");
static_assert(jacobi_symbol(19, 7u) == -jacobi_symbol(9, 7u), "Flip changes sign when (n % 4) = 3 and (a % 4) = 3");

static_assert(jacobi_symbol(1001, 9907u) == -1, "Example from Wikipedia page");
static_assert(jacobi_symbol(19, 45u) == 1, "Example from Wikipedia page");
static_assert(jacobi_symbol(8, 21u) == -1, "Example from Wikipedia page");
static_assert(jacobi_symbol(5, 21u) == 1, "Example from Wikipedia page");

// Tests for perfect square finder
static_assert(is_perfect_square(0u));
static_assert(is_perfect_square(1u));
static_assert(!is_perfect_square(2u));
static_assert(is_perfect_square(4u));

constexpr uint64_t BIG_SQUARE = [](auto x) { return x * x; }((uint64_t{1u} << 32) - 1u);
static_assert(!is_perfect_square(BIG_SQUARE - 1u));
static_assert(is_perfect_square(BIG_SQUARE));
static_assert(!is_perfect_square(BIG_SQUARE + 1u));

} // namespace

0 comments on commit d56ffc0

Please sign in to comment.