From 35f151d6cb2486d106cde31552c5820d693c4dd5 Mon Sep 17 00:00:00 2001 From: Chip Hogg Date: Fri, 8 Nov 2024 13:40:24 -0500 Subject: [PATCH] Add utilities for modular arithmetic (#320) These utilities all operate on unsigned 64-bit integers. The main design goal is to produce correct answers for all inputs while avoiding overflow. Most operations are standard: addition, subtraction, multiplication, powers. The one non-standard one is "half_mod_odd". This operation is `(x/2) % n`, but only if `n` is odd. If `x` is also odd, we add `n` before dividing by 2, which gives us an integer result. We'll need this operation for the strong Lucas probable prime checking later on. These are Au-internal functions, so we use unchecked preconditions: as long as the caller makes sure the inputs are already reduced-mod-n, we'll keep them that way. Helps #217. --- au/code/au/CMakeLists.txt | 2 + au/code/au/utility/mod.hh | 102 +++++++++++++++++++++++++ au/code/au/utility/test/mod_test.cc | 114 ++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+) create mode 100644 au/code/au/utility/mod.hh create mode 100644 au/code/au/utility/test/mod_test.cc diff --git a/au/code/au/CMakeLists.txt b/au/code/au/CMakeLists.txt index c6e4f2af..640064df 100644 --- a/au/code/au/CMakeLists.txt +++ b/au/code/au/CMakeLists.txt @@ -158,6 +158,7 @@ header_only_library( units/yards.hh units/yards_fwd.hh utility/factoring.hh + utility/mod.hh utility/string_constant.hh utility/type_traits.hh ) @@ -458,6 +459,7 @@ gtest_based_test( NAME utility_test SRCS utility/test/factoring_test.cc + utility/test/mod_test.cc utility/test/string_constant_test.cc utility/test/type_traits_test.cc DEPS diff --git a/au/code/au/utility/mod.hh b/au/code/au/utility/mod.hh new file mode 100644 index 00000000..4b7f3d39 --- /dev/null +++ b/au/code/au/utility/mod.hh @@ -0,0 +1,102 @@ +// Copyright 2024 Aurora Operations, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace au { +namespace detail { + +// (a + b) % n +// +// Precondition: (a < n). +// Precondition: (b < n). +constexpr uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n) { + if (a >= n - b) { + return a - (n - b); + } else { + return a + b; + } +} + +// (a - b) % n +// +// Precondition: (a < n). +// Precondition: (b < n). +constexpr uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n) { + if (a >= b) { + return a - b; + } else { + return n - (b - a); + } +} + +// (a * b) % n +// +// Precondition: (a < n). +// Precondition: (b < n). +constexpr uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n) { + // Start by trying the simplest case, where everything "fits". + if (b == 0u || a < std::numeric_limits::max() / b) { + return (a * b) % n; + } + + // We know the "negative" result is smaller, because we've taken as many copies of `a` as will + // fit into `n`. So, do the reduced calculation in "negative space", and then transform the + // result back at the end. + uint64_t chunk_size = n / a; + uint64_t num_chunks = b / chunk_size; + uint64_t negative_chunk = n - (a * chunk_size); // == n % a (but this should be cheaper) + uint64_t chunk_result = n - mul_mod(negative_chunk, num_chunks, n); + + // Compute the leftover. (We don't need to recurse, because we know it will fit.) + uint64_t leftover = b - num_chunks * chunk_size; + uint64_t leftover_result = (a * leftover) % n; + + return add_mod(chunk_result, leftover_result, n); +} + +// (a / 2) % n +// +// Precondition: (a < n). +// Precondition: (n is odd). +// +// If `a` is even, this is of course simply `a / 2` (because `(a < n)` as a precondition). +// Otherwise, we give the result one would obtain by first adding `n` (guaranteeing an even number, +// since `n` is also odd as a precondition), and _then_ dividing by `2`. +constexpr uint64_t half_mod_odd(uint64_t a, uint64_t n) { + return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u)); +} + +// (base ^ exp) % n +constexpr uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n) { + uint64_t result = 1u; + base %= n; + + while (exp > 0u) { + if (exp % 2u == 1u) { + result = mul_mod(result, base, n); + } + + exp /= 2u; + base = mul_mod(base, base, n); + } + + return result; +} + +} // namespace detail +} // namespace au diff --git a/au/code/au/utility/test/mod_test.cc b/au/code/au/utility/test/mod_test.cc new file mode 100644 index 00000000..44dc9891 --- /dev/null +++ b/au/code/au/utility/test/mod_test.cc @@ -0,0 +1,114 @@ +// Copyright 2024 Aurora Operations, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "au/utility/mod.hh" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace au { +namespace detail { +namespace { + +constexpr auto MAX = std::numeric_limits::max(); + +TEST(AddMod, HandlesSimpleCases) { + EXPECT_EQ(add_mod(1u, 2u, 5u), 3u); + EXPECT_EQ(add_mod(4u, 4u, 5u), 3u); +} + +TEST(AddMod, HandlesVeryLargeNumbers) { EXPECT_EQ(add_mod(MAX - 1u, MAX - 2u, MAX), MAX - 3u); } + +TEST(SubMod, HandlesSimpleCases) { + EXPECT_EQ(sub_mod(2u, 1u, 5u), 1u); + EXPECT_EQ(sub_mod(1u, 2u, 5u), 4u); +} + +TEST(SubMod, HandlesVeryLargeNumbers) { + EXPECT_EQ(sub_mod(MAX - 2u, MAX - 1u, MAX), MAX - 1u); + EXPECT_EQ(sub_mod(1u, MAX - 1u, MAX), 2u); +} + +TEST(MulMod, HandlesSimpleCases) { + EXPECT_EQ(mul_mod(6u, 7u, 10u), 2u); + EXPECT_EQ(mul_mod(13u, 11u, 50u), 43u); +} + +TEST(MulMod, HandlesHugeNumbers) { + constexpr auto JUST_UNDER_HALF = MAX / 2u; + ASSERT_EQ(JUST_UNDER_HALF * 2u + 1u, MAX); + + EXPECT_EQ(mul_mod(JUST_UNDER_HALF, 10u, MAX), MAX - 5u); +} + +TEST(HalfModOdd, HalvesEvenNumbers) { + EXPECT_EQ(half_mod_odd(0u, 11u), 0u); + EXPECT_EQ(half_mod_odd(10u, 11u), 5u); +} + +TEST(HalfModOdd, HalvesSumWithNForOddNumbers) { + EXPECT_EQ(half_mod_odd(1u, 11u), 6u); + EXPECT_EQ(half_mod_odd(9u, 11u), 10u); +} + +TEST(HalfModOdd, SameAsMultiplyingByCeilOfNOver2WhenNIsOdd) { + // An interesting test case, which helps us make sense of the operation of "dividing by 2" in + // modular arithmetic. When `n` is odd, `2` has a multiplicative inverse, so we can understand + // "dividing by two" in terms of multiplying by this inverse. + // + // This fails when `n` is even, but so does dividing by 2 generally. + // + // In principle, we could replace our `half_mod_odd` implementation with this, and it would have + // the same preconditions, but there's a chance it would be less efficient (because `mul_mod` + // may recurse multiple times). Also, keeping them separate lets us use this test case as an + // independent check. + std::vector n_values{9u, 11u, 8723493u, MAX}; + for (const auto &n : n_values) { + const auto half_n = n / 2u + 1u; + + std::vector x_values{0u, 1u, 2u, (n / 2u), (n / 2u + 1u), (n - 2u), (n - 1u)}; + for (const auto &x : x_values) { + EXPECT_EQ(half_mod_odd(x, n), mul_mod(x, half_n, n)); + } + } +} + +TEST(PowMod, HandlesSimpleCases) { + auto to_the_eighth = [](uint64_t x) { + x *= x; + x *= x; + x *= x; + return x; + }; + EXPECT_EQ(pow_mod(5u, 8u, 9u), to_the_eighth(5u) % 9u); +} + +TEST(PowMod, HandlesNumbersThatWouldOverflow) { EXPECT_EQ(pow_mod(2u, 64u, MAX), 1u); } + +TEST(PowMod, ProducesSameAnswerAsRepeatedModMulForLargeNumbers) { + const auto x = MAX / 3u * 2u; + const auto to_pow_2 = mul_mod(x, x, MAX); + const auto to_pow_4 = mul_mod(to_pow_2, to_pow_2, MAX); + const auto to_pow_5 = mul_mod(x, to_pow_4, MAX); + const auto to_pow_10 = mul_mod(to_pow_5, to_pow_5, MAX); + const auto to_pow_11 = mul_mod(x, to_pow_10, MAX); + const auto to_pow_22 = mul_mod(to_pow_11, to_pow_11, MAX); + EXPECT_EQ(pow_mod(x, 22u, MAX), to_pow_22); +} + +} // namespace +} // namespace detail +} // namespace au