Skip to content

Commit

Permalink
State machine prover
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 19, 2024
1 parent 2495710 commit cce5da9
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 52 deletions.
88 changes: 85 additions & 3 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
use num_traits::One;
use num_traits::{One, Zero};

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::core::fields::qm31::QM31;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::prover::StarkProof;
use crate::core::vcs::ops::MerkleHasher;

pub const N_STATE: usize = 2;
pub type StateMachineElements = LookupElements<N_STATE>;
pub type State = [M31; N_STATE];

pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>;
pub type StateMachineOp1Component = FrameworkComponent<StateTransitionEval<1>>;

/// State machine with state of size `N_STATE`.
/// Transition `INDEX` of state increments the state by 1 at that offset.
Expand Down Expand Up @@ -42,3 +53,74 @@ impl<const INDEX: usize> FrameworkEval for StateTransitionEval<INDEX> {
eval
}
}

pub struct StateMachineStatement0 {
pub n: u32,
pub m: u32,
}
impl StateMachineStatement0 {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let sizes = vec![
state_transition_info::<0>()
.mask_offsets
.as_cols_ref()
.map_cols(|_| self.n),
state_transition_info::<1>()
.mask_offsets
.as_cols_ref()
.map_cols(|_| self.m),
];
TreeVec::concat_cols(sizes.into_iter())
}
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_u64(self.n as u64);
channel.mix_u64(self.m as u64);
}
}

pub struct StateMachineStatement1 {
pub x_axis_claimed_sum: SecureField,
pub y_axis_claimed_sum: SecureField,
}
impl StateMachineStatement1 {
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_felts(&[self.x_axis_claimed_sum, self.y_axis_claimed_sum])
}
}

fn state_transition_info<const INDEX: usize>() -> InfoEvaluator {
let component = StateTransitionEval::<INDEX> {
log_n_rows: 1,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
};
component.evaluate(InfoEvaluator::default())
}

pub struct StateMachineComponents {
pub component0: StateMachineOp0Component,
pub component1: StateMachineOp1Component,
}

impl StateMachineComponents {
pub fn components(&self) -> Vec<&dyn Component> {
vec![
&self.component0 as &dyn Component,
&self.component1 as &dyn Component,
]
}

pub fn component_provers(&self) -> Vec<&dyn ComponentProver<SimdBackend>> {
vec![
&self.component0 as &dyn ComponentProver<SimdBackend>,
&self.component1 as &dyn ComponentProver<SimdBackend>,
]
}
}

pub struct StateMachineProof<H: MerkleHasher> {
pub public_input: [State; 2], // Initial and final state.
pub stmt0: StateMachineStatement0,
pub stmt1: StateMachineStatement1,
pub stark_proof: StarkProof<H>,
}
4 changes: 1 addition & 3 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::Itertools;
use num_traits::One;

use super::components::N_STATE;
use super::components::{State, N_STATE};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
Expand All @@ -14,8 +14,6 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub type State = [M31; N_STATE];

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
// `inc_index` dimension.
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size].
Expand Down
Loading

0 comments on commit cce5da9

Please sign in to comment.