Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: clearer duplication functions #895

Merged
merged 5 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 9 additions & 76 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,48 +420,29 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values[0].remove_indices(&mut removal_indices, true)?;
values[1].remove_indices(&mut removal_indices, true)?;

let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);

let start = instant::Instant::now();
let mut inputs = vec![];
let block_width = config.custom_gates.output.num_inner_cols();

let mut assigned_len = 0;
for (i, input) in values.iter_mut().enumerate() {
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_with_duplication(
&config.custom_gates.inputs[i],
input,
&config.check_mode,
false,
)?;
let (res, len) = region
.assign_with_duplication_unconstrained(&config.custom_gates.inputs[i], input)?;
assigned_len = len;
res.get_inner()?
};
inputs.push(inp);
}

let elapsed = start.elapsed();
trace!("assigning inputs took: {:?}", elapsed);

// Now we can assign the dot product
// time this step
let start = instant::Instant::now();
let accumulated_dot = accumulated::dot(&[inputs[0].clone(), inputs[1].clone()], block_width)?;
let elapsed = start.elapsed();
trace!("calculating accumulated dot took: {:?}", elapsed);

let start = instant::Instant::now();
let (output, output_assigned_len) = region.assign_with_duplication(
let (output, output_assigned_len) = region.assign_with_duplication_constrained(
&config.custom_gates.output,
&accumulated_dot.into(),
&config.check_mode,
true,
)?;
let elapsed = start.elapsed();
trace!("assigning output took: {:?}", elapsed);

// enable the selectors
if !region.is_dummy() {
Expand Down Expand Up @@ -1002,7 +983,6 @@ fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
let start = instant::Instant::now();
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();

Expand Down Expand Up @@ -1030,9 +1010,6 @@ fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let (_, assigned_output) =
dynamic_lookup(config, region, &[index, output], &[dim_indices, input])?;

let end = start.elapsed();
trace!("select took: {:?}", end);

Ok(assigned_output)
}

Expand Down Expand Up @@ -1094,7 +1071,6 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), CircuitError> {
let start = instant::Instant::now();
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err(CircuitError::MismatchedLookupLength(
Expand Down Expand Up @@ -1128,28 +1104,20 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
}
let table_len = table_0.len();

trace!("assigning tables took: {:?}", start.elapsed());

// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;

trace!("assigning table index took: {:?}", start.elapsed());

let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();

trace!("assigning lookups took: {:?}", start.elapsed());

// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);

let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;

trace!("assigning lookup index took: {:?}", start.elapsed());

let mut lookup_block = 0;

if !region.is_dummy() {
Expand Down Expand Up @@ -1196,9 +1164,6 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);

let end = start.elapsed();
trace!("dynamic lookup took: {:?}", end);

Ok((lookup_0, lookup_1))
}

Expand Down Expand Up @@ -1443,7 +1408,6 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd +
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, CircuitError> {
let start_time = instant::Instant::now();
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
Expand Down Expand Up @@ -1517,9 +1481,6 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd +

region.apply_in_loop(&mut output, inner_loop_function)?;

let elapsed = start_time.elapsed();
trace!("linearize_element_index took: {:?}", elapsed);

Ok(output.into())
}

Expand Down Expand Up @@ -1951,16 +1912,11 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();

let mut values = values.clone();

// this section has been optimized to death, don't mess with it
values[0].remove_const_zero_values();

let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);

// if empty return a const
if values[0].is_empty() {
return Ok(create_zero_tensor(1));
Expand All @@ -1972,24 +1928,19 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (res, len) = region.assign_with_duplication(
&config.custom_gates.inputs[1],
&input,
&config.check_mode,
false,
)?;
let (res, len) =
region.assign_with_duplication_unconstrained(&config.custom_gates.inputs[1], &input)?;
assigned_len = len;
res.get_inner()?
};

// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input, block_width)?;

let (output, output_assigned_len) = region.assign_with_duplication(
let (output, output_assigned_len) = region.assign_with_duplication_constrained(
&config.custom_gates.output,
&accumulated_sum.into(),
&config.check_mode,
true,
)?;

// enable the selectors
Expand Down Expand Up @@ -2055,13 +2006,10 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
) -> Result<ValTensor<F>, CircuitError> {
region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();

// this section has been optimized to death, don't mess with it
let removal_indices = values[0].get_const_zero_indices();

let elapsed = global_start.elapsed();
trace!("finding const zero indices took: {:?}", elapsed);
// if empty return a const
if !removal_indices.is_empty() {
return Ok(create_zero_tensor(1));
Expand All @@ -2072,24 +2020,19 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let (res, len) = region.assign_with_duplication(
&config.custom_gates.inputs[1],
&input,
&config.check_mode,
false,
)?;
let (res, len) =
region.assign_with_duplication_unconstrained(&config.custom_gates.inputs[1], &input)?;
assigned_len = len;
res.get_inner()?
};

// Now we can assign the dot product
let accumulated_prod = accumulated::prod(&input, block_width)?;

let (output, output_assigned_len) = region.assign_with_duplication(
let (output, output_assigned_len) = region.assign_with_duplication_constrained(
&config.custom_gates.output,
&accumulated_prod.into(),
&config.check_mode,
true,
)?;

// enable the selectors
Expand Down Expand Up @@ -2442,7 +2385,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let orig_lhs = lhs.clone();
let orig_rhs = rhs.clone();

let start = instant::Instant::now();
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());

Expand All @@ -2457,7 +2399,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
BaseOp::Sub => second_zero_indices.clone(),
_ => return Err(CircuitError::UnsupportedOp),
};
trace!("setting up indices took {:?}", start.elapsed());

if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
Expand All @@ -2482,7 +2423,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash

// Now we can assign the dot product
// time the calc
let start = instant::Instant::now();
let op_result = match op {
BaseOp::Add => add(&inputs),
BaseOp::Sub => sub(&inputs),
Expand All @@ -2493,20 +2433,13 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
trace!("pairwise {} calc took {:?}", op.as_str(), start.elapsed());

let start = instant::Instant::now();
let assigned_len = op_result.len() - removal_indices.len();
let mut output = region.assign_with_omissions(
&config.custom_gates.output,
&op_result.into(),
&removal_indices,
)?;
trace!(
"pairwise {} input assign took {:?}",
op.as_str(),
start.elapsed()
);

// Enable the selectors
if !region.is_dummy() {
Expand Down
35 changes: 30 additions & 5 deletions src/circuit/ops/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,22 +671,47 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}

/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication(
pub fn assign_with_duplication_unconstrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_unconstrained(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
false,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}

/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_constrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication_constrained(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
Ok((res, len))
Expand All @@ -695,7 +720,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.row,
self.linear_coord,
values,
single_inner_col,
true,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
Expand Down
2 changes: 2 additions & 0 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,7 @@ impl Model {
values.iter().map(|v| v.dims()).collect_vec()
);

let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
Expand Down Expand Up @@ -1363,6 +1364,7 @@ impl Model {
results.insert(*idx, full_results);
}
}
debug!("------------ layout of {} took {:?}", idx, start.elapsed());
}

// we do this so we can support multiple passes of the same model and have deterministic results (Non-assigned inputs etc... etc...)
Expand Down
30 changes: 16 additions & 14 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if (i + offset + 1) % n == 0 {
Expand Down Expand Up @@ -862,20 +862,22 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
let mut indices_to_remove = std::collections::HashSet::new();
for i in 0..self.inner.len() {
if (i + initial_offset + 1) % n == 0 {
for j in 1..(1 + num_repeats) {
indices_to_remove.insert(i + j);
}
}
}
// Pre-calculate capacity to avoid reallocations
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
let mut inner = Vec::with_capacity(estimated_size);

let old_inner = self.inner.clone();
for (i, elem) in old_inner.into_iter().enumerate() {
if !indices_to_remove.contains(&i) {
inner.push(elem.clone());
// Use iterator directly instead of creating intermediate collections
let mut i = 0;
while i < self.inner.len() {
// Add the current element
inner.push(self.inner[i].clone());

// If this is an nth position (accounting for offset)
if (i + initial_offset + 1) % n == 0 {
// Skip the next num_repeats elements
i += num_repeats + 1;
} else {
i += 1;
}
}

Expand Down
Loading
Loading