Skip to content

Commit

Permalink
Do multiplicity witgen in separate stage (#2319)
Browse files Browse the repository at this point in the history
While working on #2306, @Schaeff came across several bugs in
multiplicity witness generation. These were undetected, because we
ignored multiplicities in the mock prover, which will be fixed by #2310.
With this PR, #2310 will be green.

The issue was that counting multiplicities inside
`Machine::process_plookup()` fails if the caller actually discards the
result. This happens in a few places, for example during our loop
optimization in the "dynamic machine".

With this PR, we instead have a centralized
`MultiplicityColumnGenerator` that counts multiplicities after the fact,
by going over each lookup, evaluating the two selected tuples on all
rows, and counting how often each element in the LHS appears in the RHS.

To measure the runtime of this, I ran:
```sh
export TEST=keccak
export POWDR_JIT_OPT_LEVEL=0

cargo run -r --bin powdr-rs compile riscv/tests/riscv_data/$TEST -o output --max-degree-log 18
cargo run -r --features plonky3,halo2 pil output/$TEST.asm -o output -f --field gl --linker-mode bus
```

I get the following profile on the server:
```
 == Witgen profile (2554126 events)
   32.4% (    2.6s): Secondary machine 0: main_binary (BlockMachine)
   23.1% (    1.9s): Main machine (Dynamic)
   12.7% (    1.0s): Secondary machine 4: main_regs (DoubleSortedWitnesses32)
   10.0% ( 809.9ms): FixedLookup
    7.7% ( 621.1ms): Secondary machine 5: main_shift (BlockMachine)
    5.6% ( 454.6ms): Secondary machine 2: main_poseidon_gl (BlockMachine)
    3.8% ( 312.3ms): multiplicity witgen
    3.8% ( 308.2ms): witgen (outer code)
    0.6% (  45.3ms): Secondary machine 1: main_memory (DoubleSortedWitnesses32)
    0.4% (  33.4ms): Secondary machine 6: main_split_gl (BlockMachine)
    0.0% (   8.0µs): Secondary machine 3: main_publics (WriteOnceMemory)
  ---------------------------
    ==> Total: 8.114630092s
```

So the cost is ~4%. I'm sure it can be optimized further but I would
like to leave this to a future PR.
  • Loading branch information
georgwiese authored Jan 9, 2025
1 parent 067b633 commit e328eb9
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 244 deletions.
1 change: 1 addition & 0 deletions executor-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ powdr-number.workspace = true
powdr-ast.workspace = true

serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] }
itertools = "0.13"

[lib]
bench = false # See https://github.com/bheisler/criterion.rs/issues/458
20 changes: 18 additions & 2 deletions executor-utils/src/expression_evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use core::ops::{Add, Mul, Sub};
use itertools::Itertools;
use std::collections::BTreeMap;

use powdr_ast::analyzed::{
Expand Down Expand Up @@ -35,7 +36,7 @@ pub struct RowValues<'a, F> {
row: usize,
}

impl<F> OwnedTerminalValues<F> {
impl<F: std::fmt::Debug> OwnedTerminalValues<F> {
pub fn new(
pil: &Analyzed<F>,
witness_columns: Vec<(String, Vec<F>)>,
Expand Down Expand Up @@ -72,13 +73,28 @@ impl<F> OwnedTerminalValues<F> {
self
}

/// The height of the trace. Panics if columns have different lengths.
pub fn height(&self) -> usize {
self.trace.values().next().map(|v| v.len()).unwrap()
self.trace
.values()
.map(|v| v.len())
.unique()
.exactly_one()
.unwrap()
}

/// The length of a given column.
pub fn column_length(&self, poly_id: &PolyID) -> usize {
self.trace.get(poly_id).unwrap().len()
}

pub fn row(&self, row: usize) -> RowValues<F> {
RowValues { values: self, row }
}

pub fn into_trace(self) -> BTreeMap<PolyID, Vec<F>> {
self.trace
}
}

impl<F: FieldElement, T: From<F>> TerminalAccess<T> for RowValues<'_, F> {
Expand Down
1 change: 0 additions & 1 deletion executor/src/witgen/data_structures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ pub mod caller_data;
pub mod column_map;
pub mod copy_constraints;
pub mod finalizable_data;
pub mod multiplicity_counter;
pub mod mutable_state;
pub mod padded_bitvec;
90 changes: 0 additions & 90 deletions executor/src/witgen/data_structures/multiplicity_counter.rs

This file was deleted.

13 changes: 0 additions & 13 deletions executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use crate::witgen::affine_expression::AlgebraicVariable;
use crate::witgen::analysis::detect_connection_type_and_block_size;
use crate::witgen::block_processor::BlockProcessor;
use crate::witgen::data_structures::finalizable_data::FinalizableData;
use crate::witgen::data_structures::multiplicity_counter::MultiplicityCounter;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::jit::function_cache::FunctionCache;
use crate::witgen::processor::{OuterQuery, Processor, SolverState};
Expand Down Expand Up @@ -73,7 +72,6 @@ pub struct BlockMachine<'a, T: FieldElement> {
/// If this block machine can be JITed, we store the witgen functions here.
function_cache: FunctionCache<'a, T>,
name: String,
multiplicity_counter: MultiplicityCounter,
/// Counts the number of blocks created using the JIT.
block_count_jit: usize,
/// Counts the number of blocks created using the runtime solver.
Expand Down Expand Up @@ -131,7 +129,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
connection_type: is_permutation,
data,
publics: Default::default(),
multiplicity_counter: MultiplicityCounter::new(&parts.connections),
processing_sequence_cache: ProcessingSequenceCache::new(
block_size,
latch_row,
Expand Down Expand Up @@ -211,8 +208,6 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
.witnesses
.iter()
.map(|id| (*id, Vec::new()))
// Note that this panics if any count is not 0 (which shouldn't happen).
.chain(self.multiplicity_counter.generate_columns_single_size(0))
.map(|(id, values)| (self.fixed_data.column_name(&id).to_string(), values))
.collect();
}
Expand Down Expand Up @@ -338,10 +333,6 @@ impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
.collect();
self.handle_last_row(&mut data);
data.into_iter()
.chain(
self.multiplicity_counter
.generate_columns_single_size(self.degree),
)
.map(|(id, values)| (self.fixed_data.column_name(&id).to_string(), values))
.collect()
}
Expand Down Expand Up @@ -448,10 +439,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {

let updates = updates.report_side_effect();

let global_latch_row_index = self.data.len() - self.block_size + self.latch_row;
self.multiplicity_counter
.increment_at_row(identity_id, global_latch_row_index);

// We solved the query, so report it to the cache.
self.processing_sequence_cache
.report_processing_sequence(&outer_query.left, sequence_iterator);
Expand Down
14 changes: 0 additions & 14 deletions executor/src/witgen/machines/dynamic_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::collections::{BTreeMap, HashMap};

use crate::witgen::block_processor::BlockProcessor;
use crate::witgen::data_structures::finalizable_data::FinalizableData;
use crate::witgen::data_structures::multiplicity_counter::MultiplicityCounter;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::machines::{Machine, MachineParts};
use crate::witgen::processor::{OuterQuery, SolverState};
Expand All @@ -31,7 +30,6 @@ pub struct DynamicMachine<'a, T: FieldElement> {
latch: Option<Expression<T>>,
name: String,
degree: DegreeType,
multiplicity_counter: MultiplicityCounter,
}

impl<'a, T: FieldElement> Machine<'a, T> for DynamicMachine<'a, T> {
Expand Down Expand Up @@ -96,12 +94,6 @@ impl<'a, T: FieldElement> Machine<'a, T> for DynamicMachine<'a, T> {
self.data.extend(updated_data.block);
self.publics.extend(updated_data.publics);

// The block we just added contains the first row of the next block,
// so the latch row is the second-to-last row.
let latch_row = self.data.len() - 2;
self.multiplicity_counter
.increment_at_row(identity_id, latch_row);

eval_value.report_side_effect()
} else {
log::trace!("End processing VM '{}' (incomplete)", self.name());
Expand All @@ -122,10 +114,6 @@ impl<'a, T: FieldElement> Machine<'a, T> for DynamicMachine<'a, T> {
self.data
.take_transposed()
.map(|(id, (values, _))| (id, values))
.chain(
self.multiplicity_counter
.generate_columns_single_size(self.degree),
)
.map(|(id, values)| (self.fixed_data.column_name(&id).to_string(), values))
.collect()
}
Expand All @@ -139,7 +127,6 @@ impl<'a, T: FieldElement> DynamicMachine<'a, T> {
latch: Option<Expression<T>>,
) -> Self {
let data = FinalizableData::new(&parts.witnesses, fixed_data);
let multiplicity_counter = MultiplicityCounter::new(&parts.connections);

Self {
degree: parts.common_degree_range().max,
Expand All @@ -149,7 +136,6 @@ impl<'a, T: FieldElement> DynamicMachine<'a, T> {
data,
publics: Default::default(),
latch,
multiplicity_counter,
}
}

Expand Down
54 changes: 3 additions & 51 deletions executor/src/witgen/machines/fixed_lookup_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use std::mem;

use itertools::{Either, Itertools};
use powdr_ast::analyzed::{AlgebraicReference, PolynomialType};
use powdr_number::{DegreeType, FieldElement};
use powdr_number::FieldElement;

use crate::witgen::affine_expression::{AffineExpression, AlgebraicVariable};
use crate::witgen::data_structures::caller_data::CallerData;
use crate::witgen::data_structures::multiplicity_counter::MultiplicityCounter;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::global_constraints::{GlobalConstraints, RangeConstraintSet};
use crate::witgen::processor::OuterQuery;
Expand Down Expand Up @@ -160,7 +159,6 @@ pub struct FixedLookup<'a, T: FieldElement> {
indices: HashMap<Application, Index<T>>,
connections: BTreeMap<u64, Connection<'a, T>>,
fixed_data: &'a FixedData<'a, T>,
multiplicity_counter: MultiplicityCounter,
}

impl<'a, T: FieldElement> FixedLookup<'a, T> {
Expand All @@ -180,23 +178,11 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> {
fixed_data: &'a FixedData<'a, T>,
connections: BTreeMap<u64, Connection<'a, T>>,
) -> Self {
let multiplicity_column_sizes = connections
.values()
.filter_map(|connection| {
connection
.multiplicity_column
.map(|poly_id| (poly_id, unique_size(fixed_data, connection)))
})
.collect();
let multiplicity_counter =
MultiplicityCounter::new_with_sizes(&connections, multiplicity_column_sizes);

Self {
global_constraints,
indices: Default::default(),
connections,
fixed_data,
multiplicity_counter,
}
}

Expand Down Expand Up @@ -265,34 +251,6 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> {
}
}

/// Get the unique size of the fixed lookup machine referenced by the provided connection.
/// Panics if any expression in the connection's RHS is not a reference to a fixed column,
/// if the fixed columns are variably-sized, or if the fixed columns have different sizes.
fn unique_size<T: FieldElement>(
fixed_data: &FixedData<T>,
connection: &Connection<T>,
) -> DegreeType {
let fixed_columns = connection
.right
.expressions
.iter()
.map(|expr| try_to_simple_poly(expr).unwrap().poly_id)
.collect::<Vec<_>>();
fixed_columns
.iter()
.map(|fixed_col| {
// Get unique size for fixed column
fixed_data.fixed_cols[fixed_col]
.values
.get_uniquely_sized()
.unwrap()
.len() as DegreeType
})
.unique()
.exactly_one()
.expect("All fixed columns on the same RHS must have the same size")
}

impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> {
fn name(&self) -> &str {
"FixedLookup"
Expand Down Expand Up @@ -386,13 +344,11 @@ impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> {
EvalError::FixedLookupFailed(input_assignment)
})?;

let Some((row, output)) = index_value.get() else {
let Some((_, output)) = index_value.get() else {
// multiple matches, we stop and learnt nothing
return Ok(false);
};

self.multiplicity_counter.increment_at_row(identity_id, row);

values
.iter_mut()
.filter_map(|v| match v {
Expand All @@ -410,11 +366,7 @@ impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> {
&mut self,
_mutable_state: &'b MutableState<'a, T, Q>,
) -> HashMap<String, Vec<T>> {
self.multiplicity_counter
.generate_columns_different_sizes()
.into_iter()
.map(|(poly_id, column)| (self.fixed_data.column_name(&poly_id).to_string(), column))
.collect()
Default::default()
}

fn identity_ids(&self) -> Vec<u64> {
Expand Down
10 changes: 9 additions & 1 deletion executor/src/witgen/machines/machine_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,15 @@ impl<'a, T: FieldElement> MachineExtractor<'a, T> {
.filter(|(_, pf)| {
let refs = refs_in_parsed_expression(pf)
.unique()
.filter_map(|n| self.fixed.column_by_name.get(n).cloned())
.flat_map(|n| {
self.fixed.try_column_by_name(n).into_iter().chain(
// The reference might be an array, in which case it wouldn't
// be in the list of columns. So we try the first element as well.
self.fixed
.try_column_by_name(&format!("{n}[0]"))
.into_iter(),
)
})
.collect::<HashSet<_>>();
refs.intersection(&machine_witnesses).next().is_some()
})
Expand Down
Loading

0 comments on commit e328eb9

Please sign in to comment.