Skip to content

Commit

Permalink
Merge pull request #205 from kroma-network/feat/implement-proof-verif…
Browse files Browse the repository at this point in the history
…ication

feat: implement proof verification
  • Loading branch information
Ryan Kim authored Dec 20, 2023
2 parents 1dff39e + 295e50c commit a12e82d
Show file tree
Hide file tree
Showing 26 changed files with 1,369 additions and 125 deletions.
2 changes: 1 addition & 1 deletion tachyon/math/base/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ tachyon_cc_library(
hdrs = ["rings.h"],
deps = [
":groups",
"//tachyon/base:template_util",
"//tachyon/base:parallelize",
],
)

Expand Down
20 changes: 8 additions & 12 deletions tachyon/math/base/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,16 @@ template <typename F>
class Field : public AdditiveGroup<F>, public MultiplicativeGroup<F> {
public:
// Sum of products: a₁ * b₁ + a₂ * b₂ + ... + aₙ * bₙ
template <
typename InputIterator,
std::enable_if_t<std::is_same_v<F, base::iter_value_t<InputIterator>>>* =
nullptr>
constexpr static F SumOfProducts(InputIterator a_first, InputIterator a_last,
InputIterator b_first,
InputIterator b_last) {
return Ring<F>::SumOfProducts(std::move(a_first), std::move(a_last),
std::move(b_first), std::move(b_last));
template <typename ContainerA, typename ContainerB>
constexpr static F SumOfProducts(const ContainerA& a, const ContainerB& b) {
return Ring<F>::SumOfProducts(a, b);
}

template <typename Container>
constexpr static F SumOfProducts(const Container& a, const Container& b) {
return Ring<F>::SumOfProducts(a, b);
// Sum of products: a₁ * b₁ + a₂ * b₂ + ... + aₙ * bₙ
template <typename ContainerA, typename ContainerB>
constexpr static F SumOfProductsSerial(const ContainerA& a,
const ContainerB& b) {
return Ring<F>::SumOfProductsSerial(a, b);
}
};

Expand Down
20 changes: 18 additions & 2 deletions tachyon/math/base/field_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,28 @@ class FieldTest : public testing::Test {

} // namespace

TEST(FieldTest, SumOfProducts) {
TEST(FieldTest, SumOfProductsSerial) {
std::vector<GF7> a = {GF7(2), GF7(3), GF7(4)};
std::vector<GF7> b = {GF7(1), GF7(2), GF7(3)};

GF7 result = Field<GF7>::SumOfProducts(a, b);
GF7 result = Field<GF7>::SumOfProductsSerial(a, b);
EXPECT_EQ(result, GF7(6));
}

TEST(FieldTest, SumOfProductsParallel) {
#if defined(TACHYON_HAS_OPENMP)
size_t thread_nums = static_cast<size_t>(omp_get_max_threads());
#else
size_t thread_nums = 1;
#endif

std::vector<GF7> a =
base::CreateVector(thread_nums * 10, []() { return GF7::Random(); });
std::vector<GF7> b =
base::CreateVector(thread_nums * 10, []() { return GF7::Random(); });

EXPECT_EQ(Field<GF7>::SumOfProducts(a, b),
Field<GF7>::SumOfProductsSerial(a, b));
}

} // namespace tachyon::math
62 changes: 44 additions & 18 deletions tachyon/math/base/rings.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
#define TACHYON_MATH_BASE_RINGS_H_

#include <type_traits>
#include <vector>

#include "tachyon/base/template_util.h"
#include "tachyon/base/parallelize.h"
#include "tachyon/math/base/groups.h"

namespace tachyon::math {
Expand All @@ -25,27 +26,52 @@ class Ring : public AdditiveGroup<F>, public MultiplicativeSemigroup<F> {
// This is taken and modified from
// https://github.com/arkworks-rs/algebra/blob/5dfeedf560da6937a5de0a2163b7958bd32cd551/ff/src/fields/mod.rs#L298C1-L305
// Sum of products: a₁ * b₁ + a₂ * b₂ + ... + aₙ * bₙ
template <
typename InputIterator,
std::enable_if_t<std::is_same_v<F, base::iter_value_t<InputIterator>>>* =
nullptr>
constexpr static F SumOfProducts(InputIterator a_first, InputIterator a_last,
InputIterator b_first,
InputIterator b_last) {
// TODO(chokobole): If I call |SumOfProducts()| instead of
// |SumOfProductsSerial| for all call sites, it gets stuck when doing
// unittests. I think we need a some general threshold to check whether it is
// good to doing parallelization.
template <typename ContainerA, typename ContainerB>
constexpr static F SumOfProducts(const ContainerA& a, const ContainerB& b) {
size_t size = std::size(a);
CHECK_EQ(size, std::size(b));
CHECK_NE(size, size_t{0});
std::vector<F> partial_sum_of_products = base::ParallelizeMap(
a,
[&b](absl::Span<const F> chunk, size_t chunk_idx, size_t chunk_size) {
F sum = F::Zero();
size_t i = chunk_idx * chunk_size;
for (size_t j = 0; j < chunk.size(); ++j) {
sum += (chunk[j] * b[i + j]);
}
return sum;
});
return std::accumulate(partial_sum_of_products.begin(),
partial_sum_of_products.end(), F::Zero(),
[](F& acc, const F& partial_sum_of_product) {
return acc += partial_sum_of_product;
});
}

template <typename ContainerA, typename ContainerB>
constexpr static F SumOfProductsSerial(const ContainerA& a,
const ContainerB& b) {
size_t size = std::size(a);
CHECK_EQ(size, std::size(b));
CHECK_NE(size, size_t{0});
return DoSumOfProductsSerial(a, b);
}

private:
template <typename ContainerA, typename ContainerB>
constexpr static F DoSumOfProductsSerial(const ContainerA& a,
const ContainerB& b) {
size_t n = std::size(a);
F sum = F::Zero();
while (a_first != a_last) {
sum += (*a_first * *b_first);
++a_first;
++b_first;
for (size_t i = 0; i < n; ++i) {
sum += (a[i] * b[i]);
}
return sum;
}

template <typename Container>
constexpr static F SumOfProducts(const Container& a, const Container& b) {
return SumOfProducts(std::begin(a), std::end(a), std::begin(b),
std::end(b));
}
};

} // namespace tachyon::math
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ constexpr PointXYZZ<Curve> CLASS::DoubleXYZZ() const {
// Y3 = M * (S - X3) - W * Y1
BaseField lefts[] = {std::move(m), -w};
BaseField rights[] = {s - x, y_};
BaseField y = BaseField::SumOfProducts(lefts, rights);
BaseField y = BaseField::SumOfProductsSerial(lefts, rights);

// ZZ3 = V
BaseField zz = std::move(v);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ constexpr CLASS& CLASS::AddInPlace(const JacobianPoint& other) {
y_.DoubleInPlace();
BaseField lefts[] = {std::move(r), y_};
BaseField rights[] = {std::move(v), std::move(j)};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// Z3 = ((Z1 + Z2)² - Z1Z1 - Z2Z2) * H
// This is equal to Z3 = 2 * Z1 * Z2 * H, and computing it this way is
Expand Down Expand Up @@ -159,7 +159,7 @@ constexpr CLASS& CLASS::AddInPlace(const AffinePoint<Curve>& other) {
// Y3 = r * (V - X3) + 2 * Y1 * J
BaseField lefts[] = {std::move(r), y_.Double()};
BaseField rights[] = {v - x_, std::move(j)};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// Z3 = 2 * Z1 * H;
// Can alternatively be computed as (Z1 + H)² - Z1Z1 - HH, but the latter is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ constexpr CLASS& CLASS::AddInPlace(const PointXYZZ& other) {
// Y3 = R * (Q - X3) - S1 * PPP
BaseField lefts[] = {std::move(r), -s1};
BaseField rights[] = {q - x_, ppp};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// ZZ3 = ZZ1 * ZZ2 * PP
zz_ *= other.zz_;
Expand Down Expand Up @@ -136,7 +136,7 @@ constexpr CLASS& CLASS::AddInPlace(const AffinePoint<Curve>& other) {
// Y3 = R * (Q - X3) - Y1 * PPP
BaseField lefts[] = {std::move(r), -y_};
BaseField rights[] = {q - x_, ppp};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// ZZ3 = ZZ1 * PP
zz_ *= pp;
Expand Down Expand Up @@ -186,7 +186,7 @@ constexpr CLASS& CLASS::DoubleInPlace() {
// Y3 = M * (S - X3) - W * Y1
BaseField lefts[] = {std::move(m), -w};
BaseField rights[] = {s - x_, y_};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// ZZ3 = V * ZZ1
zz_ *= v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ constexpr CLASS& CLASS::AddInPlace(const ProjectivePoint& other) {
// Y3 = u * (R - A) - vvv * Y1Z2
BaseField lefts[] = {std::move(u), -vvv};
BaseField rights[] = {r - a, std::move(y1z2)};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// Z3 = vvv * Z1Z2
z_ = std::move(vvv);
Expand Down Expand Up @@ -139,7 +139,7 @@ constexpr CLASS& CLASS::AddInPlace(const AffinePoint<Curve>& other) {
// Y3 = u * (R - A) - vvv * Y1
BaseField lefts[] = {std::move(u), -vvv};
BaseField rights[] = {r - a, std::move(y_)};
y_ = BaseField::SumOfProducts(lefts, rights);
y_ = BaseField::SumOfProductsSerial(lefts, rights);

// Z3 = vvv * Z1
z_ *= vvv;
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/finite_fields/prime_field_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ TEST_F(PrimeFieldTest, MultiplicativeGroupOperators) {
EXPECT_EQ(f.Pow(5), GF7(5));
}

TEST_F(PrimeFieldTest, SumOfProducts) {
TEST_F(PrimeFieldTest, SumOfProductsSerial) {
const GF7 a[] = {GF7(3), GF7(2)};
const GF7 b[] = {GF7(2), GF7(5)};
EXPECT_EQ(GF7::SumOfProducts(a, b), GF7(2));
EXPECT_EQ(GF7::SumOfProductsSerial(a, b), GF7(2));
}

TEST_F(PrimeFieldTest, Random) {
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/finite_fields/quadratic_extension_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ class QuadraticExtensionField
{
BaseField lefts[] = {c0_, Config::MulByNonResidue(c1_)};
BaseField rights[] = {other.c0_, other.c1_};
c0 = BaseField::SumOfProducts(lefts, rights);
c0 = BaseField::SumOfProductsSerial(lefts, rights);
}
BaseField c1;
{
BaseField lefts[] = {c0_, c1_};
BaseField rights[] = {other.c1_, other.c0_};
c1 = BaseField::SumOfProducts(lefts, rights);
c1 = BaseField::SumOfProductsSerial(lefts, rights);
}
c0_ = std::move(c0);
c1_ = std::move(c1);
Expand Down
1 change: 1 addition & 0 deletions tachyon/math/polynomials/univariate/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ tachyon_cc_library(
":univariate_polynomial",
"//tachyon/base:bits",
"//tachyon/base:openmp_util",
"//tachyon/base:range",
"//tachyon/math/polynomials:evaluation_domain",
],
)
Expand Down
56 changes: 40 additions & 16 deletions tachyon/math/polynomials/univariate/univariate_evaluation_domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "tachyon/base/bits.h"
#include "tachyon/base/logging.h"
#include "tachyon/base/openmp_util.h"
#include "tachyon/base/range.h"
#include "tachyon/math/polynomials/evaluation_domain.h"
#include "tachyon/math/polynomials/univariate/univariate_evaluation_domain_forwards.h"
#include "tachyon/math/polynomials/univariate/univariate_evaluations.h"
Expand Down Expand Up @@ -159,6 +160,19 @@ class UnivariateEvaluationDomain : public EvaluationDomain<F, MaxDegree> {
// polynomial P over H, where d < m, P(𝜏) can be computed as P(𝜏) =
// Σ{i in m} Lᵢ_H(𝜏) * P(gⁱ).
constexpr std::vector<F> EvaluateAllLagrangeCoefficients(const F& tau) const {
return EvaluatePartialLagrangeCoefficients(
tau, base::Range<size_t>::Until(size_));
}

// Almost same with above, but it only computes parts of the lagrange
// coefficients defined by |range|.
template <typename T>
constexpr std::vector<F> EvaluatePartialLagrangeCoefficients(
const F& tau, base::Range<T> range) const {
size_t size = range.GetSize();
CHECK_LE(size, size_);
if (size == 0) return {};

// Evaluate all Lagrange polynomials at 𝜏 to get the lagrange
// coefficients.
//
Expand All @@ -174,8 +188,8 @@ class UnivariateEvaluationDomain : public EvaluationDomain<F, MaxDegree> {
// Then i-th lagrange coefficient in this case is then simply 1,
// and all other lagrange coefficients are 0.
// Thus we find i by brute force.
std::vector<F> u(size_, F::Zero());
F omega_i = offset_;
std::vector<F> u(size, F::Zero());
F omega_i = GetElement(range.from);
for (F& u_i : u) {
if (omega_i == tau) {
u_i = F::One();
Expand All @@ -197,21 +211,25 @@ class UnivariateEvaluationDomain : public EvaluationDomain<F, MaxDegree> {
// (See
// https://github.com/arkworks-rs/algebra/blob/4152c41769ae0178fc110bfd15cc699673a2ce4b/poly/src/domain/mod.rs#L198)

// v₀⁻¹ = m * hᵐ⁻¹
F v_0_inv = size_as_field_element_ * offset_pow_size_ * offset_inv_;
// lᵢ = Z_H(𝜏)⁻¹ * v₀⁻¹ = (Z_H(𝜏) * vᵢ)⁻¹
F l_i = z_h_at_tau.Inverse() * v_0_inv;
F negative_cur_elem = -offset_;
// t = m * hᵐ = v₀⁻¹ * h
F t = size_as_field_element_ * offset_pow_size_;
F omega_i = GetElement(range.from);
// lᵢ = (Z_H(𝜏) * h * gᵢ)⁻¹ * t
// = (Z_H(𝜏) * h * gᵢ * t⁻¹)⁻¹
// = (Z_H(𝜏) * h * gᵢ * v₀⁻¹ * h⁻¹)⁻¹
// = (Z_H(𝜏) * gᵢ * v₀)⁻¹
F l_i = (z_h_at_tau * omega_i).Inverse() * t;
F negative_omega_i = -omega_i;
std::vector<F> lagrange_coefficients_inverse =
base::CreateVector(size_, [this, &l_i, &tau, &negative_cur_elem]() {
base::CreateVector(size, [this, &l_i, &tau, &negative_omega_i]() {
// 𝜏 - h * gⁱ
F r_i = tau + negative_cur_elem;
F r_i = tau + negative_omega_i;
// (Z_H(𝜏) * vᵢ)⁻¹ * (𝜏 - h * gⁱ)
F ret = l_i * r_i;
// lᵢ₊₁ = g⁻¹ * lᵢ
l_i *= group_gen_inv_;
// -h * gⁱ⁺¹
negative_cur_elem *= group_gen_;
negative_omega_i *= group_gen_;
return ret;
});

Expand Down Expand Up @@ -275,8 +293,13 @@ class UnivariateEvaluationDomain : public EvaluationDomain<F, MaxDegree> {
}

// Returns the |i|-th element of the domain.
constexpr F GetElement(size_t i) const {
F result = group_gen_.Pow(i);
constexpr F GetElement(int64_t i) const {
F result;
if (i > 0) {
result = group_gen_.Pow(i);
} else {
result = group_gen_inv_.Pow(-i);
}
if (!offset_.IsOne()) {
result *= offset_;
}
Expand All @@ -286,11 +309,12 @@ class UnivariateEvaluationDomain : public EvaluationDomain<F, MaxDegree> {
// Returns all the elements of the domain.
constexpr std::vector<F> GetElements() const {
if (offset_.IsOne()) {
return base::CreateVector(size_,
[this](size_t i) { return group_gen_.Pow(i); });
return F::GetSuccessivePowers(size_, group_gen_);
} else {
return base::CreateVector(
size_, [this](size_t i) { return group_gen_.Pow(i) * offset_; });
F value = offset_;
return base::CreateVector(size_, [this, &value]() {
return std::exchange(value, value * group_gen_);
});
}
}

Expand Down
Loading

0 comments on commit a12e82d

Please sign in to comment.