Skip to content

Commit

Permalink
feat(zk): add BuildCircuitPoly() to VanishingArgument
Browse files Browse the repository at this point in the history
  • Loading branch information
dongchangYoo committed Dec 18, 2023
1 parent 1e95685 commit 8d0ef06
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 9 deletions.
5 changes: 5 additions & 0 deletions tachyon/zk/plonk/vanishing/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ tachyon_cc_library(
":graph_evaluator",
"//tachyon/base/containers:container_util",
"//tachyon/zk/plonk:constraint_system",
"//tachyon/zk/plonk/vanishing/data:vanishing_lookup",
"//tachyon/zk/plonk/vanishing/data:vanishing_permutation",
"//tachyon/zk/plonk/vanishing/data:vanishing_proving_key",
"//tachyon/zk/plonk/vanishing/data:vanishing_table",
],
)

Expand Down Expand Up @@ -149,5 +153,6 @@ tachyon_cc_unittest(
"//tachyon/zk/plonk/circuit/examples:simple_circuit",
"//tachyon/zk/plonk/halo2:pinned_verifying_key",
"//tachyon/zk/plonk/halo2:prover_test",
"//tachyon/zk/plonk/keys:proving_key",
],
)
4 changes: 2 additions & 2 deletions tachyon/zk/plonk/vanishing/prover_vanishing_argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ template <typename PCSTy, typename ExtendedEvals>
[[nodiscard]] bool CommitFinalHPoly(
ProverBase<PCSTy>* prover,
VanishingCommitted<EntityTy::kProver, PCSTy>&& committed,
const VerifyingKey<PCSTy>& vk, ExtendedEvals& linear_combination_of_gates,
const VerifyingKey<PCSTy>& vk, ExtendedEvals& combined_custom_gate_column,
VanishingConstructed<EntityTy::kProver, PCSTy>* constructed_out) {
using F = typename PCSTy::Field;
using Poly = typename PCSTy::Poly;
Expand All @@ -62,7 +62,7 @@ template <typename PCSTy, typename ExtendedEvals>

// Divide by t(X) = X^{params.n} - 1.
ExtendedEvals h_evals = DivideByVanishingPolyInPlace<F>(
linear_combination_of_gates, prover->extended_domain(), prover->domain());
combined_custom_gate_column, prover->extended_domain(), prover->domain());

// Obtain final h(X) polynomial
ExtendedPoly h_poly =
Expand Down
273 changes: 270 additions & 3 deletions tachyon/zk/plonk/vanishing/vanishing_argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,24 @@
#include <utility>
#include <vector>

#include "tachyon/base/containers/container_util.h"
#include "tachyon/base/parallelize.h"
#include "tachyon/zk/plonk/constraint_system.h"
#include "tachyon/zk/plonk/vanishing/data/vanishing_lookup.h"
#include "tachyon/zk/plonk/vanishing/data/vanishing_permutation.h"
#include "tachyon/zk/plonk/vanishing/data/vanishing_proving_key.h"
#include "tachyon/zk/plonk/vanishing/data/vanishing_table.h"
#include "tachyon/zk/plonk/vanishing/graph_evaluator.h"

namespace tachyon::zk {

template <typename PCSTy>
class ProvingKey;

template <typename F>
class VanishingArgument {
public:
constexpr static size_t rot_scale = 1;

VanishingArgument() = default;

static VanishingArgument Create(
Expand Down Expand Up @@ -75,10 +84,268 @@ class VanishingArgument {
return evaluator;
}

// TODO(chokobole): Implement EvaluateH. See
// [evaluate_h](https://github.com/kroma-network/halo2/blob/7d0a36990452c8e7ebd600de258420781a9b7917/halo2_proofs/src/plonk/evaluation.rs#L279-L583).
// Returns a evaluation-formed polynomial below.
// - gate₀(X) + y * gate₁(X) + ... + yⁱgateᵢ(X) + ...
template <typename PCSTy, typename Poly,
typename ExtendedEvals = typename PCSTy::ExtendedEvals>
ExtendedEvals BuildCombinedCustomGateColumn(
ProverBase<PCSTy>* prover, const ProvingKey<PCSTy>& proving_key,
const std::vector<Table<Poly>>& poly_tables,
const std::vector<F>& challenges, const F& y, const F& beta,
const F& gamma, const F& theta,
const std::vector<PermutationCommitted<Poly>>& committed_permutations,
const std::vector<std::vector<LookupCommitted<Poly>>>&
committed_lookups_vec) {
using Evals = typename PCSTy::Evals;

size_t cs_degree =
proving_key.verifying_key().constraint_system().ComputeDegree();
F zeta = GetZeta<F>();
VanishingCommon<F> common(prover, cs_degree, &beta, &gamma, &theta, &y,
&zeta, &challenges);

std::vector<std::vector<F>> value_parts;
value_parts.reserve(common.num_parts());
// Calculate the quotient polynomial for each part
for (size_t i = 0; i < common.num_parts(); ++i) {
VanishingProvingKey<Evals> vanishing_proving_key =
VanishingProvingKey<Evals>::Create(prover, proving_key, common);
std::vector<VanishingTable<Poly, Evals>> vanishing_tables = base::Map(
poly_tables, [prover, &common](const Table<Poly>& poly_table) {
return VanishingTable<Poly, Evals>::Create(prover, common,
poly_table);
});

std::vector<F> values_part =
base::CreateVector(prover->pcs().N(), F::Zero());
for (size_t j = 0; j < poly_tables.size(); ++j) {
UpdateValuesByCustomGates(common, vanishing_tables[j], values_part);
UpdateValuesByPermutation(prover, common, vanishing_proving_key,
vanishing_tables[j],
proving_key.permutation_proving_key(),
committed_permutations[j], values_part);
UpdateValuesByLookups(prover, common, vanishing_proving_key,
vanishing_tables[j], committed_lookups_vec[j],
values_part);
}
value_parts.push_back(std::move(values_part));
common.UpdateCurrentExtendedOmega();
}
std::vector<F> extended =
BuildExtendedColumnWithColumns(std::move(value_parts));
return ExtendedEvals(std::move(extended));
}

template <typename PCSTy, typename Evals, typename Poly>
void UpdateValuesByLookups(
const ProverBase<PCSTy>* prover, const VanishingCommon<F>& common,
const VanishingProvingKey<Evals>& vanishing_proving_key,
const VanishingTable<Poly, Evals>& vanishing_table,
const std::vector<LookupCommitted<Poly>>& committed_lookups,
std::vector<F>& values) {
for (size_t i = 0; i < committed_lookups.size(); ++i) {
const GraphEvaluator<F>& ev = lookups_[i];
VanishingLookup<Poly, Evals> vanishing_lookup =
VanishingLookup<Poly, Evals>::Create(prover, &common,
committed_lookups[i]);
base::Parallelize(values, [&ev, &vanishing_proving_key, &vanishing_lookup,
&vanishing_table](absl::Span<F> chunk,
size_t chunk_offset,
size_t chunk_size) {
const VanishingCommon<F>& common = vanishing_lookup.common();
const Evals& l_first = vanishing_proving_key.l_first();
const Evals& l_last = vanishing_proving_key.l_last();

std::vector<F> intermediates = ev.CreateInitialIntermediates();
std::vector<int32_t> rotations = ev.CreateEmptyRotations();
EvaluationInput<Poly, Evals> evaluation_input(
common, std::move(intermediates), std::move(rotations),
&vanishing_table);

size_t start = chunk_offset * chunk_size;
for (size_t j = 0; j < chunk.size(); ++j) {
size_t idx = start + j;

F zero = F::Zero();
F table_value = ev.Evaluate(evaluation_input, idx, rot_scale, zero);

size_t r_next = Rotation(1).GetIndex(idx, rot_scale, common.n());
size_t r_prev = Rotation(-1).GetIndex(idx, rot_scale, common.n());

F a_minus_s = *vanishing_lookup.input_coset()[idx] -
*vanishing_lookup.table_coset()[idx];

// l₀(X) * (1 - z(X)) = 0
chunk[j] *= common.y();
chunk[j] += (common.one() - *vanishing_lookup.product_coset()[idx]) *
*l_first[idx];

// l₋₁(X) * (z(X)² - z(X)) = 0
chunk[j] *= common.y();
chunk[j] += (vanishing_lookup.product_coset()[idx]->Square() -
*vanishing_lookup.product_coset()[idx]) *
*l_last[idx];

// clang-format off
// A * (B - C) = 0 where
// - A = 1 - (l₋₁(X) + l_blind(X))
// - B = z(wX) * (a'(X) + β) * (s'(X) + γ)
// - C = z(X) * (θᵐ⁻¹ a₀(X) + ... + aᵐ⁻¹(X) + β) * (θᵐ⁻¹ s₀(X) + ... + sᵐ⁻¹(X) + γ)
// clang-format on
chunk[j] *= common.y();
chunk[j] +=
(*vanishing_lookup.product_coset()[r_next] *
(*vanishing_lookup.input_coset()[idx] + common.beta()) *
(*vanishing_lookup.table_coset()[idx] + common.gamma()) -
*vanishing_lookup.product_coset()[idx] * table_value) *
*vanishing_proving_key.l_active_row()[idx];

// Check that the first values in the permuted input expression and
// permuted fixed expression are the same. l₀(X) * (a'(X) - s'(X)) = 0
chunk[j] *= common.y();
chunk[j] += a_minus_s * *vanishing_proving_key.l_first()[idx];

// Check that each value in the permuted lookup input expression is
// either equal to the value above it, or the value at the same index
// in the permuted table expression.
// (1 - (l₋₁ + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(w⁻¹X)) = 0
chunk[j] *= common.y();
chunk[j] += a_minus_s *
(*vanishing_lookup.input_coset()[idx] -
*vanishing_lookup.input_coset()[r_prev]) *
*vanishing_proving_key.l_active_row()[idx];
}
});
}
}

template <typename Poly, typename Evals>
void UpdateValuesByCustomGates(
const VanishingCommon<F>& common,
const VanishingTable<Poly, Evals>& vanishing_table,
std::vector<F>& values) {
base::Parallelize(values, [this, &common, &vanishing_table](
absl::Span<F> chunk, size_t chunk_offset,
size_t chunk_size) {
std::vector<F> intermediates = custom_gates_.CreateInitialIntermediates();
std::vector<int32_t> rotations = custom_gates_.CreateEmptyRotations();
EvaluationInput<Poly, Evals> evaluation_input(
common, std::move(intermediates), std::move(rotations),
&vanishing_table);

size_t start = chunk_offset * chunk_size;
for (size_t i = 0; i < chunk.size(); ++i) {
chunk[i] = custom_gates_.Evaluate(evaluation_input, start + i,
rot_scale, chunk[i]);
}
});
}

template <typename PCSTy, typename Evals, typename Poly>
void UpdateValuesByPermutation(
ProverBase<PCSTy>* prover, const VanishingCommon<F>& common,
const VanishingProvingKey<Evals>& vanishing_proving_key,
const VanishingTable<Poly, Evals>& vanishing_table,
const PermutationProvingKey<Poly, Evals>& permutation_proving_key,
const PermutationCommitted<Poly>& committed_permutation,
std::vector<F>& values) {
if (committed_permutation.product_polys().size() == 0) return;

VanishingPermutation<Poly, Evals> vanishing_permutation =
VanishingPermutation<Poly, Evals>::Create(
prover, &common, committed_permutation, permutation_proving_key);

base::Parallelize(values, [this, &vanishing_proving_key,
&vanishing_permutation, &vanishing_table](
absl::Span<F> chunk, size_t chunk_offset,
size_t chunk_size) {
const VanishingCommon<F>& common = vanishing_permutation.common();
const std::vector<Evals>& product_cosets =
vanishing_permutation.product_cosets();
const Evals& l_first = vanishing_proving_key.l_first();

size_t start = chunk_offset * chunk_size;
F beta_term = common.current_extended_omega() * common.omega().Pow(start);
for (size_t i = 0; i < chunk.size(); ++i) {
size_t idx = start + i;

// Enforce only for the first set: l₀(X) * (1 - z₀(X)) = 0
chunk[i] *= common.y();
F tmp = (common.one() - *product_cosets.front()[idx]) * *l_first[idx];
chunk[i] += tmp;

// Enforce only for the last set: l₋₁(X) * (z₋₁(X)² - z₋₁(X)) = 0
const Evals& last_coset = product_cosets.back();
chunk[i] *= common.y();
chunk[i] += *vanishing_proving_key.l_last()[idx] *
(last_coset[idx]->Square() - *last_coset[idx]);

// Except for the first set, enforce: l₀(X) * (zᵢ(X) - zᵢ₋₁(w⁻¹X)) = 0
size_t r_last = vanishing_permutation.last_rotation().GetIndex(
idx, rot_scale, common.n());
for (size_t set_idx = 0; set_idx < product_cosets.size(); ++set_idx) {
if (set_idx == 0) continue;
chunk[i] *= common.y();
chunk[i] += *l_first[idx] * (*product_cosets[set_idx][idx] -
*product_cosets[set_idx - 1][r_last]);
}

// And for all the sets we enforce: (1 - (l₋₁(X) + l_blind(X))) *
// (zᵢ(wX) * Πⱼ(p(X) + βsⱼ(X) + γ) - zᵢ(X) Πⱼ(p(X) + δʲβX + γ))
F current_delta = vanishing_permutation.delta_start() * beta_term;
size_t r_next = Rotation(1).GetIndex(idx, rot_scale, common.n());

std::vector<absl::Span<const AnyColumnKey>> column_key_chunks =
base::ParallelizeMapByChunkSize(
vanishing_proving_key.column_keys(), common.chunk_len(),
[](absl::Span<const AnyColumnKey> chunk) { return chunk; });
std::vector<absl::Span<const Evals>> coset_chunks =
base::ParallelizeMapByChunkSize(
vanishing_permutation.cosets(), common.chunk_len(),
[](absl::Span<const Evals> chunk) { return chunk; });

for (size_t j = 0; j < product_cosets.size(); ++j) {
std::vector<base::Ref<const Evals>> column_chunk =
vanishing_table.GetColumns(column_key_chunks[j]);
F left = CalculateLeft(common, column_chunk, coset_chunks[j], idx,
product_cosets[j][r_next]);
F right = CalculateRight(common, column_chunk, &current_delta, idx,
product_cosets[j][idx]);
chunk[i] *= common.y();
chunk[i] +=
(left - right) * *vanishing_proving_key.l_active_row()[idx];
}
beta_term *= common.omega();
}
});
}

private:
template <typename Evals>
F CalculateLeft(const VanishingCommon<F>& common,
const std::vector<base::Ref<const Evals>>& column_chunk,
absl::Span<const Evals> coset_chunk, size_t idx,
const F* initial_value) {
F left = *initial_value;
for (size_t i = 0; i < column_chunk.size(); ++i) {
left *= *(*column_chunk[i])[idx] + common.beta() * *coset_chunk[i][idx] +
common.gamma();
}
return left;
}

template <typename Evals>
F CalculateRight(const VanishingCommon<F>& common,
const std::vector<base::Ref<const Evals>>& column_chunk,
F* current_delta, size_t idx, const F* initial_value) {
F right = *initial_value;
for (size_t i = 0; i < column_chunk.size(); ++i) {
right *= *(*column_chunk[i])[idx] + common.delta() + common.gamma();
*current_delta *= common.delta();
}
return right;
}

GraphEvaluator<F> custom_gates_;
std::vector<GraphEvaluator<F>> lookups_;
};
Expand Down
Loading

0 comments on commit 8d0ef06

Please sign in to comment.