diff --git a/au/code/au/utility/probable_primes.hh b/au/code/au/utility/probable_primes.hh index a9af8efc..0b7c917a 100644 --- a/au/code/au/utility/probable_primes.hh +++ b/au/code/au/utility/probable_primes.hh @@ -14,6 +14,8 @@ #pragma once +#include + #include "au/utility/mod.hh" namespace au { @@ -82,5 +84,85 @@ constexpr PrimeResult miller_rabin(std::size_t a, uint64_t n) { return PrimeResult::COMPOSITE; } +constexpr uint64_t gcd(uint64_t a, uint64_t b) { + while (b != 0u) { + const auto remainder = a % b; + a = b; + b = remainder; + } + return a; +} + +// Map `true` onto `1`, and `false` onto `0`. +// +// The conversions `true` -> `1` and `false` -> `0` are guaranteed by the standard. This is a +// branchless implementation, which should generally be faster. +constexpr int bool_sign(bool x) { return x - (!x); } + +// +// The Jacobi symbol (a/n) is defined for odd positive `n` and any integer `a` as the product of the +// Legendre symbols (a/p) for all prime factors `p` of n. There are several rules that make this +// easier to calculate, including: +// +// 1. (a/n) = (b/n) whenever (a % n) == (b % n). +// +// 2. (2a/n) = (a/n) if n is congruent to 1 or 7 (mod 8), and -(a/n) if n is congruent to 3 or 5. +// +// 3. (1/n) = 1 for all n. +// +// 4. (a/n) = 0 whenever a and n have a nontrivial common factor. +// +// 5. (a/n) = (n/a) * (-1)^x if a and n are both odd, positive, and coprime. Here, x is 0 if +// either a or n is congruent to 1 (mod 4), and 1 otherwise. +// +constexpr int jacobi_symbol_positive_numerator(uint64_t a, uint64_t n, int start) { + int result = start; + + while (a != 0u) { + // Handle even numbers in the "numerator". + const int sign_for_even = bool_sign(n % 8u == 1u || n % 8u == 7u); + while (a % 2u == 0u) { + a /= 2u; + result *= sign_for_even; + } + + // `jacobi_symbol(1, n)` is `1` for all `n`. + if (a == 1u) { + return result; + } + + // `jacobi_symbol(a, n)` is `0` whenever `a` and `n` have a common factor. + if (gcd(a, n) != 1u) { + return 0; + } + + // At this point, `a` and `n` are odd, positive, and coprime. We can use the reciprocity + // relationship to "flip" them, and modular arithmetic to reduce them. + + // First, compute the sign change from the flip. + result *= bool_sign((a % 4u == 1u) || (n % 4u == 1u)); + + // Now, do the flip-and-reduce. + const uint64_t new_a = n % a; + n = a; + a = new_a; + } + return 0; +} +constexpr int jacobi_symbol(int64_t raw_a, uint64_t n) { + // Degenerate case: n = 1. + if (n == 1u) { + return 1; + } + + // Starting conditions: transform `a` to strictly non-negative values, setting `result` to the + // sign we pick up from this operation (if any). + int result = bool_sign((raw_a >= 0) || (n % 4u == 1u)); + auto a = static_cast(std::abs(raw_a)) % n; + + // Delegate to an implementation which can only handle positive numbers. + return jacobi_symbol_positive_numerator(a, n, result); +} + } // namespace detail } // namespace au diff --git a/au/code/au/utility/test/probable_primes_test.cc b/au/code/au/utility/test/probable_primes_test.cc index c5babb9a..41285c47 100644 --- a/au/code/au/utility/test/probable_primes_test.cc +++ b/au/code/au/utility/test/probable_primes_test.cc @@ -200,6 +200,62 @@ TEST(MillerRabin, SupportsConstexpr) { static_assert(result == PrimeResult::PROBABLY_PRIME, "997 is prime"); } +TEST(Gcd, ResultIsAlwaysAFactorAndGCDFindsNoLargerFactor) { + for (auto i = 0u; i < 500u; ++i) { + for (auto j = 1u; j < i; ++j) { + const auto g = gcd(i, j); + EXPECT_EQ(i % g, 0u); + EXPECT_EQ(j % g, 0u); + + // Brute force: no larger factors. + for (auto k = g + 1u; k < j / 2u; ++k) { + EXPECT_FALSE((i % k == 0u) && (j % k == 0u)); + } + } + } +} + +TEST(Gcd, HandlesZeroCorrectly) { + // The usual convention: if one argument is 0, return the other argument. + EXPECT_EQ(gcd(0u, 0u), 0u); + EXPECT_EQ(gcd(10u, 0u), 10u); + EXPECT_EQ(gcd(0u, 10u), 10u); +} + +TEST(JacobiSymbol, ZeroWhenCommonFactorExists) { + for (int i = -20; i <= 20; ++i) { + for (auto j = 1u; j <= 19u; j += 2u) { + for (auto factor = 3u; factor < 200u; factor += 2u) { + // Make sure that `j * factor` is odd, or else the result is undefined. + EXPECT_EQ(jacobi_symbol(i * static_cast(factor), j * factor), 0) + << "jacobi(" << i * static_cast(factor) << ", " << j * factor + << ") should be 0"; + } + } + } +} + +TEST(JacobiSymbol, AlwaysOneWhenFirstInputIsOne) { + for (auto i = 3u; i < 99u; i += 2u) { + EXPECT_EQ(jacobi_symbol(1, i), 1) << "jacobi(1, " << i << ") should be 1"; + } +} + +TEST(JacobiSymbol, ReproducesExamplesFromWikipedia) { + // https://en.wikipedia.org/wiki/Jacobi_symbol#Example_of_calculations + EXPECT_EQ(jacobi_symbol(1001, 9907), -1); + + // https://en.wikipedia.org/wiki/Jacobi_symbol#Primality_testing + EXPECT_EQ(jacobi_symbol(19, 45), 1); + EXPECT_EQ(jacobi_symbol(8, 21), -1); + EXPECT_EQ(jacobi_symbol(5, 21), 1); +} + +TEST(BoolSign, ReturnsCorrectValues) { + EXPECT_EQ(bool_sign(true), 1); + EXPECT_EQ(bool_sign(false), -1); +} + } // namespace } // namespace detail } // namespace au