diff --git a/Cargo.lock b/Cargo.lock index 3742d707c..3225e59d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2263,7 +2263,7 @@ dependencies = [ [[package]] name = "halo2_gadgets" version = "0.2.0" -source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f" +source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f" dependencies = [ "arrayvec 0.7.4", "bitvec 1.0.1", @@ -2280,7 +2280,7 @@ dependencies = [ [[package]] name = "halo2_proofs" version = "0.3.0" -source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f" +source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f" dependencies = [ "blake2b_simd", "env_logger", diff --git a/examples/notebooks/tictactoe_autoencoder.ipynb b/examples/notebooks/tictactoe_autoencoder.ipynb index 90e026d1b..8cdc43aa6 100644 --- a/examples/notebooks/tictactoe_autoencoder.ipynb +++ b/examples/notebooks/tictactoe_autoencoder.ipynb @@ -633,7 +633,7 @@ "json.dump(data, open(cal_path, 'w'))\n", "\n", "\n", - "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])" + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])" ] }, { @@ -664,7 +664,6 @@ " compiled_model_path,\n", " vk_path,\n", " pk_path,\n", - " \n", ")" ] }, diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index a6b02edf8..7928be2de 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -277,7 +277,7 @@ impl Op for HybridOp { .. } => { if denom.0.fract() == 0.0 && *use_range_check_for_int { - layouts::div( + layouts::loop_div( config, region, values[..].try_into()?, diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 0e95b0c19..dffd1405a 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -18,10 +18,7 @@ use super::{ region::RegionCtx, }; use crate::{ - circuit::{ - ops::base::BaseOp, - utils::{self, F32}, - }, + circuit::{ops::base::BaseOp, utils}, fieldutils::{felt_to_i128, i128_to_felt}, tensor::{ get_broadcasted_shape, @@ -54,6 +51,41 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi total_len } +/// Same as div but splits the division into N parts +pub fn loop_div( + config: &BaseConfig, + region: &mut RegionCtx, + value: &[ValTensor; 1], + divisor: F, +) -> Result, Box> { + if divisor == F::ONE { + return Ok(value[0].clone()); + } + + // if integer val is divisible by 2, we can use a faster method and div > F::S + let mut divisor = divisor; + let mut num_parts = 1; + + while felt_to_i128(divisor) % 2 == 0 && felt_to_i128(divisor) > (2_i128.pow(F::S - 4)) { + divisor = i128_to_felt(felt_to_i128(divisor) / 2); + num_parts += 1; + } + + let output = div(config, region, value, divisor)?; + if num_parts == 1 { + return Ok(output); + } + + let divisor_int = 2_i128.pow(num_parts - 1); + let divisor_felt = i128_to_felt(divisor_int); + if divisor_int <= 2_i128.pow(F::S - 3) { + div(config, region, &[output], divisor_felt) + } else { + // keep splitting the divisor until it satisfies the condition + loop_div(config, region, &[output], divisor_felt) + } +} + /// Div accumulated layout pub fn div( config: &BaseConfig, @@ -61,6 +93,10 @@ pub fn div( value: &[ValTensor; 1], div: F, ) -> Result, Box> { + if div == F::ONE { + return Ok(value[0].clone()); + } + let input = value[0].clone(); let input_dims = input.dims(); @@ -88,6 +124,8 @@ pub fn div( .into() }; claimed_output.reshape(input_dims)?; + region.assign(&config.output, &claimed_output)?; + region.increment(claimed_output.len()); let product = pairwise( config, @@ -96,8 +134,6 @@ pub fn div( BaseOp::Mult, )?; - log::debug!("product: {:?}", product.get_int_evals()?); - let diff_with_input = pairwise( config, region, @@ -105,8 +141,6 @@ pub fn div( BaseOp::Sub, )?; - log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?); - range_check( config, region, @@ -117,6 +151,46 @@ pub fn div( Ok(claimed_output) } +fn recip_int( + config: &BaseConfig, + region: &mut RegionCtx, + input: &[ValTensor; 1], +) -> Result, Box> { + // assert is boolean + let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0]; + // get values where input is 0 + let zero_mask = equals_zero(config, region, input)?; + + let one_minus_zero_mask = pairwise( + config, + region, + &[ + zero_mask.clone(), + ValTensor::from(Tensor::from([ValType::Constant(F::ONE)].into_iter())), + ], + BaseOp::Sub, + )?; + + let zero_inverse_val = pairwise( + config, + region, + &[ + zero_mask, + ValTensor::from(Tensor::from( + [ValType::Constant(i128_to_felt(zero_inverse_val))].into_iter(), + )), + ], + BaseOp::Mult, + )?; + + pairwise( + config, + region, + &[one_minus_zero_mask, zero_inverse_val], + BaseOp::Add, + ) +} + /// recip accumulated layout pub fn recip( config: &BaseConfig, @@ -125,10 +199,23 @@ pub fn recip( input_scale: F, output_scale: F, ) -> Result, Box> { + if output_scale == F::ONE || output_scale == F::ZERO { + return recip_int(config, region, value); + } + let input = value[0].clone(); let input_dims = input.dims(); - let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2; + let integer_input_scale = felt_to_i128(input_scale); + let integer_output_scale = felt_to_i128(output_scale); + + // range_check_bracket is min of input_scale * output_scale and 2^F::S - 3 + let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4)); + + let input_scale_ratio = + i128_to_felt(integer_input_scale * integer_output_scale / range_check_len); + + let range_check_bracket = range_check_len / 2; let is_assigned = !input.any_unknowns()?; @@ -151,6 +238,8 @@ pub fn recip( .into() }; claimed_output.reshape(input_dims)?; + let claimed_output = region.assign(&config.output, &claimed_output)?; + region.increment(claimed_output.len()); // this is now of scale 2 * scale let product = pairwise( @@ -160,15 +249,46 @@ pub fn recip( BaseOp::Mult, )?; - log::debug!("product: {:?}", product.get_int_evals()?); + // divide by input_scale + let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?; - log::debug!("range_check_bracket: {:?}", range_check_bracket); + let zero_inverse_val = + tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0]; + let zero_inverse = + Tensor::from([ValType::Constant(i128_to_felt::(zero_inverse_val))].into_iter()); + + let equal_zero_mask = equals_zero(config, region, &[input.clone()])?; + + let equal_inverse_mask = equals( + config, + region, + &[claimed_output.clone(), zero_inverse.into()], + )?; + + // assert the two masks are equal + enforce_equality( + config, + region, + &[equal_zero_mask.clone(), equal_inverse_mask], + )?; + + let unit_scale = Tensor::from([ValType::Constant(i128_to_felt(range_check_len))].into_iter()); + + let unit_mask = pairwise( + config, + region, + &[equal_zero_mask, unit_scale.into()], + BaseOp::Mult, + )?; + + // now add the unit mask to the rebased_div + let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?; // at most the error should be in the original unit scale's range range_check( config, region, - &[product], + &[rebased_offset_div], &(range_check_bracket, 3 * range_check_bracket), )?; @@ -1677,9 +1797,23 @@ pub fn equals( values: &[ValTensor; 2], ) -> Result, Box> { let diff = pairwise(config, region, values, BaseOp::Sub)?; - let diff_inverse = diff.inverse()?; - let product_diff_and_invert = - pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?; + equals_zero(config, region, &[diff]) +} + +/// Equality boolean operation +pub fn equals_zero( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let values = values[0].clone(); + let values_inverse = values.inverse()?; + let product_values_and_invert = pairwise( + config, + region, + &[values.clone(), values_inverse], + BaseOp::Mult, + )?; // constant of 1 let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter()); @@ -1689,12 +1823,12 @@ pub fn equals( let output = pairwise( config, region, - &[ones.into(), product_diff_and_invert], + &[ones.into(), product_values_and_invert], BaseOp::Sub, )?; // take the product of diff and output - let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?; + let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?; is_zero_identity(config, region, &[prod_check], false)?; @@ -1860,7 +1994,7 @@ pub fn sumpool( last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?; if normalized { - last_elem = div( + last_elem = loop_div( config, region, &[last_elem], @@ -2519,6 +2653,17 @@ pub fn range_check( .collect::, Box>>()?; } + if region.throw_range_check_error() { + // assert is within range + let int_values = w.get_int_evals()?; + for v in int_values { + if v < range.0 || v > range.1 { + log::debug!("Value ({:?}) out of range: {:?}", v, range); + return Err(Box::new(TensorError::TableLookupError)); + } + } + } + region.increment(assigned_len); let elapsed = timer.elapsed(); @@ -2945,16 +3090,8 @@ pub fn softmax( let denom = sum(config, region, &[ex.clone()])?; // get the inverse - let inv_denom = nonlinearity( - config, - region, - &[denom], - // we set to input scale + output_scale so the output scale is output)scale - &LookupOp::Recip { - input_scale: scale, - output_scale: scale, - }, - )?; + let felt_scale = F::from(scale.0 as u64); + let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?; // product of num * (1 / denom) = 2*output_scale let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?; @@ -2989,29 +3126,44 @@ pub fn range_check_percent( // Calculate the difference between the expected output and actual output let diff = pairwise(config, region, &values, BaseOp::Sub)?; - // Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor - let recip = nonlinearity( + // integer scale + let int_scale = scale.0 as i128; + // felt scale + let felt_scale = i128_to_felt(int_scale); + // range check len capped at 2^(S-3) and make it divisible 2 + let range_check_bracket = std::cmp::min( + utils::F32(scale.0), + utils::F32(2_f32.powf((F::S - 5) as f32)), + ) + .0; + + let range_check_bracket_int = range_check_bracket as i128; + + // input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent + let input_scale_ratio = ((scale.0.powf(2.0) / range_check_bracket) * tol) as i128 / 2 * 2; + + let recip = recip( config, region, &[values[0].clone()], - &LookupOp::Recip { - input_scale: scale, - // multiply by 100 to get the percent error - output_scale: F32(scale.0 * 100.0), - }, + felt_scale, + felt_scale * F::from(100), )?; + log::debug!("recip: {}", recip.show()); + // Multiply the difference by the recip let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?; - let rebased_product = div(config, region, &[product], F::from(scale.0 as u64))?; - let scaled_tol = (tol * scale.0) as i128; + log::debug!("product: {}", product.show()); + let rebased_product = loop_div(config, region, &[product], i128_to_felt(input_scale_ratio))?; + log::debug!("rebased_product: {}", rebased_product.show()); // check that it is within the tolerance range range_check( config, region, &[rebased_product], - &(-scaled_tol, scaled_tol), + &(-range_check_bracket_int, range_check_bracket_int), ) } diff --git a/src/circuit/ops/region.rs b/src/circuit/ops/region.rs index e546dc375..c69052601 100644 --- a/src/circuit/ops/region.rs +++ b/src/circuit/ops/region.rs @@ -70,8 +70,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> { used_range_checks: HashSet, max_lookup_inputs: i128, min_lookup_inputs: i128, - min_range_check: i128, - max_range_check: i128, + max_range_size: i128, + throw_range_check_error: bool, } impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { @@ -80,6 +80,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { self.total_constants += n; } + /// + pub fn throw_range_check_error(&self) -> bool { + self.throw_range_check_error + } + /// Create a new region context pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> { let region = Some(RefCell::new(region)); @@ -95,8 +100,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { used_range_checks: HashSet::new(), max_lookup_inputs: 0, min_lookup_inputs: 0, - max_range_check: 0, - min_range_check: 0, + max_range_size: 0, + throw_range_check_error: false, } } /// Create a new region context from a wrapped region @@ -116,13 +121,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { used_range_checks: HashSet::new(), max_lookup_inputs: 0, min_lookup_inputs: 0, - max_range_check: 0, - min_range_check: 0, + max_range_size: 0, + throw_range_check_error: false, } } /// Create a new region context - pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> { + pub fn new_dummy( + row: usize, + num_inner_cols: usize, + throw_range_check_error: bool, + ) -> RegionCtx<'a, F> { let region = None; let linear_coord = row * num_inner_cols; @@ -136,8 +145,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { used_range_checks: HashSet::new(), max_lookup_inputs: 0, min_lookup_inputs: 0, - max_range_check: 0, - min_range_check: 0, + max_range_size: 0, + throw_range_check_error, } } @@ -149,6 +158,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { num_inner_cols: usize, used_lookups: HashSet, used_range_checks: HashSet, + throw_range_check_error: bool, ) -> RegionCtx<'a, F> { let region = None; RegionCtx { @@ -161,8 +171,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { used_range_checks, max_lookup_inputs: 0, min_lookup_inputs: 0, - max_range_check: 0, - min_range_check: 0, + max_range_size: 0, + throw_range_check_error, } } @@ -234,6 +244,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { self.num_inner_cols, HashSet::new(), HashSet::new(), + self.throw_range_check_error, ); let res = inner_loop_function(idx, &mut local_reg); // we update the offset and constants @@ -310,8 +321,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { return Err("update_max_min_lookup_range: invalid range".into()); } - self.max_range_check = self.max_range_check.max(range.1); - self.min_range_check = self.min_range_check.min(range.0); + let range_size = (range.1 - range.0).abs(); + + self.max_range_size = self.max_range_size.max(range_size); Ok(()) } @@ -371,14 +383,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { self.min_lookup_inputs } - /// min range check - pub fn min_range_check(&self) -> i128 { - self.min_range_check - } - /// max range check - pub fn max_range_check(&self) -> i128 { - self.max_range_check + pub fn max_range_size(&self) -> i128 { + self.max_range_size } /// Assign a constant value diff --git a/src/circuit/table.rs b/src/circuit/table.rs index cb2b62089..f8aca6246 100644 --- a/src/circuit/table.rs +++ b/src/circuit/table.rs @@ -133,9 +133,7 @@ impl Table { } /// -pub fn num_cols_required(range: Range, col_size: usize) -> usize { - // double it to be safe - let range_len = range.1 - range.0; +pub fn num_cols_required(range_len: i128, col_size: usize) -> usize { // number of cols needed to store the range (range_len / (col_size as i128)) as usize + 1 } @@ -152,7 +150,7 @@ impl Table { let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD; let col_size = Self::cal_col_size(logrows, factors); // number of cols needed to store the range - let num_cols = num_cols_required(range, col_size); + let num_cols = num_cols_required((range.1 - range.0).abs(), col_size); log::debug!("table range: {:?}", range); @@ -313,7 +311,7 @@ impl RangeCheck { let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD; let col_size = Self::cal_col_size(logrows, factors); // number of cols needed to store the range - let num_cols = num_cols_required(range, col_size); + let num_cols = num_cols_required((range.1 - range.0).abs(), col_size); let inputs = { let mut cols = vec![]; diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 31a349e6b..2f61c2be2 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1,4 +1,3 @@ -use crate::circuit::ops::hybrid::HybridOp; use crate::circuit::ops::poly::PolyOp; use crate::circuit::*; use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor}; @@ -2338,113 +2337,3 @@ mod lookup_ultra_overflow { println!("done."); } } - -#[cfg(test)] -mod softmax { - - use super::*; - use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - dev::MockProver, - plonk::{Circuit, ConstraintSystem, Error}, - }; - - const K: usize = 18; - const LEN: usize = 3; - const SCALE: f32 = 128.0; - - #[derive(Clone)] - struct SoftmaxCircuit { - pub input: ValTensor, - _marker: PhantomData, - } - - impl Circuit for SoftmaxCircuit { - type Config = BaseConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = TestParams; - - fn without_witnesses(&self) -> Self { - self.clone() - } - fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let a = VarTensor::new_advice(cs, K, 1, LEN); - let b = VarTensor::new_advice(cs, K, 1, LEN); - let output = VarTensor::new_advice(cs, K, 1, LEN); - let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE); - let advices = (0..3) - .map(|_| VarTensor::new_advice(cs, K, 1, LEN)) - .collect::>(); - - config - .configure_lookup( - cs, - &advices[0], - &advices[1], - &advices[2], - (-32768, 32768), - K, - &LookupOp::Exp { - scale: SCALE.into(), - }, - ) - .unwrap(); - config - .configure_lookup( - cs, - &advices[0], - &advices[1], - &advices[2], - (-32768, 32768), - K, - &LookupOp::Recip { - input_scale: SCALE.into(), - output_scale: SCALE.into(), - }, - ) - .unwrap(); - config - } - - fn synthesize( - &self, - mut config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.layout_tables(&mut layouter).unwrap(); - layouter - .assign_region( - || "", - |region| { - let mut region = RegionCtx::new(region, 0, 1); - let _output = config - .layout( - &mut region, - &[self.input.clone()], - Box::new(HybridOp::Softmax { - scale: SCALE.into(), - axes: vec![0], - }), - ) - .unwrap(); - Ok(()) - }, - ) - .unwrap(); - - Ok(()) - } - } - - #[test] - fn softmax_circuit() { - let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); - - let circuit = SoftmaxCircuit:: { - input: ValTensor::from(input), - _marker: PhantomData, - }; - let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - } -} diff --git a/src/execute.rs b/src/execute.rs index d626ad24b..4b2cbbee9 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -618,7 +618,7 @@ pub(crate) async fn gen_witness( let start_time = Instant::now(); - let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?; + let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref(), false)?; // print each variable tuple (symbol, value) as symbol=value trace!( @@ -808,16 +808,7 @@ pub(crate) fn calibrate( // we load the model to get the input and output shapes // check if gag already exists - #[cfg(unix)] - let _r = match Gag::stdout() { - Ok(r) => Some(r), - Err(_) => None, - }; - let model = Model::from_run_args(&settings.run_args, &model_path)?; - // drop the gag - #[cfg(unix)] - std::mem::drop(_r); let chunks = data.split_into_batches(model.graph.input_shapes()?)?; info!("num of calibration batches: {}", chunks.len()); @@ -833,7 +824,7 @@ pub(crate) fn calibrate( let range = if let Some(scales) = scales { scales } else { - (10..14).collect::>() + (11..14).collect::>() }; let div_rebasing = if only_range_check_rebase { @@ -896,16 +887,6 @@ pub(crate) fn calibrate( input_scale, param_scale, scale_rebase_multiplier, div_rebasing )); - #[cfg(unix)] - let _r = match Gag::stdout() { - Ok(r) => Some(r), - Err(_) => None, - }; - #[cfg(unix)] - let _q = match Gag::stderr() { - Ok(r) => Some(r), - Err(_) => None, - }; let key = (input_scale, param_scale, scale_rebase_multiplier); forward_pass_res.insert(key, vec![]); @@ -920,17 +901,12 @@ pub(crate) fn calibrate( let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) { Ok(c) => c, Err(e) => { - // drop the gag - #[cfg(unix)] - std::mem::drop(_r); - #[cfg(unix)] - std::mem::drop(_q); debug!("circuit creation from run args failed: {:?}", e); continue; } }; - chunks + let forward_res = chunks .iter() .map(|chunk| { let chunk = chunk.clone(); @@ -940,7 +916,7 @@ pub(crate) fn calibrate( .map_err(|e| format!("failed to load circuit inputs: {}", e))?; let forward_res = circuit - .forward(&mut data.clone(), None, None) + .forward(&mut data.clone(), None, None, true) .map_err(|e| format!("failed to forward: {}", e))?; // push result to the hashmap @@ -951,7 +927,16 @@ pub(crate) fn calibrate( Ok(()) as Result<(), String> }) - .collect::, String>>()?; + .collect::, String>>(); + + match forward_res { + Ok(_) => (), + // typically errors will be due to the circuit overflowing the i128 limit + Err(e) => { + debug!("forward pass failed: {:?}", e); + continue; + } + } let min_lookup_range = forward_pass_res .get(&key) @@ -969,35 +954,21 @@ pub(crate) fn calibrate( .max() .unwrap_or(0); - let min_range_check = forward_pass_res - .get(&key) - .unwrap() - .iter() - .map(|x| x.min_range_check) - .min() - .unwrap_or(0); - - let max_range_check = forward_pass_res + let max_range_size = forward_pass_res .get(&key) .unwrap() .iter() - .map(|x| x.max_range_check) + .map(|x| x.max_range_size) .max() .unwrap_or(0); let res = circuit.calibrate_from_min_max( (min_lookup_range, max_lookup_range), - (min_range_check, max_range_check), + max_range_size, max_logrows, lookup_safety_margin, ); - // // drop the gag - // #[cfg(unix)] - // std::mem::drop(_r); - // #[cfg(unix)] - // std::mem::drop(_q); - if res.is_ok() { let new_settings = circuit.settings().clone(); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index a18c2843e..7dc8f9743 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -61,8 +61,11 @@ use crate::pfsys::field_to_string; /// The safety factor for the range of the lookup table. pub const RANGE_MULTIPLIER: i128 = 2; +/// The maximum number of columns in a lookup table. +pub const MAX_NUM_LOOKUP_COLS: usize = 12; + /// Max representation of a lookup table input -pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS); +pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS); #[cfg(not(target_arch = "wasm32"))] lazy_static! { @@ -134,15 +137,16 @@ pub enum GraphError { MissingResults, } -const ASSUMED_BLINDING_FACTORS: usize = 5; +/// +pub const ASSUMED_BLINDING_FACTORS: usize = 5; /// The minimum number of rows in the grid pub const MIN_LOGROWS: u32 = 6; /// 26 pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2; -/// Lookup deg -pub const LOOKUP_DEG: usize = 5; +/// +pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD; use std::cell::RefCell; @@ -171,10 +175,8 @@ pub struct GraphWitness { pub max_lookup_inputs: i128, /// max lookup input pub min_lookup_inputs: i128, - /// max range check input - pub max_range_check: i128, - /// max range check input - pub min_range_check: i128, + /// max range check size + pub max_range_size: i128, } impl GraphWitness { @@ -202,8 +204,7 @@ impl GraphWitness { processed_outputs: None, max_lookup_inputs: 0, min_lookup_inputs: 0, - max_range_check: 0, - min_range_check: 0, + max_range_size: 0, } } @@ -376,9 +377,7 @@ impl ToPyObject for GraphWitness { .unwrap(); dict.set_item("min_lookup_inputs", self.min_lookup_inputs) .unwrap(); - dict.set_item("max_range_check", self.max_range_check) - .unwrap(); - dict.set_item("min_range_check", self.min_range_check) + dict.set_item("max_range_size", self.max_range_size) .unwrap(); if let Some(processed_inputs) = &self.processed_inputs { @@ -473,6 +472,20 @@ pub struct GraphSettings { } impl GraphSettings { + fn model_constraint_logrows(&self) -> u32 { + (self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64) + .log2() + .ceil() as u32 + } + + fn module_constraint_logrows(&self) -> u32 { + (self.module_sizes.max_constraints() as f64).log2().ceil() as u32 + } + + fn constants_logrows(&self) -> u32 { + (self.total_const_size as f64).log2().ceil() as u32 + } + /// calculate the total number of instances pub fn total_instances(&self) -> Vec { let mut instances: Vec = self @@ -1005,10 +1018,6 @@ impl GraphCircuit { Ok(data) } - fn reserved_blinding_rows() -> f64 { - (ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64 - } - fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range { let mut margin = ( lookup_safety_margin * min_max_lookup.0, @@ -1022,18 +1031,33 @@ impl GraphCircuit { margin } - fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize { - let max_col_size = Table::::cal_col_size( - max_logrows as usize, - Self::reserved_blinding_rows() as usize, + fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize { + let max_col_size = Table::::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS); + num_cols_required(range_len, max_col_size) + } + + fn table_size_logrows( + &self, + safe_lookup_range: Range, + max_range_size: i128, + ) -> Result> { + // pick the range with the largest absolute size safe_lookup_range or max_range_size + let safe_range = std::cmp::max( + (safe_lookup_range.1 - safe_lookup_range.0).abs(), + max_range_size, ); - num_cols_required(safe_range, max_col_size) + + let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.) + .log2() + .ceil() as u32; + + Ok(min_bits) } fn calc_min_logrows( &mut self, min_max_lookup: Range, - min_max_range_checks: Range, + max_range_size: i128, max_logrows: Option, lookup_safety_margin: i128, ) -> Result<(), Box> { @@ -1043,68 +1067,57 @@ impl GraphCircuit { let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS); let mut min_logrows = MIN_LOGROWS; - let reserved_blinding_rows = Self::reserved_blinding_rows(); + let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin); + // check if has overflowed max lookup input - if min_max_lookup.1.abs() > MAX_LOOKUP_ABS / lookup_safety_margin - || min_max_lookup.0.abs() > MAX_LOOKUP_ABS / lookup_safety_margin - { + if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin { let err_string = format!("max lookup input {:?} is too large", min_max_lookup); return Err(err_string.into()); } - if min_max_range_checks.1.abs() > MAX_LOOKUP_ABS - || min_max_range_checks.1.abs() > MAX_LOOKUP_ABS - { - let err_string = format!( - "max range check input {:?} is too large", - min_max_range_checks - ); + if max_range_size.abs() > MAX_LOOKUP_ABS { + let err_string = format!("max range check size {:?} is too large", max_range_size); return Err(err_string.into()); } - let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin); - // pick the range with the largest absolute size between safe_lookup_range and min_max_range_checks - let safe_range = if (safe_lookup_range.1 - safe_lookup_range.0) - > (min_max_range_checks.1 - min_max_range_checks.0) - { - safe_lookup_range - } else { - min_max_range_checks - }; + // These are hard lower limits, we can't overflow instances or modules constraints + let instance_logrows = self.settings().log2_total_instances(); + let module_constraint_logrows = self.settings().module_constraint_logrows(); + min_logrows = std::cmp::max( + min_logrows, + // max of the instance logrows and the module constraint logrows is the lower limit + [instance_logrows, module_constraint_logrows] + .iter() + .max() + .unwrap() + .clone(), + ); - // degrade the max logrows until the extended k is small enough - while min_logrows < max_logrows - && !self.extended_k_is_small_enough( - min_logrows, - Self::calc_num_cols(safe_range, min_logrows), - ) - { - min_logrows += 1; - } + // These are upper limits, going above these is wasteful, but they are not hard limits + let model_constraint_logrows = self.settings().model_constraint_logrows(); + let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?; + let constants_logrows = self.settings().constants_logrows(); + max_logrows = std::cmp::min( + max_logrows, + // max of the model constraint logrows, min_bits, and the constants logrows is the upper limit + [model_constraint_logrows, min_bits, constants_logrows] + .iter() + .max() + .unwrap() + .clone(), + ); - if !self - .extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows)) - { - let err_string = format!( - "extended k is too large to accommodate the quotient polynomial with logrows {}", - min_logrows - ); - debug!("{}", err_string); - return Err(err_string.into()); - } + // we now have a min and max logrows + max_logrows = std::cmp::max(min_logrows, max_logrows); + // degrade the max logrows until the extended k is small enough while min_logrows < max_logrows - && !self.extended_k_is_small_enough( - max_logrows, - Self::calc_num_cols(safe_range, max_logrows), - ) + && !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) { max_logrows -= 1; } - if !self - .extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows)) - { + if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) { let err_string = format!( "extended k is too large to accommodate the quotient polynomial with logrows {}", max_logrows @@ -1113,67 +1126,27 @@ impl GraphCircuit { return Err(err_string.into()); } - let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.) - .log2() - .ceil() as usize; - - let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows) - .log2() - .ceil() as usize; - - let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints); - - // if public input then public inputs col will have public inputs len - if self.settings().run_args.input_visibility.is_public() - || self.settings().run_args.output_visibility.is_public() - { - let mut max_instance_len = self - .model() - .instance_shapes()? - .iter() - .fold(0, |acc, x| std::cmp::max(acc, x.iter().product::())) - as f64 - + reserved_blinding_rows; - // if there are modules then we need to add the max module size - if self.settings().uses_modules() { - max_instance_len += self - .settings() - .module_sizes - .num_instances() - .iter() - .sum::() as f64; - } - let instance_len_logrows = (max_instance_len).log2().ceil() as usize; - logrows = std::cmp::max(logrows, instance_len_logrows); - // this is for fixed const columns - } - - // ensure logrows is at least 4 - logrows = std::cmp::max(logrows, min_logrows as usize); - logrows = std::cmp::min(logrows, max_logrows as usize); + let logrows = max_logrows; let model = self.model().clone(); let settings_mut = self.settings_mut(); settings_mut.run_args.lookup_range = safe_lookup_range; - settings_mut.run_args.logrows = logrows as u32; + settings_mut.run_args.logrows = logrows; *settings_mut = GraphCircuit::new(model, &settings_mut.run_args)? .settings() .clone(); - // recalculate the total const size give nthe new logrows - let total_const_len = settings_mut.total_const_size; - let const_len_logrows = (total_const_len as f64).log2().ceil() as u32; - settings_mut.run_args.logrows = - std::cmp::max(settings_mut.run_args.logrows, const_len_logrows); - // recalculate the total number of constraints given the new logrows - let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows) - .log2() - .ceil() as u32; - settings_mut.run_args.logrows = - std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints); - - settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows); + // recalculate the logrows if there has been overflow on the constants + settings_mut.run_args.logrows = std::cmp::max( + settings_mut.run_args.logrows, + settings_mut.constants_logrows(), + ); + // recalculate the logrows if there has been overflow for the model constraints + settings_mut.run_args.logrows = std::cmp::max( + settings_mut.run_args.logrows, + settings_mut.model_constraint_logrows(), + ); debug!( "setting lookup_range to: {:?}, setting logrows to: {}", @@ -1184,12 +1157,37 @@ impl GraphCircuit { Ok(()) } - fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool { - let max_degree = self.settings().run_args.num_inner_cols + 2; - let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector + fn extended_k_is_small_enough( + &self, + k: u32, + safe_lookup_range: Range, + max_range_size: i128, + ) -> bool { + // if num cols is too large then the extended k is too large + if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS { + return false; + } else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS { + return false; + } - let max_degree = std::cmp::max(max_degree, max_lookup_degree); + let mut settings = self.settings().clone(); + settings.run_args.lookup_range = safe_lookup_range; + settings.run_args.logrows = k; + settings.required_range_checks = vec![(0, max_range_size)]; + let mut cs = ConstraintSystem::default(); + // fetch gag + #[cfg(unix)] + let _r = match gag::Gag::stdout() { + Ok(r) => Some(r), + Err(_) => None, + }; + Self::configure_with_params(&mut cs, settings); + #[cfg(feature = "mv-lookup")] + let cs = cs.chunk_lookups(); // quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial + let max_degree = cs.degree(); + #[cfg(unix)] + std::mem::drop(_r); let quotient_poly_degree = (max_degree - 1) as u64; // n = 2^k let n = 1u64 << k; @@ -1208,13 +1206,13 @@ impl GraphCircuit { pub fn calibrate_from_min_max( &mut self, min_max_lookup: Range, - min_max_range_checks: Range, + max_range_size: i128, max_logrows: Option, lookup_safety_margin: i128, ) -> Result<(), Box> { self.calc_min_logrows( min_max_lookup, - min_max_range_checks, + max_range_size, max_logrows, lookup_safety_margin, )?; @@ -1227,6 +1225,7 @@ impl GraphCircuit { inputs: &mut [Tensor], vk: Option<&VerifyingKey>, srs: Option<&ParamsKZG>, + throw_range_check_error: bool, ) -> Result> { let original_inputs = inputs.to_vec(); @@ -1267,7 +1266,9 @@ impl GraphCircuit { } } - let mut model_results = self.model().forward(inputs, &self.settings().run_args)?; + let mut model_results = + self.model() + .forward(inputs, &self.settings().run_args, throw_range_check_error)?; if visibility.output.requires_processing() { let module_outlets = visibility.output.overwrites_inputs(); @@ -1310,8 +1311,7 @@ impl GraphCircuit { processed_outputs, max_lookup_inputs: model_results.max_lookup_inputs, min_lookup_inputs: model_results.min_lookup_inputs, - max_range_check: model_results.max_range_check, - min_range_check: model_results.min_range_check, + max_range_size: model_results.max_range_size, }; witness.generate_rescaled_elements( diff --git a/src/graph/model.rs b/src/graph/model.rs index e00267f57..a45fa53b7 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -67,10 +67,8 @@ pub struct ForwardResult { pub max_lookup_inputs: i128, /// The minimum value of any input to a lookup operation. pub min_lookup_inputs: i128, - /// The max range check value - pub max_range_check: i128, - /// The min range check value - pub min_range_check: i128, + /// The max range check size + pub max_range_size: i128, } impl From for ForwardResult { @@ -79,8 +77,7 @@ impl From for ForwardResult { outputs: res.outputs, max_lookup_inputs: res.max_lookup_inputs, min_lookup_inputs: res.min_lookup_inputs, - min_range_check: res.min_range_check, - max_range_check: res.max_range_check, + max_range_size: res.max_range_size, } } } @@ -115,9 +112,7 @@ pub struct DummyPassRes { /// min lookup inputs pub min_lookup_inputs: i128, /// min range check - pub min_range_check: i128, - /// max range check - pub max_range_check: i128, + pub max_range_size: i128, /// outputs pub outputs: Vec>, } @@ -531,7 +526,7 @@ impl Model { }) .collect::, Box>>()?; - let res = self.dummy_layout(run_args, &inputs)?; + let res = self.dummy_layout(run_args, &inputs, false)?; // if we're using percentage tolerance, we need to add the necessary range check ops for it. @@ -570,12 +565,13 @@ impl Model { &self, model_inputs: &[Tensor], run_args: &RunArgs, + throw_range_check_error: bool, ) -> Result> { let valtensor_inputs: Vec> = model_inputs .iter() .map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into()) .collect(); - let res = self.dummy_layout(run_args, &valtensor_inputs)?; + let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?; Ok(res.into()) } @@ -1356,6 +1352,7 @@ impl Model { &self, run_args: &RunArgs, inputs: &[ValTensor], + throw_range_check_error: bool, ) -> Result> { debug!("calculating num of constraints using dummy model layout..."); @@ -1374,7 +1371,7 @@ impl Model { vars: ModelVars::new_dummy(), }; - let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols); + let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error); let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?; @@ -1441,8 +1438,7 @@ impl Model { range_checks: region.used_range_checks(), max_lookup_inputs: region.max_lookup_inputs(), min_lookup_inputs: region.min_lookup_inputs(), - min_range_check: region.min_range_check(), - max_range_check: region.max_range_check(), + max_range_size: region.max_range_size(), outputs, }; diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 11304b430..f6f7ccb2c 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -734,7 +734,7 @@ pub fn new_op_from_onnx( SupportedOp::Hybrid(HybridOp::Recip { input_scale: (scale_to_multiplier(in_scale) as f32).into(), output_scale: (scale_to_multiplier(max_scale) as f32).into(), - use_range_check_for_int: false, + use_range_check_for_int: true, }) } diff --git a/src/lib.rs b/src/lib.rs index 7f86615e3..803204da5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -180,6 +180,11 @@ impl RunArgs { if self.num_inner_cols < 1 { return Err("num_inner_cols must be >= 1".into()); } + if self.tolerance.val > 0.0 { + if self.output_visibility != Visibility::Public { + return Err("tolerance > 0.0 requires output_visibility to be public".into()); + } + } Ok(()) } diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 0fe910e29..23577f0fd 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -3773,6 +3773,30 @@ pub mod nonlinearities { .unwrap() } + /// Elementwise inverse. + /// # Arguments + /// * `out_scale` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::zero_recip; + /// let k = 2_f64; + /// let result = zero_recip(1.0); + /// let expected = Tensor::::new(Some(&[4503599627370496]), &[1]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn zero_recip(out_scale: f64) -> Tensor { + let a = Tensor::::new(Some(&[0]), &[1]).unwrap(); + + a.par_enum_map(|_, a_i| { + let rescaled = a_i as f64; + let denom = (1_f64) / (rescaled + f64::EPSILON); + let d_inv_x = out_scale * denom; + Ok::<_, TensorError>(d_inv_x.round() as i128) + }) + .unwrap() + } + /// Elementwise greater than /// # Arguments /// diff --git a/src/wasm.rs b/src/wasm.rs index c8548aacf..bf9b615ae 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -211,7 +211,7 @@ pub fn genWitness( .map_err(|e| JsError::new(&format!("{}", e)))?; let witness = circuit - .forward(&mut input, None, None) + .forward(&mut input, None, None, false) .map_err(|e| JsError::new(&format!("{}", e)))?; serde_json::to_vec(&witness) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 23ee9d43a..280c406b8 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -2,6 +2,7 @@ #[cfg(test)] mod native_tests { + use ezkl::circuit::Tolerance; use ezkl::fieldutils::{felt_to_i128, i128_to_felt}; // use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD; use ezkl::graph::input::{FileSource, FileSourceInner, GraphData}; @@ -276,7 +277,7 @@ mod native_tests { "bitshift", ]; - const WASM_TESTS: [&str; 48] = [ + const WASM_TESTS: [&str; 46] = [ "1l_mlp", "1l_slice", "1l_concat", @@ -325,8 +326,6 @@ mod native_tests { "1l_where", "boolean", "boolean_identity", - "decision_tree", // "variable_cnn", - "random_forest", "gradient_boosted_trees", "1l_topk", // "xgboost", @@ -586,6 +585,8 @@ mod native_tests { test_dir.close().unwrap(); } + + #(#[test_case(TESTS[N])])* fn mock_large_batch_public_outputs_(test: &str) { crate::native_tests::init_binary(); @@ -841,7 +842,7 @@ mod native_tests { }); - seq!(N in 0..=47 { + seq!(N in 0..=45 { #(#[test_case(WASM_TESTS[N])])* fn kzg_prove_and_verify_with_overflow_(test: &str) { @@ -1288,6 +1289,7 @@ mod native_tests { scales_to_use: Option>, tolerance: f32, ) { + let mut tolerance = tolerance; gen_circuit_settings_and_witness( test_dir, example_name.clone(), @@ -1299,16 +1301,10 @@ mod native_tests { scales_to_use, 2, false, - tolerance, + &mut tolerance, ); - let settings = - GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into()) - .unwrap(); - - let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0); - - if tolerance > 0.0 && !any_output_scales_smol { + if tolerance > 0.0 { // load witness and shift the output by a small amount that is less than tolerance percent let witness = GraphWitness::from_path( format!("{}/{}/witness.json", test_dir, example_name).into(), @@ -1333,7 +1329,7 @@ mod native_tests { as i128, ) }; - + *v + perturbation }) .collect::>() @@ -1444,7 +1440,7 @@ mod native_tests { scales_to_use: Option>, num_inner_columns: usize, div_rebasing: bool, - tolerance: f32, + tolerance: &mut f32, ) { let mut args = vec![ "gen-settings".to_string(), @@ -1502,6 +1498,24 @@ mod native_tests { .expect("failed to execute process"); assert!(status.success()); + let mut settings = + GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into()) + .unwrap(); + + let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0); + + if any_output_scales_smol { + // set the tolerance to 0.0 + settings.run_args.tolerance = Tolerance { + val: 0.0.into(), + scale: 0.0.into(), + }; + settings + .save(&format!("{}/{}/settings.json", test_dir, example_name).into()) + .unwrap(); + *tolerance = 0.0; + } + let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ "compile-circuit", @@ -1559,7 +1573,7 @@ mod native_tests { None, 2, div_rebasing, - 0.0, + &mut 0.0, ); println!( @@ -1819,7 +1833,7 @@ mod native_tests { scales_to_use, num_inner_columns, false, - 0.0, + &mut 0.0, ); let settings_path = format!("{}/{}/settings.json", test_dir, example_name); @@ -1921,7 +1935,7 @@ mod native_tests { None, 2, false, - 0.0, + &mut 0.0, ); let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) @@ -2198,7 +2212,7 @@ mod native_tests { Some(vec![4]), 1, false, - 0.0, + &mut 0.0, ); let model_path = format!("{}/{}/network.compiled", test_dir, example_name); diff --git a/tests/output_comparison.py b/tests/output_comparison.py index 12ac87451..217566311 100644 --- a/tests/output_comparison.py +++ b/tests/output_comparison.py @@ -91,9 +91,7 @@ def compare_outputs(zk_output, onnx_output): print("------- zk_output: ", list1_i) print("------- onnx_output: ", list2_i) - - - return np.mean(np.abs(res)) + return res if __name__ == '__main__': @@ -113,6 +111,9 @@ def compare_outputs(zk_output, onnx_output): onnx_output = get_onnx_output(model_file, input_file) # compare the outputs percentage_difference = compare_outputs(ezkl_output, onnx_output) + mean_percentage_difference = np.mean(np.abs(percentage_difference)) + max_percentage_difference = np.max(np.abs(percentage_difference)) # print the percentage difference - print("mean percent diff: ", percentage_difference) - assert percentage_difference < target, "Percentage difference is too high" + print("mean percent diff: ", mean_percentage_difference) + print("max percent diff: ", max_percentage_difference) + assert mean_percentage_difference < target, "Percentage difference is too high" diff --git a/tests/wasm/model.compiled b/tests/wasm/model.compiled index 84f3c6613..13461def8 100644 Binary files a/tests/wasm/model.compiled and b/tests/wasm/model.compiled differ diff --git a/tests/wasm/witness.json b/tests/wasm/witness.json index daa10bc1a..c7b551bd6 100644 --- a/tests/wasm/witness.json +++ b/tests/wasm/witness.json @@ -1 +1 @@ -{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_check":0,"min_range_check":0} \ No newline at end of file +{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0} \ No newline at end of file