Skip to content

Commit

Permalink
Implement trace generator for fibonacci.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 2, 2024
1 parent 0d56bd7 commit 2d40280
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 40 deletions.
2 changes: 1 addition & 1 deletion crates/prover/benches/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn simd_poseidon(c: &mut Criterion) {
let trace = gen_trace(component.log_column_size());
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let air = PoseidonAir { component };
prove::<SimdBackend>(&air, channel, trace).unwrap()
prove::<SimdBackend>(air, channel, trace).unwrap()
});
});
}
Expand Down
23 changes: 14 additions & 9 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
}

pub fn prove<B: Backend + MerkleOps<MerkleHasher>>(
air: &impl AirTraceGenerator<B>,
air: impl AirTraceGenerator<B>,
channel: &mut Channel,
trace: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
) -> Result<StarkProof, ProvingError> {
Expand All @@ -173,8 +173,7 @@ pub fn prove<B: Backend + MerkleOps<MerkleHasher>>(

// Check that the composition polynomial is not too big.
// TODO(AlonH): Get traces log degree bounds from trace writer.
let composition_polynomial_log_degree_bound =
air.to_air_prover().composition_log_degree_bound();
let composition_polynomial_log_degree_bound = air.composition_log_degree_bound();
if composition_polynomial_log_degree_bound + LOG_BLOWUP_FACTOR > MAX_CIRCLE_DOMAIN_LOG_SIZE {
return Err(ProvingError::MaxCompositionDegreeExceeded {
degree: composition_polynomial_log_degree_bound,
Expand All @@ -190,10 +189,10 @@ pub fn prove<B: Backend + MerkleOps<MerkleHasher>>(
span.exit();

let (mut commitment_scheme, interaction_elements) =
evaluate_and_commit_on_trace(air, channel, &twiddles, trace)?;
evaluate_and_commit_on_trace(&air, channel, &twiddles, trace)?;

generate_proof(
air.to_air_prover(),
&air.to_air_prover(),
channel,
&interaction_elements,
&twiddles,
Expand Down Expand Up @@ -390,6 +389,7 @@ mod tests {
use crate::trace_generation::registry::ComponentGenerationRegistry;
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

#[derive(Clone)]
struct TestAir<C: ComponentProver<CpuBackend>> {
component: C,
}
Expand All @@ -415,9 +415,13 @@ mod tests {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
fn to_air_prover(self) -> impl AirProver<CpuBackend> {
self
}

fn composition_log_degree_bound(&self) -> u32 {
self.component.max_constraint_log_degree_bound()
}
}

impl AirProver<CpuBackend> for TestAir<TestComponent> {
Expand All @@ -426,6 +430,7 @@ mod tests {
}
}

#[derive(Clone)]
struct TestComponent {
log_size: u32,
max_constraint_log_degree_bound: u32,
Expand Down Expand Up @@ -527,7 +532,7 @@ mod tests {
let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE];
let trace = vec![CpuCircleEvaluation::new(domain, values)];

let proof_error = prove(&air, &mut test_channel(), trace).unwrap_err();
let proof_error = prove(air, &mut test_channel(), trace).unwrap_err();
assert!(matches!(
proof_error,
ProvingError::MaxTraceDegreeExceeded {
Expand All @@ -554,7 +559,7 @@ mod tests {
let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE];
let trace = vec![CpuCircleEvaluation::new(domain, values)];

let proof_error = prove(&air, &mut test_channel(), trace).unwrap_err();
let proof_error = prove(air, &mut test_channel(), trace).unwrap_err();
assert!(matches!(
proof_error,
ProvingError::MaxCompositionDegreeExceeded {
Expand All @@ -576,7 +581,7 @@ mod tests {
let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE];
let trace = vec![CpuCircleEvaluation::new(domain, values)];

let proof = prove(&air, &mut test_channel(), trace).unwrap_err();
let proof = prove(air, &mut test_channel(), trace).unwrap_err();
assert!(matches!(proof, ProvingError::ConstraintsNotSatisfied));
}
}
86 changes: 82 additions & 4 deletions crates/prover/src/examples/fibonacci/air.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,80 @@
use itertools::{zip_eq, Itertools};

use super::component::FibonacciComponent;
use super::component::{FibonacciComponent, FibonacciInput, FibonacciTraceGenerator};
use crate::core::air::{Air, AirProver, Component, ComponentProver};
use crate::core::backend::CpuBackend;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, InteractionElements};
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier};
use crate::trace_generation::registry::ComponentGenerationRegistry;
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

pub struct FibonacciAirGenerator {
pub registry: ComponentGenerationRegistry,
}

impl Clone for FibonacciAirGenerator {
fn clone(&self) -> Self {
Self {
registry: ComponentGenerationRegistry::default(),
}
}
}

impl FibonacciAirGenerator {
pub fn new(inputs: &FibonacciInput) -> Self {
let mut component_generator = FibonacciTraceGenerator::new();
component_generator.add_inputs(inputs);
let mut registry = ComponentGenerationRegistry::default();
registry.register("fibonacci", component_generator);
Self { registry }
}
}

impl AirTraceVerifier for FibonacciAirGenerator {
fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements {
InteractionElements::default()
}
}

impl AirTraceGenerator<CpuBackend> for FibonacciAirGenerator {
fn write_trace(&mut self) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
FibonacciTraceGenerator::write_trace("fibonacci", &mut self.registry)
}

fn interact(
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn to_air_prover(self) -> impl AirProver<CpuBackend> {
let component_generator = self
.registry
.get_generator::<FibonacciTraceGenerator>("fibonacci");
// TODO(AlonH): Take instead of clone.
FibonacciAir {
component: component_generator.clone().component(),
}
}

fn composition_log_degree_bound(&self) -> u32 {
let component_generator = self
.registry
.get_generator::<FibonacciTraceGenerator>("fibonacci");
assert!(component_generator.inputs_set(), "Fibonacci input not set.");
component_generator
.clone()
.component()
.max_constraint_log_degree_bound()
}
}

#[derive(Clone)]
pub struct FibonacciAir {
pub component: FibonacciComponent,
}
Expand Down Expand Up @@ -41,9 +106,13 @@ impl AirTraceGenerator<CpuBackend> for FibonacciAir {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
fn to_air_prover(self) -> impl AirProver<CpuBackend> {
self
}

fn composition_log_degree_bound(&self) -> u32 {
self.component.max_constraint_log_degree_bound()
}
}

impl AirProver<CpuBackend> for FibonacciAir {
Expand All @@ -52,6 +121,7 @@ impl AirProver<CpuBackend> for FibonacciAir {
}
}

#[derive(Clone)]
pub struct MultiFibonacciAir {
pub components: Vec<FibonacciComponent>,
}
Expand Down Expand Up @@ -90,9 +160,17 @@ impl AirTraceGenerator<CpuBackend> for MultiFibonacciAir {
vec![]
}

fn to_air_prover(&self) -> &impl AirProver<CpuBackend> {
fn to_air_prover(self) -> impl AirProver<CpuBackend> {
self
}

fn composition_log_degree_bound(&self) -> u32 {
self.components
.iter()
.map(|component| component.max_constraint_log_degree_bound())
.max()
.unwrap()
}
}

impl AirProver<CpuBackend> for MultiFibonacciAir {
Expand Down
64 changes: 55 additions & 9 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ use crate::core::prover::BASE_TRACE;
use crate::core::utils::bit_reverse_index;
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::trace_generation::registry::ComponentGenerationRegistry;
use crate::trace_generation::ComponentTraceGenerator;
use crate::trace_generation::{ComponentGen, ComponentTraceGenerator};

#[derive(Clone)]
pub struct FibonacciComponent {
pub log_size: u32,
pub claim: BaseField,
Expand Down Expand Up @@ -130,17 +131,61 @@ impl Component for FibonacciComponent {
}
}

impl ComponentTraceGenerator<CpuBackend> for FibonacciComponent {
type Component = Self;
type Inputs = ();
pub struct FibonacciInput(pub u32, pub BaseField);

fn add_inputs(&mut self, _inputs: &Self::Inputs) {}
#[derive(Clone)]
pub struct FibonacciTraceGenerator {
log_size: Option<u32>,
claim: Option<BaseField>,
}

impl ComponentGen for FibonacciTraceGenerator {}

impl FibonacciTraceGenerator {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
log_size: None,
claim: None,
}
}

pub fn inputs_set(&self) -> bool {
self.log_size.is_some() && self.claim.is_some()
}
}

impl ComponentTraceGenerator<CpuBackend> for FibonacciTraceGenerator {
type Component = FibonacciComponent;
type Inputs = FibonacciInput;

fn add_inputs(&mut self, inputs: &Self::Inputs) {
assert!(!self.inputs_set(), "Fibonacci input already set.");
self.log_size = Some(inputs.0);
self.claim = Some(inputs.1);
}

fn write_trace(
_component_id: &str,
_registry: &mut ComponentGenerationRegistry,
component_id: &str,
registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
let component = registry.get_generator_mut::<Self>(component_id);
assert!(component.inputs_set(), "Fibonacci input not set.");
let trace_domain = CanonicCoset::new(component.log_size.unwrap());
let mut trace = Vec::with_capacity(trace_domain.size());

// Fill trace with fibonacci squared.
let mut a = BaseField::one();
let mut b = BaseField::one();
for _ in 0..trace_domain.size() {
trace.push(a);
let tmp = a.square() + b.square();
a = b;
b = tmp;
}

// Returns as a CircleEvaluation.
vec![CircleEvaluation::new_canonical_ordered(trace_domain, trace)]
}

fn write_interaction_trace(
Expand All @@ -152,7 +197,8 @@ impl ComponentTraceGenerator<CpuBackend> for FibonacciComponent {
}

fn component(self) -> Self::Component {
self
assert!(self.inputs_set(), "Fibonacci input not set.");
FibonacciComponent::new(self.log_size.unwrap(), self.claim.unwrap())
}
}

Expand Down
28 changes: 25 additions & 3 deletions crates/prover/src/examples/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::core::vcs::hasher::Hasher;
pub mod air;
mod component;

#[derive(Clone)]
pub struct Fibonacci {
pub air: FibonacciAir,
}
Expand Down Expand Up @@ -55,7 +56,7 @@ impl Fibonacci {
.air
.component
.claim])));
prove(&self.air, channel, vec![trace])
prove(self.air.clone(), channel, vec![trace])
}

pub fn verify(&self, proof: StarkProof) -> Result<(), VerificationError> {
Expand Down Expand Up @@ -97,7 +98,8 @@ impl MultiFibonacci {
pub fn prove(&self) -> Result<StarkProof, ProvingError> {
let channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&self.claims)));
prove(&self.air, channel, self.get_trace())
let trace = self.get_trace();
prove(self.air.clone(), channel, trace)
}

pub fn verify(&self, proof: StarkProof) -> Result<(), VerificationError> {
Expand All @@ -120,15 +122,22 @@ mod tests {
use super::{Fibonacci, MultiFibonacci};
use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::air::{AirExt, AirProverExt, Component, ComponentTrace};
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::IntoSlice;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CanonicCoset;
use crate::core::prover::{VerificationError, BASE_TRACE};
use crate::core::prover::{prove, VerificationError, BASE_TRACE};
use crate::core::queries::Queries;
use crate::core::utils::bit_reverse;
use crate::core::vcs::blake2_hash::Blake2sHasher;
use crate::core::vcs::hasher::Hasher;
use crate::core::{InteractionElements, LookupValues};
use crate::examples::fibonacci::air::FibonacciAirGenerator;
use crate::examples::fibonacci::component::FibonacciInput;
use crate::trace_generation::AirTraceGenerator;
use crate::{m31, qm31};

pub fn generate_test_queries(n_queries: usize, trace_length: usize) -> Vec<usize> {
Expand Down Expand Up @@ -232,6 +241,19 @@ mod tests {
fib.verify(proof).unwrap();
}

#[test]
fn test_fib_prove_2() {
const FIB_LOG_SIZE: u32 = 5;
const CLAIM: BaseField = m31!(443693538);
let mut fib_trace_generator =
FibonacciAirGenerator::new(&FibonacciInput(FIB_LOG_SIZE, CLAIM));

let trace = fib_trace_generator.write_trace();
let channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[CLAIM])));
prove(fib_trace_generator, channel, trace).unwrap();
}

#[test]
fn test_prove_invalid_trace_value() {
const FIB_LOG_SIZE: u32 = 5;
Expand Down
Loading

0 comments on commit 2d40280

Please sign in to comment.