Skip to content

Commit

Permalink
refactor!: make rebasing multiplicative by default (#698)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: adds a `required_range_checks` field to `cs`
  • Loading branch information
alexander-camuto authored Jan 30, 2024
1 parent bc7c331 commit 45fd12a
Show file tree
Hide file tree
Showing 23 changed files with 659 additions and 81 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,8 @@ jobs:
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
Expand Down
3 changes: 2 additions & 1 deletion benches/accum_matmul_relu.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::table::Range;
use ezkl::circuit::*;

use ezkl::circuit::lookup::LookupOp;
Expand All @@ -16,7 +17,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use std::marker::PhantomData;

const BITS: (i128, i128) = (-32768, 32768);
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
const K: usize = 16;

Expand Down
3 changes: 2 additions & 1 deletion benches/accum_matmul_relu_overflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use ezkl::circuit::*;

use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
Expand All @@ -16,7 +17,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use std::marker::PhantomData;

const BITS: (i128, i128) = (-8180, 8180);
const BITS: Range = (-8180, 8180);
static mut LEN: usize = 4;
static mut K: usize = 16;

Expand Down
3 changes: 2 additions & 1 deletion benches/relu.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::TranscriptType;
Expand All @@ -14,7 +15,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;

const BITS: (i128, i128) = (-32768, 32768);
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
const K: usize = 16;

Expand Down
97 changes: 95 additions & 2 deletions src/circuit/ops/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ use serde::{Deserialize, Serialize};

use crate::{
circuit::ops::base::BaseOp,
circuit::{table::Table, utils},
circuit::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
Expand Down Expand Up @@ -176,6 +179,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
Expand All @@ -194,7 +201,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
check_mode: CheckMode::SAFE,
_marker: PhantomData,
}
Expand Down Expand Up @@ -325,11 +334,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Self {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
check_mode,
_marker: PhantomData,
Expand All @@ -344,7 +355,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
input: &VarTensor,
output: &VarTensor,
index: &VarTensor,
lookup_range: (i128, i128),
lookup_range: Range,
logrows: usize,
nl: &LookupOp,
) -> Result<(), Box<dyn Error>>
Expand Down Expand Up @@ -482,6 +493,74 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}

/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_range_check(
&mut self,
cs: &mut ConstraintSystem<F>,
input: &VarTensor,
range: Range,
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
let mut selectors = BTreeMap::new();

if !input.is_advice() {
return Err("wrong input type for lookup input".into());
}

// we borrow mutably twice so we need to do this dance

let range_check = if !self.range_checks.contains_key(&range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
self.range_checks.insert(range, range_check.clone());
range_check
} else {
return Ok(());
};

for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
let single_col_sel = cs.complex_selector();

cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(single_col_sel);

let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};

let default_x = range_check.get_first_element();

let not_sel = Expression::Constant(F::ONE) - sel.clone();

res.extend([(
sel.clone() * input_query.clone()
+ not_sel.clone() * Expression::Constant(default_x),
range_check.input,
)]);

res
});
selectors.insert((range, x, y), single_col_sel);
}
}
self.range_check_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
}

Ok(())
}

/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
for (i, table) in self.tables.values_mut().enumerate() {
Expand All @@ -500,6 +579,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}

/// layout_range_checks must be called before layout.
pub fn layout_range_checks(
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
for range_check in self.range_checks.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
range_check.layout(layouter)?;
}
}
Ok(())
}

/// Assigns variables to the regions created when calling `configure`.
/// # Arguments
/// * `values` - The explicit values to the operations.
Expand Down
106 changes: 105 additions & 1 deletion src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::{
};
use crate::{
circuit::{ops::base::BaseOp, utils},
fieldutils::i128_to_felt,
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Expand Down Expand Up @@ -51,6 +51,66 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
total_len
}

/// Div accumulated layout
pub fn div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
div: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let input = value[0].clone();
let input_dims = input.dims();

let range_check_bracket = felt_to_i128(div) - 1;

let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
divisor.set_visibility(&crate::graph::Visibility::Fixed);
let divisor = region.assign(&config.inputs[1], &divisor.into())?;
region.increment(divisor.len());

let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;

let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
let divisor_evals = divisor.get_int_evals()?;
tensor::ops::div(&[input_evals.clone(), divisor_evals.clone()])?
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
claimed_output.reshape(input_dims)?;

let product = pairwise(
config,
region,
&[claimed_output.clone(), divisor.clone()],
BaseOp::Mult,
)?;

let diff_with_input = pairwise(
config,
region,
&[product.clone(), input.clone()],
BaseOp::Sub,
)?;

range_check(
config,
region,
&[diff_with_input],
&(-range_check_bracket, range_check_bracket),
)?;

Ok(claimed_output)
}

/// Dot product accumulated layout
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down Expand Up @@ -2304,6 +2364,50 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}

/// layout for range check.
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
range: &crate::circuit::table::Range,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// time the entire operation
let timer = instant::Instant::now();

let x = values[0].clone();

let w = region.assign(&config.lookup_input, &x)?;

let assigned_len = x.len();

let is_dummy = region.is_dummy();

if !is_dummy {
(0..assigned_len)
.map(|i| {
let (x, y, z) = config
.lookup_input
.cartesian_coord(region.linear_coord() + i);
let selector = config.range_check_selectors.get(&(range.clone(), x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}

region.increment(assigned_len);

let elapsed = timer.elapsed();
trace!(
"range check {:?} layout took {:?}, row: {:?}",
range,
elapsed,
region.row()
);

Ok(w)
}

/// layout for nonlinearity check.
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down
4 changes: 2 additions & 2 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use std::error::Error;

use crate::{
circuit::{layouts, utils},
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i128, i128_to_felt},
graph::{multiplier_to_scale, scale_to_multiplier},
tensor::{self, Tensor, TensorError, TensorType},
Expand Down Expand Up @@ -57,7 +57,7 @@ pub enum LookupOp {

impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> (i128, i128) {
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i128;
(-range, range)
Expand Down
7 changes: 7 additions & 0 deletions src/circuit/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use halo2curves::ff::PrimeField;

use self::{lookup::LookupOp, region::RegionCtx};

use super::table::Range;

///
pub mod base;
///
Expand Down Expand Up @@ -60,6 +62,11 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
vec![]
}

/// Returns the range checks required by the operation.
fn required_range_checks(&self) -> Vec<Range> {
vec![]
}

/// Returns true if the operation is an input.
fn is_input(&self) -> bool {
false
Expand Down
Loading

0 comments on commit 45fd12a

Please sign in to comment.