Skip to content

Commit

Permalink
Add helpers for modular arithmetic
Browse files Browse the repository at this point in the history
The prime-testing techniques we will use (Miller-Rabin, Strong Lucas)
all make heavy usage of modular arithmetic.  Therefore, we lay those
foundations here, adding utilities to perform the basic arithmetic
operations robustly.

Since these are internal-only helper functions, we don't bother checking
the preconditions, although we state them clearly in the contract
comment for each utility.  After C++26, we could add contracts for
these.

Helps mpusz#509.
  • Loading branch information
chiphogg committed Nov 11, 2024
1 parent 727a898 commit 8a7483f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/core/include/mp-units/ext/prime.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,86 @@ import std;

namespace mp_units::detail {

// (a + b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n)
{
if (a >= n - b) {
return a - (n - b);
} else {
return a + b;
}
}

// (a - b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n)
{
if (a >= b) {
return a - b;
} else {
return n - (b - a);
}
}

// (a * b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n)
{
if (b == 0u || a < std::numeric_limits<uint64_t>::max() / b) {
return (a * b) % n;
}

const uint64_t batch_size = n / a;
const uint64_t num_batches = b / batch_size;

return add_mod(
// Transform into "negative space" to make the first parameter as small as possible;
// then, transform back.
n - mul_mod(n % a, num_batches, n),

// Handle the leftover product (which is guaranteed to fit in the integer type).
(a * (b % batch_size)) % n,

n);
}

// (a / 2) % n.
//
// Precondition: (a < n).
// Precondition: (n % 2 == 1).
[[nodiscard]] consteval uint64_t half_mod_odd(uint64_t a, uint64_t n)
{
return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u));
}

// (base ^ exp) % n.
[[nodiscard]] consteval uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n)
{
uint64_t result = 1u;
base %= n;

while (exp > 0u) {
if (exp % 2u == 1u) {
result = mul_mod(result, base, n);
}

exp /= 2u;
base = mul_mod(base, base, n);
}

return result;
}

[[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n)
{
for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) {
Expand Down
26 changes: 26 additions & 0 deletions test/static/prime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ using namespace mp_units::detail;

namespace {

inline constexpr auto MAX_U64 = std::numeric_limits<std::uint64_t>::max();

template<std::size_t BasisSize, std::size_t... Is>
constexpr bool check_primes(std::index_sequence<Is...>)
{
Expand Down Expand Up @@ -78,4 +80,28 @@ static_assert(!wheel_factorizer<3>::is_prime(0));
static_assert(!wheel_factorizer<3>::is_prime(1));
static_assert(wheel_factorizer<3>::is_prime(2));

// Modular arithmetic.
static_assert(add_mod(1u, 2u, 5u) == 3u);
static_assert(add_mod(4u, 4u, 5u) == 3u);
static_assert(add_mod(MAX_U64 - 1u, MAX_U64 - 2u, MAX_U64) == MAX_U64 - 3u);

static_assert(sub_mod(2u, 1u, 5u) == 1u);
static_assert(sub_mod(1u, 2u, 5u) == 4u);
static_assert(sub_mod(MAX_U64 - 2u, MAX_U64 - 1u, MAX_U64) == MAX_U64 - 1u);
static_assert(sub_mod(1u, MAX_U64 - 1u, MAX_U64) == 2u);

static_assert(mul_mod(6u, 7u, 10u) == 2u);
static_assert(mul_mod(13u, 11u, 50u) == 43u);
static_assert(mul_mod(MAX_U64 / 2u, 10u, MAX_U64) == MAX_U64 - 5u);

static_assert(half_mod_odd(0u, 11u) == 0u);
static_assert(half_mod_odd(10u, 11u) == 5u);
static_assert(half_mod_odd(1u, 11u) == 6u);
static_assert(half_mod_odd(9u, 11u) == 10u);
static_assert(half_mod_odd(MAX_U64 - 1u, MAX_U64) == (MAX_U64 - 1u) / 2u);
static_assert(half_mod_odd(MAX_U64 - 2u, MAX_U64) == MAX_U64 - 1u);

static_assert(pow_mod(5u, 8u, 9u) == ((5u * 5u * 5u * 5u) * (5u * 5u * 5u * 5u)) % 9u);
static_assert(pow_mod(2u, 64u, MAX_U64) == 1u);

} // namespace

0 comments on commit 8a7483f

Please sign in to comment.