diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index d43dafa0e..d301bde49 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,12 +1,23 @@ -use num_traits::One; +use num_traits::{One, Zero}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkEval}; -use crate::core::fields::qm31::QM31; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; +use crate::core::air::{Component, ComponentProver}; +use crate::core::backend::simd::SimdBackend; +use crate::core::channel::Channel; +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::lookups::utils::Fraction; +use crate::core::pcs::TreeVec; +use crate::core::prover::StarkProof; +use crate::core::vcs::ops::MerkleHasher; pub const N_STATE: usize = 2; pub type StateMachineElements = LookupElements; +pub type State = [M31; N_STATE]; + +pub type StateMachineOp0Component = FrameworkComponent>; +pub type StateMachineOp1Component = FrameworkComponent>; /// State machine with state of size `N_STATE`. /// Transition `INDEX` of state increments the state by 1 at that offset. @@ -42,3 +53,74 @@ impl FrameworkEval for StateTransitionEval { eval } } + +pub struct StateMachineStatement0 { + pub n: u32, + pub m: u32, +} +impl StateMachineStatement0 { + pub fn log_sizes(&self) -> TreeVec> { + let sizes = vec![ + state_transition_info::<0>() + .mask_offsets + .as_cols_ref() + .map_cols(|_| self.n), + state_transition_info::<1>() + .mask_offsets + .as_cols_ref() + .map_cols(|_| self.m), + ]; + TreeVec::concat_cols(sizes.into_iter()) + } + pub fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_u64(self.n as u64); + channel.mix_u64(self.m as u64); + } +} + +pub struct StateMachineStatement1 { + pub x_axis_claimed_sum: SecureField, + pub y_axis_claimed_sum: SecureField, +} +impl StateMachineStatement1 { + pub fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_felts(&[self.x_axis_claimed_sum, self.y_axis_claimed_sum]) + } +} + +fn state_transition_info() -> InfoEvaluator { + let component = StateTransitionEval:: { + log_n_rows: 1, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + +pub struct StateMachineComponents { + pub component0: StateMachineOp0Component, + pub component1: StateMachineOp1Component, +} + +impl StateMachineComponents { + pub fn components(&self) -> Vec<&dyn Component> { + vec![ + &self.component0 as &dyn Component, + &self.component1 as &dyn Component, + ] + } + + pub fn component_provers(&self) -> Vec<&dyn ComponentProver> { + vec![ + &self.component0 as &dyn ComponentProver, + &self.component1 as &dyn ComponentProver, + ] + } +} + +pub struct StateMachineProof { + pub public_input: [State; 2], // Initial and final state. + pub stmt0: StateMachineStatement0, + pub stmt1: StateMachineStatement1, + pub stark_proof: StarkProof, +} diff --git a/crates/prover/src/examples/state_machine/gen.rs b/crates/prover/src/examples/state_machine/gen.rs index 0f6732dd7..7723d010b 100644 --- a/crates/prover/src/examples/state_machine/gen.rs +++ b/crates/prover/src/examples/state_machine/gen.rs @@ -1,7 +1,7 @@ use itertools::Itertools; use num_traits::One; -use super::components::N_STATE; +use super::components::{State, N_STATE}; use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; @@ -14,8 +14,6 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::ColumnVec; -pub type State = [M31; N_STATE]; - // Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the // `inc_index` dimension. // E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size]. diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 26d4d57bd..2cc2801fc 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -1,23 +1,27 @@ pub mod components; pub mod gen; -use components::{StateMachineElements, StateTransitionEval}; -use gen::{gen_interaction_trace, gen_trace, State}; -use itertools::Itertools; +use components::{ + State, StateMachineComponents, StateMachineElements, StateMachineOp0Component, + StateMachineOp1Component, StateMachineProof, StateMachineStatement0, StateMachineStatement1, + StateTransitionEval, +}; +use gen::{gen_interaction_trace, gen_trace}; +use itertools::{chain, Itertools}; use crate::constraint_framework::constant_columns::gen_is_first; -use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator}; -use crate::core::air::Component; +use crate::constraint_framework::TraceLocationAllocator; +// use crate::core::air::Component; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::SimdBackend; use crate::core::channel::Blake2sChannel; +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::QM31; use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps}; -use crate::core::prover::{prove, verify, StarkProof, VerificationError}; +use crate::core::prover::{prove, verify, VerificationError}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; -pub type StateMachineOp0Component = FrameworkComponent>; - #[allow(unused)] pub fn prove_state_machine( log_n_rows: u32, @@ -25,11 +29,17 @@ pub fn prove_state_machine( config: PcsConfig, channel: &mut Blake2sChannel, ) -> ( - StateMachineOp0Component, - StarkProof, + StateMachineComponents, + StateMachineProof, TreeVec>>, ) { assert!(log_n_rows >= LOG_N_LANES); + let x_axis_log_rows = log_n_rows; + let y_axis_log_rows = log_n_rows - 1; + let mut intermediate_state = initial_state; + intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows); + let mut final_state = intermediate_state; + final_state[1] += M31::from_u32_unchecked(1 << y_axis_log_rows); // Precompute twiddles. let twiddles = SimdBackend::precompute_twiddles( @@ -43,9 +53,17 @@ pub fn prove_state_machine( &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); // Trace. - let trace_op0 = gen_trace(log_n_rows, initial_state, 0); + let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0); + let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1); + + let stmt0 = StateMachineStatement0 { + n: x_axis_log_rows, + m: y_axis_log_rows, + }; + stmt0.mix_into(channel); + let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace_op0); + tree_builder.extend_evals(chain![trace_op0, trace_op1].collect_vec()); tree_builder.commit(channel); // Draw lookup element. @@ -53,14 +71,26 @@ pub fn prove_state_machine( // Interaction trace. let (interaction_trace_op0, total_sum_op0) = - gen_interaction_trace(log_n_rows, initial_state, 0, &lookup_elements); + gen_interaction_trace(x_axis_log_rows, initial_state, 0, &lookup_elements); + let (interaction_trace_op1, total_sum_op1) = + gen_interaction_trace(y_axis_log_rows, intermediate_state, 1, &lookup_elements); + + let stmt1 = StateMachineStatement1 { + x_axis_claimed_sum: total_sum_op0, + y_axis_claimed_sum: total_sum_op1, + }; + stmt1.mix_into(channel); + let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(interaction_trace_op0); + tree_builder.extend_evals(chain![interaction_trace_op0, interaction_trace_op1].collect_vec()); tree_builder.commit(channel); // Constant trace. let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]); + tree_builder.extend_evals(vec![ + gen_is_first(x_axis_log_rows), + gen_is_first(y_axis_log_rows), + ]); tree_builder.commit(channel); let trace_polys = commitment_scheme @@ -69,82 +99,154 @@ pub fn prove_state_machine( .map(|t| t.polynomials.iter().cloned().collect_vec()); // Prove constraints. - let component_op0 = StateMachineOp0Component::new( - &mut TraceLocationAllocator::default(), + let mut tree_span_provider = &mut TraceLocationAllocator::default(); + let component0 = StateMachineOp0Component::new( + tree_span_provider, StateTransitionEval { - log_n_rows, - lookup_elements, + log_n_rows: x_axis_log_rows, + lookup_elements: lookup_elements.clone(), total_sum: total_sum_op0, }, ); - - let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap(); - - (component_op0, proof, trace_polys) + let component1 = StateMachineOp1Component::new( + tree_span_provider, + StateTransitionEval { + log_n_rows: y_axis_log_rows, + lookup_elements, + total_sum: total_sum_op1, + }, + ); + let components = StateMachineComponents { + component0, + component1, + }; + let stark_proof = prove(&components.component_provers(), channel, commitment_scheme).unwrap(); + let proof = StateMachineProof { + public_input: [initial_state, final_state], + stmt0, + stmt1, + stark_proof, + }; + (components, proof, trace_polys) } pub fn verify_state_machine( config: PcsConfig, channel: &mut Blake2sChannel, - component: StateMachineOp0Component, - proof: StarkProof, + components: StateMachineComponents, + proof: StateMachineProof, ) -> Result<(), VerificationError> { let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); - // Decommit. // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); + let sizes = proof.stmt0.log_sizes(); // Trace columns. - commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + proof.stmt0.mix_into(channel); + commitment_scheme.commit(proof.stark_proof.commitments[0], &sizes[0], channel); + + // Assert state machine statement. + let lookup_elements = StateMachineElements::draw(channel); + let initial_state_comb: QM31 = lookup_elements.combine(&proof.public_input[0]); + let final_state_comb: QM31 = lookup_elements.combine(&proof.public_input[1]); + assert_eq!( + (proof.stmt1.x_axis_claimed_sum + proof.stmt1.y_axis_claimed_sum) + * initial_state_comb + * final_state_comb, + final_state_comb - initial_state_comb + ); + // Interaction columns. - commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + proof.stmt1.mix_into(channel); + commitment_scheme.commit(proof.stark_proof.commitments[1], &sizes[1], channel); // Constant columns. - commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); + commitment_scheme.commit(proof.stark_proof.commitments[2], &sizes[2], channel); - verify(&[&component], channel, commitment_scheme, proof) + verify( + &components.components(), + channel, + commitment_scheme, + proof.stark_proof, + ) } #[cfg(test)] mod tests { use num_traits::Zero; - use super::components::N_STATE; + use super::components::{ + StateMachineElements, StateMachineOp0Component, StateTransitionEval, N_STATE, + }; + use super::gen::{gen_interaction_trace, gen_trace}; use super::{prove_state_machine, verify_state_machine}; - use crate::constraint_framework::{assert_constraints, FrameworkEval}; + use crate::constraint_framework::constant_columns::gen_is_first; + use crate::constraint_framework::{assert_constraints, FrameworkEval, TraceLocationAllocator}; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; - use crate::core::pcs::PcsConfig; + use crate::core::pcs::{PcsConfig, TreeVec}; use crate::core::poly::circle::CanonicCoset; #[test] fn test_state_machine_constraints() { + let log_n_rows = 8; + let initial_state = [M31::zero(); N_STATE]; + + let trace = gen_trace(log_n_rows, initial_state, 0); + + let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); + + // Interaction trace. + let (interaction_trace, total_sum) = + gen_interaction_trace(log_n_rows, initial_state, 0, &lookup_elements); + + let component = StateMachineOp0Component::new( + &mut TraceLocationAllocator::default(), + StateTransitionEval { + log_n_rows, + lookup_elements, + total_sum, + }, + ); + + let trace = TreeVec::new(vec![ + trace, + interaction_trace, + vec![gen_is_first(log_n_rows)], + ]); + let trace_polys = trace.map_cols(|c| c.interpolate()); + assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| { + component.evaluate(eval); + }); + } + + #[test] + fn test_state_machine_total_sum() { let log_n_rows = 8; let config = PcsConfig::default(); // Initial and last state. let initial_state = [M31::zero(); N_STATE]; - let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()]; + let last_state = [ + M31::from_u32_unchecked(1 << log_n_rows), + M31::from_u32_unchecked(1 << (log_n_rows - 1)), + ]; // Setup protocol. let channel = &mut Blake2sChannel::default(); - let (component, _, trace_polys) = + let (component, _, _trace_polys) = prove_state_machine(log_n_rows, initial_state, config, channel); - let interaction_elements = component.lookup_elements.clone(); + let interaction_elements = component.component0.lookup_elements.clone(); let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); let last_state_comb: QM31 = interaction_elements.combine(&last_state); - // Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. + // Assert total_sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. assert_eq!( - component.total_sum * initial_state_comb * last_state_comb, + (component.component0.total_sum + component.component1.total_sum) + * initial_state_comb + * last_state_comb, last_state_comb - initial_state_comb ); - - // Assert constraints. - assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| { - component.evaluate(eval); - }); } #[test] @@ -152,11 +254,12 @@ mod tests { let log_n_rows = 8; let config = PcsConfig::default(); let initial_state = [M31::zero(); N_STATE]; + let prover_channel = &mut Blake2sChannel::default(); - let (component_op0, proof, _) = + let (components, proof, _) = prove_state_machine(log_n_rows, initial_state, config, prover_channel); let verifier_channel = &mut Blake2sChannel::default(); - verify_state_machine(config, verifier_channel, component_op0, proof).unwrap(); + verify_state_machine(config, verifier_channel, components, proof).unwrap(); } }