From e30683a4c2c19b7811a49a8aaec21b074d32ddc4 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 21 Nov 2023 20:17:20 +0000 Subject: [PATCH] chore: allow for multiple redirect calls during calibrate (#625) --- src/bin/ezkl.rs | 2 +- src/commands.rs | 5 +++-- src/execute.rs | 25 +++++++++++++++++++------ src/graph/mod.rs | 2 +- src/pfsys/mod.rs | 4 ++-- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/bin/ezkl.rs b/src/bin/ezkl.rs index ba4b210ab..b75424504 100644 --- a/src/bin/ezkl.rs +++ b/src/bin/ezkl.rs @@ -33,7 +33,7 @@ pub async fn main() -> Result<(), Box> { info!("Running with CPU"); } info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?); - let res = run(args).await; + let res = run(args.command).await; match &res { Ok(_) => info!("succeeded"), Err(e) => error!("failed: {}", e), diff --git a/src/commands.rs b/src/commands.rs index 6a4a0b2cb..62f10b272 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -51,7 +51,7 @@ impl<'source> FromPyObject<'source> for TranscriptType { } } -#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] /// Determines what the calibration pass should optimize for pub enum CalibrationTarget { /// Optimizes for reducing cpu and memory usage @@ -171,8 +171,9 @@ impl Cli { } #[allow(missing_docs)] -#[derive(Debug, Subcommand, Clone, Deserialize, Serialize)] +#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)] pub enum Commands { + Empty, /// Loads model and prints model table #[command(arg_required_else_help = true)] Table { diff --git a/src/execute.rs b/src/execute.rs index bf9e09530..937bbfdf4 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -1,7 +1,7 @@ use crate::circuit::CheckMode; #[cfg(not(target_arch = "wasm32"))] use crate::commands::CalibrationTarget; -use crate::commands::{Cli, Commands}; +use crate::commands::Commands; #[cfg(not(target_arch = "wasm32"))] use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity}; #[cfg(not(target_arch = "wasm32"))] @@ -111,8 +111,9 @@ pub enum ExecutionError { } /// Run an ezkl command with given args -pub async fn run(cli: Cli) -> Result<(), Box> { - match cli.command { +pub async fn run(command: Commands) -> Result<(), Box> { + match command { + Commands::Empty => Ok(()), #[cfg(not(target_arch = "wasm32"))] Commands::Fuzz { witness, @@ -615,7 +616,13 @@ pub(crate) fn calibrate( let settings = GraphSettings::load(&settings_path)?; // now retrieve the run args // we load the model to get the input and output shapes - let _r = Gag::stdout().unwrap(); + // check if gag already exists + + let _r = match Gag::stdout() { + Ok(r) => Some(r), + Err(_) => None, + }; + let model = Model::from_run_args(&settings.run_args, &model_path).unwrap(); // drop the gag std::mem::drop(_r); @@ -679,8 +686,14 @@ pub(crate) fn calibrate( // vec of settings copied chunks.len() times let run_args_iterable = vec![settings.run_args.clone(); chunks.len()]; - let _r = Gag::stdout().unwrap(); - let _q = Gag::stderr().unwrap(); + let _r = match Gag::stdout() { + Ok(r) => Some(r), + Err(_) => None, + }; + let _q = match Gag::stderr() { + Ok(r) => Some(r), + Err(_) => None, + }; let tasks = chunks .iter() diff --git a/src/graph/mod.rs b/src/graph/mod.rs index d0c256a8d..8dcb16ddb 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -501,7 +501,7 @@ impl GraphCircuit { } } -#[derive(Clone, Debug, Default, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)] /// The data source for a test pub enum TestDataSource { /// The data is loaded from a file diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 34ef13051..943faf589 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -43,7 +43,7 @@ use thiserror::Error as thisError; use halo2curves::bn256::{Bn256, Fr, G1Affine}; #[allow(missing_docs)] -#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)] pub enum ProofType { Single, ForAggr, @@ -142,7 +142,7 @@ pub enum PfSysError { } #[allow(missing_docs)] -#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)] pub enum TranscriptType { Poseidon, EVM,