diff --git a/Cargo.lock b/Cargo.lock index 4c1f122b0..c074a45fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3848,6 +3848,7 @@ version = "0.1.0" dependencies = [ "alloy", "anyhow", + "clap", "evm_arithmetization", "futures", "num-traits", diff --git a/proof_gen/src/proof_gen.rs b/proof_gen/src/proof_gen.rs index 7876baad4..754032aba 100644 --- a/proof_gen/src/proof_gen.rs +++ b/proof_gen/src/proof_gen.rs @@ -16,8 +16,8 @@ use plonky2::{ use crate::{ proof_types::{ - GeneratedBlockProof, GeneratedSegmentAggProof, GeneratedSegmentProof, GeneratedTxnAggProof, - SegmentAggregatableProof, TxnAggregatableProof, + BatchAggregatableProof, GeneratedBlockProof, GeneratedSegmentAggProof, + GeneratedSegmentProof, GeneratedTxnAggProof, SegmentAggregatableProof, }, prover_state::ProverState, types::{Field, PlonkyProofIntern, EXTENSION_DEGREE}, @@ -121,8 +121,8 @@ pub fn generate_segment_agg_proof( /// Note that the child proofs may be either transaction or aggregation proofs. pub fn generate_transaction_agg_proof( p_state: &ProverState, - lhs_child: &TxnAggregatableProof, - rhs_child: &TxnAggregatableProof, + lhs_child: &BatchAggregatableProof, + rhs_child: &BatchAggregatableProof, ) -> ProofGenResult { let (b_proof_intern, p_vals) = p_state .state diff --git a/proof_gen/src/proof_types.rs b/proof_gen/src/proof_types.rs index fea8f845f..9807f0b25 100644 --- a/proof_gen/src/proof_types.rs +++ b/proof_gen/src/proof_types.rs @@ -67,7 +67,7 @@ pub enum SegmentAggregatableProof { /// we can combine it into an agg proof. For these cases, we want to abstract /// away whether or not the proof was a txn or agg proof. #[derive(Clone, Debug, Deserialize, Serialize)] -pub enum TxnAggregatableProof { +pub enum BatchAggregatableProof { /// The underlying proof is a segment proof. It first needs to be aggregated /// with another segment proof, or a dummy one. Segment(GeneratedSegmentProof), @@ -100,28 +100,28 @@ impl SegmentAggregatableProof { } } -impl TxnAggregatableProof { +impl BatchAggregatableProof { pub(crate) fn public_values(&self) -> PublicValues { match self { - TxnAggregatableProof::Segment(info) => info.p_vals.clone(), - TxnAggregatableProof::Txn(info) => info.p_vals.clone(), - TxnAggregatableProof::Agg(info) => info.p_vals.clone(), + BatchAggregatableProof::Segment(info) => info.p_vals.clone(), + BatchAggregatableProof::Txn(info) => info.p_vals.clone(), + BatchAggregatableProof::Agg(info) => info.p_vals.clone(), } } pub(crate) fn is_agg(&self) -> bool { match self { - TxnAggregatableProof::Segment(_) => false, - TxnAggregatableProof::Txn(_) => false, - TxnAggregatableProof::Agg(_) => true, + BatchAggregatableProof::Segment(_) => false, + BatchAggregatableProof::Txn(_) => false, + BatchAggregatableProof::Agg(_) => true, } } pub(crate) fn intern(&self) -> &PlonkyProofIntern { match self { - TxnAggregatableProof::Segment(info) => &info.intern, - TxnAggregatableProof::Txn(info) => &info.intern, - TxnAggregatableProof::Agg(info) => &info.intern, + BatchAggregatableProof::Segment(info) => &info.intern, + BatchAggregatableProof::Txn(info) => &info.intern, + BatchAggregatableProof::Agg(info) => &info.intern, } } } @@ -138,23 +138,23 @@ impl From for SegmentAggregatableProof { } } -impl From for TxnAggregatableProof { +impl From for BatchAggregatableProof { fn from(v: GeneratedSegmentAggProof) -> Self { Self::Txn(v) } } -impl From for TxnAggregatableProof { +impl From for BatchAggregatableProof { fn from(v: GeneratedTxnAggProof) -> Self { Self::Agg(v) } } -impl From for TxnAggregatableProof { +impl From for BatchAggregatableProof { fn from(v: SegmentAggregatableProof) -> Self { match v { - SegmentAggregatableProof::Agg(agg) => TxnAggregatableProof::Txn(agg), - SegmentAggregatableProof::Seg(seg) => TxnAggregatableProof::Segment(seg), + SegmentAggregatableProof::Agg(agg) => BatchAggregatableProof::Txn(agg), + SegmentAggregatableProof::Seg(seg) => BatchAggregatableProof::Segment(seg), } } } diff --git a/zero_bin/leader/src/cli.rs b/zero_bin/leader/src/cli.rs index 9ec32a420..ccb09fd1f 100644 --- a/zero_bin/leader/src/cli.rs +++ b/zero_bin/leader/src/cli.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use alloy::transports::http::reqwest::Url; use clap::{Parser, Subcommand, ValueHint}; +use prover::cli::CliProverConfig; use rpc::RpcType; use zero_bin_common::prover_state::cli::CliProverStateConfig; @@ -14,6 +15,9 @@ pub(crate) struct Cli { #[clap(flatten)] pub(crate) paladin: paladin::config::Config, + #[clap(flatten)] + pub(crate) prover_config: CliProverConfig, + // Note this is only relevant for the leader when running in in-memory // mode. #[clap(flatten)] @@ -27,13 +31,6 @@ pub(crate) enum Command { /// The previous proof output. #[arg(long, short = 'f', value_hint = ValueHint::FilePath)] previous_proof: Option, - #[arg(short, long, default_value_t = 20)] - max_cpu_len_log: usize, - #[arg(short, long, default_value_t = 1)] - batch_size: usize, - /// If true, save the public inputs to disk on error. - #[arg(short, long, default_value_t = false)] - save_inputs_on_error: bool, }, /// Reads input from a node rpc and writes output to stdout. Rpc { @@ -56,14 +53,6 @@ pub(crate) enum Command { /// stdout. #[arg(long, short = 'o', value_hint = ValueHint::FilePath)] proof_output_dir: Option, - /// The log of the max number of CPU cycles per proof. - #[arg(short, long, default_value_t = 20)] - max_cpu_len_log: usize, - #[arg(short, long, default_value_t = 1)] - batch_size: usize, - /// If true, save the public inputs to disk on error. - #[arg(short, long, default_value_t = false)] - save_inputs_on_error: bool, /// Network block time in milliseconds. This value is used /// to determine the blockchain node polling interval. #[arg(short, long, env = "ZERO_BIN_BLOCK_TIME", default_value_t = 2000)] @@ -92,12 +81,5 @@ pub(crate) enum Command { /// The directory to which output should be written. #[arg(short, long, value_hint = ValueHint::DirPath)] output_dir: PathBuf, - #[arg(short, long, default_value_t = 20)] - max_cpu_len_log: usize, - #[arg(short, long, default_value_t = 1)] - batch_size: usize, - /// If true, save the public inputs to disk on error. - #[arg(short, long, default_value_t = false)] - save_inputs_on_error: bool, }, } diff --git a/zero_bin/leader/src/client.rs b/zero_bin/leader/src/client.rs index 50993d4a3..74910f621 100644 --- a/zero_bin/leader/src/client.rs +++ b/zero_bin/leader/src/client.rs @@ -5,6 +5,7 @@ use alloy::transports::http::reqwest::Url; use anyhow::Result; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; +use prover::ProverConfig; use rpc::{retry::build_http_retry_provider, RpcType}; use tracing::{error, info, warn}; use zero_bin_common::block_interval::BlockInterval; @@ -18,14 +19,12 @@ pub struct RpcParams { pub max_retries: u32, } -#[derive(Debug, Default)] +#[derive(Debug)] pub struct ProofParams { pub checkpoint_block_number: u64, pub previous_proof: Option, pub proof_output_dir: Option, - pub max_cpu_len_log: usize, - pub batch_size: usize, - pub save_inputs_on_error: bool, + pub prover_config: ProverConfig, pub keep_intermediate_proofs: bool, } @@ -56,10 +55,8 @@ pub(crate) async fn client_main( let proved_blocks = prover_input .prove( &runtime, - params.max_cpu_len_log, params.previous_proof.take(), - params.batch_size, - params.save_inputs_on_error, + params.prover_config, params.proof_output_dir.clone(), ) .await; diff --git a/zero_bin/leader/src/http.rs b/zero_bin/leader/src/http.rs index 971192384..9137622be 100644 --- a/zero_bin/leader/src/http.rs +++ b/zero_bin/leader/src/http.rs @@ -5,7 +5,7 @@ use anyhow::{bail, Result}; use axum::{http::StatusCode, routing::post, Json, Router}; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; -use prover::BlockProverInput; +use prover::{BlockProverInput, ProverConfig}; use serde::{Deserialize, Serialize}; use serde_json::to_writer; use tracing::{debug, error, info}; @@ -15,9 +15,7 @@ pub(crate) async fn http_main( runtime: Runtime, port: u16, output_dir: PathBuf, - max_cpu_len_log: usize, - batch_size: usize, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result<()> { let addr = SocketAddr::from(([0, 0, 0, 0], port)); debug!("listening on {}", addr); @@ -27,16 +25,7 @@ pub(crate) async fn http_main( "/prove", post({ let runtime = runtime.clone(); - move |body| { - prove( - body, - runtime, - output_dir.clone(), - max_cpu_len_log, - batch_size, - save_inputs_on_error, - ) - } + move |body| prove(body, runtime, output_dir.clone(), prover_config) }), ); let listener = tokio::net::TcpListener::bind(&addr).await?; @@ -76,9 +65,7 @@ async fn prove( Json(payload): Json, runtime: Arc, output_dir: PathBuf, - max_cpu_len_log: usize, - batch_size: usize, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> StatusCode { debug!("Received payload: {:#?}", payload); @@ -88,10 +75,8 @@ async fn prove( .prover_input .prove( &runtime, - max_cpu_len_log, payload.previous.map(futures::future::ok), - batch_size, - save_inputs_on_error, + prover_config, ) .await { diff --git a/zero_bin/leader/src/main.rs b/zero_bin/leader/src/main.rs index f76e94300..5886e264f 100644 --- a/zero_bin/leader/src/main.rs +++ b/zero_bin/leader/src/main.rs @@ -62,30 +62,14 @@ async fn main() -> Result<()> { let runtime = Runtime::from_config(&args.paladin, register()).await?; + let cli_prover_config = args.prover_config; + match args.command { - Command::Stdio { - previous_proof, - max_cpu_len_log, - batch_size, - save_inputs_on_error, - } => { + Command::Stdio { previous_proof } => { let previous_proof = get_previous_proof(previous_proof)?; - stdio::stdio_main( - runtime, - max_cpu_len_log, - previous_proof, - batch_size, - save_inputs_on_error, - ) - .await?; + stdio::stdio_main(runtime, previous_proof, cli_prover_config.into()).await?; } - Command::Http { - port, - output_dir, - max_cpu_len_log, - batch_size, - save_inputs_on_error, - } => { + Command::Http { port, output_dir } => { // check if output_dir exists, is a directory, and is writable let output_dir_metadata = std::fs::metadata(&output_dir); if output_dir_metadata.is_err() { @@ -95,15 +79,7 @@ async fn main() -> Result<()> { panic!("output-dir is not a writable directory"); } - http::http_main( - runtime, - port, - output_dir, - max_cpu_len_log, - batch_size, - save_inputs_on_error, - ) - .await?; + http::http_main(runtime, port, output_dir, cli_prover_config.into()).await?; } Command::Rpc { rpc_url, @@ -112,9 +88,6 @@ async fn main() -> Result<()> { checkpoint_block_number, previous_proof, proof_output_dir, - max_cpu_len_log, - batch_size, - save_inputs_on_error, block_time, keep_intermediate_proofs, backoff, @@ -145,9 +118,7 @@ async fn main() -> Result<()> { checkpoint_block_number, previous_proof, proof_output_dir, - max_cpu_len_log, - batch_size, - save_inputs_on_error, + prover_config: cli_prover_config.into(), keep_intermediate_proofs, }, ) diff --git a/zero_bin/leader/src/stdio.rs b/zero_bin/leader/src/stdio.rs index 3b8bc2660..d74f4dce6 100644 --- a/zero_bin/leader/src/stdio.rs +++ b/zero_bin/leader/src/stdio.rs @@ -3,16 +3,14 @@ use std::io::{Read, Write}; use anyhow::Result; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; -use prover::ProverInput; +use prover::{ProverConfig, ProverInput}; use tracing::info; /// The main function for the stdio mode. pub(crate) async fn stdio_main( runtime: Runtime, - max_cpu_len_log: usize, previous: Option, - batch_size: usize, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result<()> { let mut buffer = String::new(); std::io::stdin().read_to_string(&mut buffer)?; @@ -23,14 +21,7 @@ pub(crate) async fn stdio_main( }; let proved_blocks = prover_input - .prove( - &runtime, - max_cpu_len_log, - previous, - batch_size, - save_inputs_on_error, - None, - ) + .prove(&runtime, previous, prover_config, None) .await; runtime.close().await?; let proved_blocks = proved_blocks?; diff --git a/zero_bin/ops/src/lib.rs b/zero_bin/ops/src/lib.rs index 803d651af..ad8241215 100644 --- a/zero_bin/ops/src/lib.rs +++ b/zero_bin/ops/src/lib.rs @@ -13,7 +13,7 @@ use paladin::{ use proof_gen::{ proof_gen::{generate_block_proof, generate_segment_agg_proof, generate_transaction_agg_proof}, proof_types::{ - GeneratedBlockProof, GeneratedTxnAggProof, SegmentAggregatableProof, TxnAggregatableProof, + BatchAggregatableProof, GeneratedBlockProof, GeneratedTxnAggProof, SegmentAggregatableProof, }, }; use serde::{Deserialize, Serialize}; @@ -207,23 +207,23 @@ impl Monoid for SegmentAggProof { } #[derive(Deserialize, Serialize, RemoteExecute)] -pub struct TxnAggProof { +pub struct BatchAggProof { pub save_inputs_on_error: bool, } -fn get_agg_proof_public_values(elem: TxnAggregatableProof) -> PublicValues { +fn get_agg_proof_public_values(elem: BatchAggregatableProof) -> PublicValues { match elem { - TxnAggregatableProof::Segment(info) => info.p_vals, - TxnAggregatableProof::Txn(info) => info.p_vals, - TxnAggregatableProof::Agg(info) => info.p_vals, + BatchAggregatableProof::Segment(info) => info.p_vals, + BatchAggregatableProof::Txn(info) => info.p_vals, + BatchAggregatableProof::Agg(info) => info.p_vals, } } -impl Monoid for TxnAggProof { - type Elem = TxnAggregatableProof; +impl Monoid for BatchAggProof { + type Elem = BatchAggregatableProof; fn combine(&self, a: Self::Elem, b: Self::Elem) -> Result { let lhs = match a { - TxnAggregatableProof::Segment(segment) => TxnAggregatableProof::from( + BatchAggregatableProof::Segment(segment) => BatchAggregatableProof::from( generate_segment_agg_proof( p_state(), &SegmentAggregatableProof::from(segment.clone()), @@ -236,7 +236,7 @@ impl Monoid for TxnAggProof { }; let rhs = match b { - TxnAggregatableProof::Segment(segment) => TxnAggregatableProof::from( + BatchAggregatableProof::Segment(segment) => BatchAggregatableProof::from( generate_segment_agg_proof( p_state(), &SegmentAggregatableProof::from(segment.clone()), diff --git a/zero_bin/prover/Cargo.toml b/zero_bin/prover/Cargo.toml index 3c2d9e131..d750f5bb1 100644 --- a/zero_bin/prover/Cargo.toml +++ b/zero_bin/prover/Cargo.toml @@ -25,6 +25,7 @@ ruint = { workspace = true, features = ["num-traits", "primitive-types"] } ops = { workspace = true } zero_bin_common ={ workspace = true } num-traits = { workspace = true } +clap = {workspace = true} [features] default = [] diff --git a/zero_bin/prover/src/cli.rs b/zero_bin/prover/src/cli.rs new file mode 100644 index 000000000..83a8f2e5a --- /dev/null +++ b/zero_bin/prover/src/cli.rs @@ -0,0 +1,27 @@ +use clap::Args; + +const HELP_HEADING: &str = "Prover options"; + +/// Represents the main configuration structure for the runtime. +#[derive(Args, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default)] +pub struct CliProverConfig { + /// The log of the max number of CPU cycles per proof. + #[arg(short, long, help_heading = HELP_HEADING, default_value_t = 20)] + max_cpu_len_log: usize, + /// Number of transactions in a batch to process at once. + #[arg(short, long, help_heading = HELP_HEADING, default_value_t = 1)] + batch_size: usize, + /// If true, save the public inputs to disk on error. + #[arg(short='i', long, help_heading = HELP_HEADING, default_value_t = false)] + save_inputs_on_error: bool, +} + +impl From for crate::ProverConfig { + fn from(cli: CliProverConfig) -> Self { + Self { + batch_size: cli.batch_size, + max_cpu_len_log: cli.max_cpu_len_log, + save_inputs_on_error: cli.save_inputs_on_error, + } + } +} diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs index 82b6c3ec2..dcbb15c7e 100644 --- a/zero_bin/prover/src/lib.rs +++ b/zero_bin/prover/src/lib.rs @@ -1,3 +1,5 @@ +pub mod cli; + use std::future::Future; use std::path::PathBuf; @@ -18,6 +20,13 @@ use trace_decoder::{ use tracing::info; use zero_bin_common::fs::generate_block_proof_file_name; +#[derive(Debug, Clone, Copy)] +pub struct ProverConfig { + pub batch_size: usize, + pub max_cpu_len_log: usize, + pub save_inputs_on_error: bool, +} + #[derive(Debug, Deserialize, Serialize)] pub struct BlockProverInput { pub block_trace: BlockTrace, @@ -36,10 +45,8 @@ impl BlockProverInput { pub async fn prove( self, runtime: &Runtime, - max_cpu_len_log: usize, previous: Option>>, - batch_size: usize, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result { use anyhow::Context as _; use evm_arithmetization::prover::SegmentDataIterator; @@ -47,51 +54,62 @@ impl BlockProverInput { use paladin::directive::{Directive, IndexedStream}; use proof_gen::types::Field; - let block_number = self.get_block_number(); + let ProverConfig { + max_cpu_len_log, + batch_size, + save_inputs_on_error, + } = prover_config; + let block_number = self.get_block_number(); let other_data = self.other_data; - let txs = self.block_trace.into_txn_proof_gen_ir( + let block_generation_inputs = self.block_trace.into_txn_proof_gen_ir( &ProcessingMeta::new(resolve_code_hash_fn), other_data.clone(), batch_size, )?; - // Generate segment data. - let agg_ops = ops::SegmentAggProof { + // Create segment proof. + let seg_prove_ops = ops::SegmentProof { save_inputs_on_error, }; - let seg_ops = ops::SegmentProof { + // Aggregate multiple segment proofs to resulting segment proof. + let seg_agg_ops = ops::SegmentAggProof { save_inputs_on_error, }; - // Map the transactions to a stream of transaction proofs. - let tx_proof_futs: FuturesUnordered<_> = txs + // Aggregate batch proofs to a single proof. + let batch_agg_ops = ops::BatchAggProof { + save_inputs_on_error, + }; + + // Segment the batches, prove segments and aggregate them to resulting batch + // proofs. + let batch_proof_futs: FuturesUnordered<_> = block_generation_inputs .iter() .enumerate() - .map(|(idx, txn)| { - let data_iterator = SegmentDataIterator::::new(txn, Some(max_cpu_len_log)); + .map(|(idx, txn_batch)| { + let segment_data_iterator = + SegmentDataIterator::::new(txn_batch, Some(max_cpu_len_log)); - Directive::map(IndexedStream::from(data_iterator), &seg_ops) - .fold(&agg_ops) + Directive::map(IndexedStream::from(segment_data_iterator), &seg_prove_ops) + .fold(&seg_agg_ops) .run(runtime) .map(move |e| { - e.map(|p| (idx, proof_gen::proof_types::TxnAggregatableProof::from(p))) + e.map(|p| (idx, proof_gen::proof_types::BatchAggregatableProof::from(p))) }) }) .collect(); - // Fold the transaction proof stream into a single transaction proof. - let final_txn_proof = Directive::fold( - IndexedStream::new(tx_proof_futs), - &ops::TxnAggProof { - save_inputs_on_error, - }, + // Fold the batch aggregated proof stream into a single proof. + let final_batch_proof = Directive::fold( + IndexedStream::new(batch_proof_futs), + &batch_agg_ops, ) .run(runtime) .await?; - if let proof_gen::proof_types::TxnAggregatableProof::Agg(proof) = final_txn_proof { + if let proof_gen::proof_types::BatchAggregatableProof::Agg(proof) = final_batch_proof { let block_number = block_number .to_u64() .context("block number overflows u64")?; @@ -120,10 +138,8 @@ impl BlockProverInput { pub async fn prove( self, _runtime: &Runtime, - max_cpu_len_log: usize, - _previous: Option>>, - batch_size: usize, - _save_inputs_on_error: bool, + previous: Option>>, + prover_config: ProverConfig, ) -> Result { use evm_arithmetization::prover::testing::simulate_execution_all_segments; use plonky2::field::goldilocks_field::GoldilocksField; @@ -135,14 +151,20 @@ impl BlockProverInput { let txs = self.block_trace.into_txn_proof_gen_ir( &ProcessingMeta::new(resolve_code_hash_fn), other_data.clone(), - batch_size, + prover_config.batch_size, )?; type F = GoldilocksField; for txn in txs.into_iter() { - simulate_execution_all_segments::(txn, max_cpu_len_log)?; + simulate_execution_all_segments::(txn, prover_config.max_cpu_len_log)?; } + // Wait for previous block proof + let _prev = match previous { + Some(it) => Some(it.await?), + None => None, + }; + info!("Successfully generated witness for block {block_number}."); // Dummy proof to match expected output type. @@ -167,10 +189,8 @@ impl ProverInput { pub async fn prove( self, runtime: &Runtime, - max_cpu_len_log: usize, previous_proof: Option, - batch_size: usize, - save_inputs_on_error: bool, + prover_config: ProverConfig, proof_output_dir: Option, ) -> Result)>> { let mut prev: Option>> = @@ -180,21 +200,12 @@ impl ProverInput { .blocks .into_iter() .map(|block| { - let block_number = block.get_block_number(); - info!("Proving block {block_number}"); - let (tx, rx) = oneshot::channel::(); // Prove the block let proof_output_dir = proof_output_dir.clone(); let fut = block - .prove( - runtime, - max_cpu_len_log, - prev.take(), - batch_size, - save_inputs_on_error, - ) + .prove(runtime, prev.take(), prover_config) .then(move |proof| async move { let proof = proof?; let block_number = proof.b_height;