Skip to content

Commit

Permalink
fix: const filtering strat is size dependent (#891)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Dec 27, 2024
1 parent c4354c1 commit caa6ef8
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 88 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ shellexpand = "3.1.0"
runner = 'wasm-bindgen-test-runner'


[[bench]]
name = "zero_finder"
harness = false

[[bench]]
name = "accum_dot"
harness = false
Expand Down
116 changes: 116 additions & 0 deletions benches/zero_finder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::thread;

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use halo2curves::{bn256::Fr as F, ff::Field};
use maybe_rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
slice::ParallelSlice,
};
use rand::Rng;

// Assuming these are your types
#[derive(Clone)]
enum ValType {
Constant(F),
AssignedConstant(usize, F),
Other,
}

// Helper to generate test data
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value
}
})
.collect()
}

fn bench_zero_finding(c: &mut Criterion) {
let sizes = [
1_000, // 1K
10_000, // 10K
100_000, // 100K
256 * 256 * 2, // Our specific case
1_000_000, // 1M
10_000_000, // 10M
];

let zero_probability = 0.1; // 10% zeros

let mut group = c.benchmark_group("zero_finding");
group.sample_size(10); // Adjust based on your needs

for &size in &sizes {
let data = generate_test_data(size, zero_probability);

// Benchmark sequential version
group.bench_function(format!("sequential_{}", size), |b| {
b.iter(|| {
let result = data
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});

// Benchmark parallel version
group.bench_function(format!("parallel_{}", size), |b| {
b.iter(|| {
let result = data
.par_iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});

// Benchmark chunked parallel version
group.bench_function(format!("chunked_parallel_{}", size), |b| {
b.iter(|| {
let num_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (size / num_cores).max(100);

let result = data
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>();
black_box(result)
})
});
}
group.finish();
}

criterion_group!(benches, bench_zero_finding);
criterion_main!(benches);
8 changes: 4 additions & 4 deletions src/eth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ pub async fn deploy_da_verifier_via_solidity(
}
}

let contract = match call_to_account {

match call_to_account {
Some(call) => {
deploy_single_da_contract(
client,
Expand All @@ -514,8 +515,7 @@ pub async fn deploy_da_verifier_via_solidity(
)
.await
}
};
return contract;
}
}

async fn deploy_multi_da_contract(
Expand Down Expand Up @@ -630,7 +630,7 @@ async fn deploy_single_da_contract(
// bytes memory _callData,
PackedSeqToken(call_data.as_ref()),
// uint256 _decimals,
WordToken(B256::from(decimals).into()),
WordToken(B256::from(decimals)),
// uint[] memory _scales,
DynSeqToken(
scales
Expand Down
28 changes: 13 additions & 15 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1136,23 +1136,21 @@ pub fn new_op_from_onnx(
a: crate::circuit::utils::F32(exponent),
})
}
} else {
if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
}
} else if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
}

let base = c.raw_values[0];
let base = c.raw_values[0];

SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
})
} else {
unimplemented!("only support constant base or pow for now")
}
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
})
} else {
unimplemented!("only support constant base or pow for now")
}
}
"Div" => {
Expand Down
Loading

0 comments on commit caa6ef8

Please sign in to comment.