Skip to content

Commit

Permalink
chore: fold div_rebasing parameter into calibration (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Jan 31, 2024
1 parent 45fd12a commit 04d7b5f
Showing 4 changed files with 31 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/commands.rs
Original file line number Diff line number Diff line change
@@ -336,6 +336,9 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
#[arg(long)]
div_rebasing: Option<bool>,
},

/// Generates a dummy SRS
23 changes: 20 additions & 3 deletions src/execute.rs
Original file line number Diff line number Diff line change
@@ -178,6 +178,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
scales,
scale_rebase_multiplier,
max_logrows,
div_rebasing,
} => calibrate(
model,
data,
@@ -186,6 +187,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -782,6 +784,7 @@ pub(crate) fn calibrate(
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
div_rebasing: Option<bool>,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
@@ -825,6 +828,12 @@ pub(crate) fn calibrate(
}
};

let div_rebasing = if let Some(div_rebasing) = div_rebasing {
vec![div_rebasing]
} else {
vec![true, false]
};

let mut found_params: Vec<GraphSettings> = vec![];

// 2 x 2 grid
@@ -862,15 +871,21 @@ pub(crate) fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();

let range_grid = range_grid
.iter()
.cartesian_product(div_rebasing.iter())
.map(|(a, b)| (*a, *b))
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();

let mut forward_pass_res = HashMap::new();

let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");

for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
input_scale, param_scale, scale_rebase_multiplier
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));

#[cfg(unix)]
@@ -890,6 +905,7 @@ pub(crate) fn calibrate(
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
..settings.run_args.clone()
};

@@ -964,6 +980,7 @@ pub(crate) fn calibrate(
let found_run_args = RunArgs {
input_scale: new_settings.run_args.input_scale,
param_scale: new_settings.run_args.param_scale,
div_rebasing: new_settings.run_args.div_rebasing,
lookup_range: new_settings.run_args.lookup_range,
logrows: new_settings.run_args.logrows,
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
3 changes: 3 additions & 0 deletions src/python.rs
Original file line number Diff line number Diff line change
@@ -521,6 +521,7 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
div_rebasing = None,
))]
fn calibrate_settings(
data: PathBuf,
@@ -531,6 +532,7 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
div_rebasing: Option<bool>,
) -> Result<bool, PyErr> {
crate::execute::calibrate(
model,
@@ -540,6 +542,7 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
max_logrows,
)
.map_err(|e| {
10 changes: 5 additions & 5 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -836,10 +836,10 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
env_logger::init();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, Some(vec![0,1]), true, "single");
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
test_dir.close().unwrap();
// test_dir.close().unwrap();
}

#(#[test_case(WASM_TESTS[N])])*
@@ -849,7 +849,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
env_logger::init();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, Some(vec![0,1]), true, "single");
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
test_dir.close().unwrap();
@@ -865,7 +865,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, Some(vec![0,6]), false, "single");
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, None, false, "single");
test_dir.close().unwrap();
}

@@ -875,7 +875,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", Some(vec![0,6]));
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
test_dir.close().unwrap();
}
});

0 comments on commit 04d7b5f

Please sign in to comment.