From a5464733ed7818e8d59d2a125d143507b7830443 Mon Sep 17 00:00:00 2001 From: James Parker Date: Mon, 9 Dec 2024 15:06:20 -0500 Subject: [PATCH 1/2] Add feature flag to reorder R1CS constraints in Az, Bz, Cz --- jolt-core/Cargo.toml | 2 + jolt-core/src/jolt/vm/mod.rs | 1 + jolt-core/src/r1cs/builder.rs | 176 +++++++++++++++++++++++++++++++++- jolt-core/src/r1cs/key.rs | 63 +++++++++--- jolt-core/src/r1cs/ops.rs | 17 ++++ jolt-core/src/r1cs/spartan.rs | 34 +++++-- 6 files changed, 266 insertions(+), 27 deletions(-) diff --git a/jolt-core/Cargo.toml b/jolt-core/Cargo.toml index d53e2d7e8..85006e026 100644 --- a/jolt-core/Cargo.toml +++ b/jolt-core/Cargo.toml @@ -105,8 +105,10 @@ default = [ "ark-ff/asm", "host", "rayon", + "reorder", ] host = ["dep:reqwest", "dep:tokio"] +reorder = [] [target.'cfg(not(target_arch = "wasm32"))'.dependencies] memory-stats = "1.0.0" diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 46177850e..b49520752 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -378,6 +378,7 @@ where let padded_trace_length = trace_length.next_power_of_two(); println!("Trace length: {}", trace_length); + // TODO(JP): Drop padding on number of steps JoltTraceStep::pad(&mut trace); let mut transcript = ProofTranscript::new(b"Jolt transcript"); diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 85be551c5..db3cb6c20 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -516,12 +516,27 @@ impl OffsetEqConstraint { } } +pub(crate) fn eval_offset_lc( + offset: &OffsetLC, + flattened_polynomials: &[&DensePolynomial], + step: usize, + next_step_m: Option, +) -> F { + if !offset.0 { + offset.1.evaluate_row(flattened_polynomials, step) + } else if let Some(next_step) = next_step_m { + offset.1.evaluate_row(flattened_polynomials, next_step) + } else { + offset.1.constant_term_field() + } +} + // TODO(sragss): Detailed documentation with wiki. pub struct CombinedUniformBuilder { uniform_builder: R1CSBuilder, /// Padded to the nearest power of 2 - uniform_repeat: usize, + uniform_repeat: usize, // TODO(JP): Remove padding of steps offset_equality_constraints: Vec, } @@ -554,12 +569,25 @@ impl CombinedUniformBuilder usize { self.uniform_repeat * self.offset_equality_constraints.len() } + /// Number of constraint rows per step, padded to the next power of two. + // #[cfg(feature = "reorder")] + pub(super) fn padded_rows_per_step(&self) -> usize { + let num_constraints = + self.uniform_builder.constraints.len() + self.offset_equality_constraints.len(); + num_constraints.next_power_of_two() + } + /// Total number of rows used across all repeated constraints. Not padded to nearest power of two. pub(super) fn constraint_rows(&self) -> usize { + if cfg!(feature = "reorder") { + return self.uniform_repeat * self.padded_rows_per_step() + } + self.offset_eq_constraint_rows() + self.uniform_repeat_constraint_rows() } @@ -635,18 +663,162 @@ impl CombinedUniformBuilder Fn(&'a Constraint) -> &'a LC + Send + Sync + Copy + 'static, + O: Fn(&[&DensePolynomial], &OffsetEqConstraint, usize, Option) -> F + + Send + + Sync + + Copy + + 'static, + >( + &self, + flattened_polynomials: &[&DensePolynomial], // N variables of (S steps) + uniform_constraint: U, + offset_constraint: O, + ) -> Vec<(F, usize)> { + let num_steps = flattened_polynomials[0].len(); + let padded_num_constraints = self.padded_rows_per_step(); + + // Filter out constraints that won't contribute ahead of time. + let filtered_uniform_constraints = self + .uniform_builder + .constraints + .iter() + .enumerate() + .filter_map(|(index, constraint)| { + let lc = uniform_constraint(constraint); + if lc.terms().is_empty() { + None + } else { + Some((index, lc)) + } + }) + .collect::>(); + + (0..num_steps) + .into_par_iter() + .flat_map_iter(|step_index| { + let next_step_index_m = if step_index + 1 < num_steps { + Some(step_index + 1) + } else { + None + }; + + // uniform_constraints + let uniform_constraints = filtered_uniform_constraints.par_iter().flat_map( + move |(constraint_index, lc)| { + // Evaluate a constraint on a given step. + let item = lc.evaluate_row(flattened_polynomials, step_index); + if !item.is_zero() { + let global_index = + step_index * padded_num_constraints + constraint_index; + Some((item, global_index)) + } else { + None + } + }, + ); + + // offset_equality_constraints + // (a - b) * condition == 0 + // For the final step we will not compute the offset terms, and will assume the condition to be set to 0 + let non_uniform_constraints = self + .offset_equality_constraints + .par_iter() + .enumerate() + .flat_map(move |(constr_i, constr)| { + let xz = offset_constraint( + flattened_polynomials, + constr, + step_index, + next_step_index_m, + ); + let global_index = step_index * padded_num_constraints + + self.uniform_builder.constraints.len() + + constr_i; + if !xz.is_zero() { + Some((xz, global_index)) + } else { + None + } + }); + + uniform_constraints + .chain(non_uniform_constraints) + .collect::>() + }) + .collect() + } + #[tracing::instrument(skip_all)] pub fn compute_spartan_Az_Bz_Cz< PCS: CommitmentScheme, ProofTranscript: Transcript, >( &self, - flattened_polynomials: &[&DensePolynomial], + flattened_polynomials: &[&DensePolynomial], // N variables of (S steps) ) -> ( SparsePolynomial, SparsePolynomial, SparsePolynomial, ) { + if cfg!(feature = "reorder") { + let span = tracing::span!(tracing::Level::DEBUG, "uniform and non-uniform constraints"); + let _enter = span.enter(); + + let az_sparse = self.compute_spartan_Xz( + flattened_polynomials, + |constraint: &Constraint| &constraint.a, + |flattened_polynomials, constr, step_index, next_step_index_m| { + let eq_a_eval = eval_offset_lc( + &constr.a, + flattened_polynomials, + step_index, + next_step_index_m, + ); + let eq_b_eval = eval_offset_lc( + &constr.b, + flattened_polynomials, + step_index, + next_step_index_m, + ); + let az = eq_a_eval - eq_b_eval; + az + }, + ); + let bz_sparse = self.compute_spartan_Xz( + flattened_polynomials, + |constraint: &Constraint| &constraint.b, + |flattened_polynomials, constr, step_index, next_step_index_m| { + let condition_eval = eval_offset_lc( + &constr.cond, + flattened_polynomials, + step_index, + next_step_index_m, + ); + let bz = condition_eval; + bz + }, + ); + let cz_sparse = self.compute_spartan_Xz( + flattened_polynomials, + |constraint: &Constraint| &constraint.c, + |_, _, _, _| F::zero(), + ); + drop(_enter); + + let num_vars = self.constraint_rows().next_power_of_two().log_2(); + let az_poly = SparsePolynomial::new(num_vars, az_sparse); + let bz_poly = SparsePolynomial::new(num_vars, bz_sparse); + let cz_poly = SparsePolynomial::new(num_vars, cz_sparse); + + #[cfg(test)] + self.assert_valid(flattened_polynomials, &az_poly, &bz_poly, &cz_poly); + + return (az_poly, bz_poly, cz_poly) + } + let uniform_constraint_rows = self.uniform_repeat_constraint_rows(); // uniform_constraints: Xz[0..uniform_constraint_rows] diff --git a/jolt-core/src/r1cs/key.rs b/jolt-core/src/r1cs/key.rs index 4c777d3d9..470e5cf40 100644 --- a/jolt-core/src/r1cs/key.rs +++ b/jolt-core/src/r1cs/key.rs @@ -25,6 +25,7 @@ pub struct UniformSpartanKey { pub offset_eq_r1cs: NonUniformR1CS, /// Number of constraints across all steps padded to nearest power of 2 + // TODO: #[not(cfg(feature = "reorder"))] pub num_cons_total: usize, /// Number of steps padded to the nearest power of 2 @@ -112,6 +113,11 @@ impl NonUniformR1CS { (eq_constants, condition_constants) } + + /// Unpadded number of non-uniform constraints. + fn num_constraints(&self) -> usize { + self.constraints.len() + } } /// Represents a single constraint row where the variables are either from the current step (offset = false) @@ -138,8 +144,9 @@ impl UniformSpartanKey UniformSpartanKey UniformSpartanKey usize { self.num_cons_total } + /// Padded number of constraint rows per step. + // #[cfg(feature = "reorder")] + pub fn padded_row_constraint_per_step(&self) -> usize { + // JP: This is redundant with `padded_rows_per_step`. Can we reuse that instead? + (self.uniform_r1cs.num_rows + self.offset_eq_r1cs.num_constraints()).next_power_of_two() + } + + /// Number of bits needed for all rows. + // #[cfg(feature = "reorder")] + pub fn num_rows_bits(&self) -> usize { + let row_count = self.num_steps * self.padded_row_constraint_per_step(); + row_count.next_power_of_two().log_2() + } + /// Evaluates A(r_x, y) + r_rlc * B(r_x, y) + r_rlc^2 * C(r_x, y) where r_x = r_constr || r_step for all y. #[tracing::instrument(skip_all, name = "UniformSpartanKey::evaluate_r1cs_mle_rlc")] pub fn evaluate_r1cs_mle_rlc(&self, r_constr: &[F], r_step: &[F], r_rlc: F) -> Vec { @@ -213,9 +236,9 @@ impl UniformSpartanKey UniformSpartanKey (F, F, F) { - let total_rows_bits = self.num_rows_total().log_2(); - let total_cols_bits = self.num_cols_total().log_2(); - let steps_bits: usize = self.num_steps.log_2(); - let constraint_rows_bits = (self.uniform_r1cs.num_rows + 1).next_power_of_two().log_2(); + pub fn evaluate_r1cs_matrix_mles(&self, r_row: &[F], r_col: &[F]) -> (F, F, F) { + let (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) = if cfg!(feature = "reorder") { + let total_rows_bits = r_row.len(); + let total_cols_bits = r_col.len(); + let constraint_rows_bits = self.padded_row_constraint_per_step().log_2(); + let steps_bits: usize = total_rows_bits - constraint_rows_bits; + let (r_row_step, r_row_constr) = r_row.split_at(total_rows_bits - constraint_rows_bits); // TMP + (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) + } else { + let total_rows_bits = self.num_rows_total().log_2(); + let total_cols_bits = self.num_cols_total().log_2(); + let steps_bits: usize = self.num_steps.log_2(); + let constraint_rows_bits = (self.uniform_r1cs.num_rows + 1).next_power_of_two().log_2(); + assert_eq!(r_row.len(), total_rows_bits); + assert_eq!(r_col.len(), total_cols_bits); + assert_eq!(total_rows_bits - steps_bits, constraint_rows_bits); + + // Deconstruct 'r' into representitive bits + let (r_row_constr, r_row_step) = r_row.split_at(constraint_rows_bits); + (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) + }; let uniform_cols_bits = self.uniform_r1cs.num_vars.next_power_of_two().log_2(); - assert_eq!(r.len(), total_rows_bits + total_cols_bits); - assert_eq!(total_rows_bits - steps_bits, constraint_rows_bits); - - // Deconstruct 'r' into representitive bits - let (r_row, r_col) = r.split_at(total_rows_bits); - let (r_row_constr, r_row_step) = r_row.split_at(constraint_rows_bits); let (r_col_var, r_col_step) = r_col.split_at(uniform_cols_bits + 1); assert_eq!(r_row_step.len(), r_col_step.len()); diff --git a/jolt-core/src/r1cs/ops.rs b/jolt-core/src/r1cs/ops.rs index f31f9656a..9b3447402 100644 --- a/jolt-core/src/r1cs/ops.rs +++ b/jolt-core/src/r1cs/ops.rs @@ -108,6 +108,23 @@ impl LC { result } + // JP: Why does `evaluate` exist? + pub fn evaluate_row( + &self, + flattened_polynomials: &[&DensePolynomial], + row: usize, + ) -> F { + self.terms() + .iter() + .map(|term| match term.0 { + Variable::Input(var_index) | Variable::Auxiliary(var_index) => { + F::from_i64(term.1).mul_01_optimized(flattened_polynomials[var_index][row]) + } + Variable::Constant => F::from_i64(term.1), + }) + .sum() + } + pub fn evaluate_batch( &self, flattened_polynomials: &[&DensePolynomial], diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 1be9ee2ed..224579a81 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -115,7 +115,11 @@ where .map(|var| var.get_ref(polynomials)) .collect(); - let num_rounds_x = key.num_rows_total().log_2(); + let num_rounds_x = if cfg!(feature = "reorder") { + key.num_rows_bits() + } else { + key.num_rows_total().log_2() + }; let num_rounds_y = key.num_cols_total().log_2(); // outer sum-check @@ -156,12 +160,19 @@ where + r_inner_sumcheck_RLC * r_inner_sumcheck_RLC * claim_Cz; // this is the polynomial extended from the vector r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y - let num_steps_bits = constraint_builder - .uniform_repeat() - .next_power_of_two() - .ilog2(); - let (rx_con, rx_ts) = - outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); + let (rx_con, rx_ts) = if cfg!(feature = "reorder") { + let num_constr_bits = constraint_builder.padded_rows_per_step().ilog2() as usize; + let (rx_ts, rx_con) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_constr_bits); + (rx_con, rx_ts) + } else { + let num_steps_bits = constraint_builder + .uniform_repeat() + .next_power_of_two() + .ilog2(); + let (rx_con, rx_ts) = + outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_steps_bits as usize); + (rx_con, rx_ts) + }; let mut poly_ABC = DensePolynomial::new(key.evaluate_r1cs_mle_rlc(rx_con, rx_ts, r_inner_sumcheck_RLC)); @@ -221,7 +232,11 @@ where PCS: CommitmentScheme, ProofTranscript: Transcript, { - let num_rounds_x = key.num_rows_total().log_2(); + let num_rounds_x = if cfg!(feature = "reorder") { + key.num_rows_bits() + } else { + key.num_rows_total().log_2() + }; let num_rounds_y = key.num_cols_total().log_2(); // outer sum-check @@ -271,8 +286,7 @@ where let eval_Z = key.evaluate_z_mle(&self.claimed_witness_evals, &inner_sumcheck_r); let r_y = inner_sumcheck_r.clone(); - let r = [r_x, r_y].concat(); - let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r); + let (eval_a, eval_b, eval_c) = key.evaluate_r1cs_matrix_mles(&r_x, &r_y); let left_expected = eval_a + r_inner_sumcheck_RLC * eval_b From 5e8bb6c8ac63f2b6bffb420e59f0649982a592be Mon Sep 17 00:00:00 2001 From: James Parker Date: Mon, 9 Dec 2024 15:19:54 -0500 Subject: [PATCH 2/2] Format --- jolt-core/src/r1cs/builder.rs | 6 ++-- jolt-core/src/r1cs/key.rs | 55 ++++++++++++++++++++++------------- jolt-core/src/r1cs/spartan.rs | 3 +- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index db3cb6c20..47f72f9b1 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -580,12 +580,12 @@ impl CombinedUniformBuilder usize { if cfg!(feature = "reorder") { - return self.uniform_repeat * self.padded_rows_per_step() + return self.uniform_repeat * self.padded_rows_per_step(); } self.offset_eq_constraint_rows() + self.uniform_repeat_constraint_rows() @@ -816,7 +816,7 @@ impl CombinedUniformBuilder UniformSpartanKey (F, F, F) { - let (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) = if cfg!(feature = "reorder") { - let total_rows_bits = r_row.len(); - let total_cols_bits = r_col.len(); - let constraint_rows_bits = self.padded_row_constraint_per_step().log_2(); - let steps_bits: usize = total_rows_bits - constraint_rows_bits; - let (r_row_step, r_row_constr) = r_row.split_at(total_rows_bits - constraint_rows_bits); // TMP - (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) - } else { - let total_rows_bits = self.num_rows_total().log_2(); - let total_cols_bits = self.num_cols_total().log_2(); - let steps_bits: usize = self.num_steps.log_2(); - let constraint_rows_bits = (self.uniform_r1cs.num_rows + 1).next_power_of_two().log_2(); - assert_eq!(r_row.len(), total_rows_bits); - assert_eq!(r_col.len(), total_cols_bits); - assert_eq!(total_rows_bits - steps_bits, constraint_rows_bits); - - // Deconstruct 'r' into representitive bits - let (r_row_constr, r_row_step) = r_row.split_at(constraint_rows_bits); - (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) - }; + let (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) = + if cfg!(feature = "reorder") { + let total_rows_bits = r_row.len(); + let total_cols_bits = r_col.len(); + let constraint_rows_bits = self.padded_row_constraint_per_step().log_2(); + let steps_bits: usize = total_rows_bits - constraint_rows_bits; + let (r_row_step, r_row_constr) = + r_row.split_at(total_rows_bits - constraint_rows_bits); // TMP + ( + total_cols_bits, + steps_bits, + constraint_rows_bits, + r_row_constr, + r_row_step, + ) + } else { + let total_rows_bits = self.num_rows_total().log_2(); + let total_cols_bits = self.num_cols_total().log_2(); + let steps_bits: usize = self.num_steps.log_2(); + let constraint_rows_bits = + (self.uniform_r1cs.num_rows + 1).next_power_of_two().log_2(); + assert_eq!(r_row.len(), total_rows_bits); + assert_eq!(r_col.len(), total_cols_bits); + assert_eq!(total_rows_bits - steps_bits, constraint_rows_bits); + + // Deconstruct 'r' into representitive bits + let (r_row_constr, r_row_step) = r_row.split_at(constraint_rows_bits); + ( + total_cols_bits, + steps_bits, + constraint_rows_bits, + r_row_constr, + r_row_step, + ) + }; let uniform_cols_bits = self.uniform_r1cs.num_vars.next_power_of_two().log_2(); let (r_col_var, r_col_step) = r_col.split_at(uniform_cols_bits + 1); assert_eq!(r_row_step.len(), r_col_step.len()); diff --git a/jolt-core/src/r1cs/spartan.rs b/jolt-core/src/r1cs/spartan.rs index 224579a81..f34b8e084 100644 --- a/jolt-core/src/r1cs/spartan.rs +++ b/jolt-core/src/r1cs/spartan.rs @@ -162,7 +162,8 @@ where // this is the polynomial extended from the vector r_A * A(r_x, y) + r_B * B(r_x, y) + r_C * C(r_x, y) for all y let (rx_con, rx_ts) = if cfg!(feature = "reorder") { let num_constr_bits = constraint_builder.padded_rows_per_step().ilog2() as usize; - let (rx_ts, rx_con) = outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_constr_bits); + let (rx_ts, rx_con) = + outer_sumcheck_r.split_at(outer_sumcheck_r.len() - num_constr_bits); (rx_con, rx_ts) } else { let num_steps_bits = constraint_builder