Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorder R1CS constraints #554

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jolt-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ default = [
"ark-ff/asm",
"host",
"rayon",
"reorder",
]
host = ["dep:reqwest", "dep:tokio"]
reorder = []

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
memory-stats = "1.0.0"
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ where
let padded_trace_length = trace_length.next_power_of_two();
println!("Trace length: {}", trace_length);

// TODO(JP): Drop padding on number of steps
JoltTraceStep::pad(&mut trace);

let mut transcript = ProofTranscript::new(b"Jolt transcript");
Expand Down
176 changes: 174 additions & 2 deletions jolt-core/src/r1cs/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,12 +516,27 @@ impl OffsetEqConstraint {
}
}

pub(crate) fn eval_offset_lc<F: JoltField>(
offset: &OffsetLC,
flattened_polynomials: &[&DensePolynomial<F>],
step: usize,
next_step_m: Option<usize>,
) -> F {
if !offset.0 {
offset.1.evaluate_row(flattened_polynomials, step)
} else if let Some(next_step) = next_step_m {
offset.1.evaluate_row(flattened_polynomials, next_step)
} else {
offset.1.constant_term_field()
}
}

// TODO(sragss): Detailed documentation with wiki.
pub struct CombinedUniformBuilder<const C: usize, F: JoltField, I: ConstraintInput> {
uniform_builder: R1CSBuilder<C, F, I>,

/// Padded to the nearest power of 2
uniform_repeat: usize,
uniform_repeat: usize, // TODO(JP): Remove padding of steps

offset_equality_constraints: Vec<OffsetEqConstraint>,
}
Expand Down Expand Up @@ -554,12 +569,25 @@ impl<const C: usize, F: JoltField, I: ConstraintInput> CombinedUniformBuilder<C,
self.uniform_repeat * self.uniform_builder.constraints.len()
}

// TODO: #[not(cfg(feature = "reorder"))]
pub(super) fn offset_eq_constraint_rows(&self) -> usize {
self.uniform_repeat * self.offset_equality_constraints.len()
}

/// Number of constraint rows per step, padded to the next power of two.
// #[cfg(feature = "reorder")]
pub(super) fn padded_rows_per_step(&self) -> usize {
let num_constraints =
self.uniform_builder.constraints.len() + self.offset_equality_constraints.len();
num_constraints.next_power_of_two()
}

/// Total number of rows used across all repeated constraints. Not padded to nearest power of two.
pub(super) fn constraint_rows(&self) -> usize {
if cfg!(feature = "reorder") {
return self.uniform_repeat * self.padded_rows_per_step();
}

self.offset_eq_constraint_rows() + self.uniform_repeat_constraint_rows()
}

Expand Down Expand Up @@ -635,18 +663,162 @@ impl<const C: usize, F: JoltField, I: ConstraintInput> CombinedUniformBuilder<C,
NonUniformR1CS { constraints }
}

// #[cfg(feature = "reorder")]
fn compute_spartan_Xz<
U: for<'a> Fn(&'a Constraint) -> &'a LC + Send + Sync + Copy + 'static,
O: Fn(&[&DensePolynomial<F>], &OffsetEqConstraint, usize, Option<usize>) -> F
+ Send
+ Sync
+ Copy
+ 'static,
>(
&self,
flattened_polynomials: &[&DensePolynomial<F>], // N variables of (S steps)
uniform_constraint: U,
offset_constraint: O,
) -> Vec<(F, usize)> {
let num_steps = flattened_polynomials[0].len();
let padded_num_constraints = self.padded_rows_per_step();

// Filter out constraints that won't contribute ahead of time.
let filtered_uniform_constraints = self
.uniform_builder
.constraints
.iter()
.enumerate()
.filter_map(|(index, constraint)| {
let lc = uniform_constraint(constraint);
if lc.terms().is_empty() {
None
} else {
Some((index, lc))
}
})
.collect::<Vec<_>>();

(0..num_steps)
.into_par_iter()
.flat_map_iter(|step_index| {
let next_step_index_m = if step_index + 1 < num_steps {
Some(step_index + 1)
} else {
None
};

// uniform_constraints
let uniform_constraints = filtered_uniform_constraints.par_iter().flat_map(
move |(constraint_index, lc)| {
// Evaluate a constraint on a given step.
let item = lc.evaluate_row(flattened_polynomials, step_index);
if !item.is_zero() {
let global_index =
step_index * padded_num_constraints + constraint_index;
Some((item, global_index))
} else {
None
}
},
);

// offset_equality_constraints
// (a - b) * condition == 0
// For the final step we will not compute the offset terms, and will assume the condition to be set to 0
let non_uniform_constraints = self
.offset_equality_constraints
.par_iter()
.enumerate()
.flat_map(move |(constr_i, constr)| {
let xz = offset_constraint(
flattened_polynomials,
constr,
step_index,
next_step_index_m,
);
let global_index = step_index * padded_num_constraints
+ self.uniform_builder.constraints.len()
+ constr_i;
if !xz.is_zero() {
Some((xz, global_index))
} else {
None
}
});

uniform_constraints
.chain(non_uniform_constraints)
.collect::<Vec<_>>()
})
.collect()
}

#[tracing::instrument(skip_all)]
pub fn compute_spartan_Az_Bz_Cz<
PCS: CommitmentScheme<ProofTranscript, Field = F>,
ProofTranscript: Transcript,
>(
&self,
flattened_polynomials: &[&DensePolynomial<F>],
flattened_polynomials: &[&DensePolynomial<F>], // N variables of (S steps)
) -> (
SparsePolynomial<F>,
SparsePolynomial<F>,
SparsePolynomial<F>,
) {
if cfg!(feature = "reorder") {
let span = tracing::span!(tracing::Level::DEBUG, "uniform and non-uniform constraints");
let _enter = span.enter();

let az_sparse = self.compute_spartan_Xz(
flattened_polynomials,
|constraint: &Constraint| &constraint.a,
|flattened_polynomials, constr, step_index, next_step_index_m| {
let eq_a_eval = eval_offset_lc(
&constr.a,
flattened_polynomials,
step_index,
next_step_index_m,
);
let eq_b_eval = eval_offset_lc(
&constr.b,
flattened_polynomials,
step_index,
next_step_index_m,
);
let az = eq_a_eval - eq_b_eval;
az
},
);
let bz_sparse = self.compute_spartan_Xz(
flattened_polynomials,
|constraint: &Constraint| &constraint.b,
|flattened_polynomials, constr, step_index, next_step_index_m| {
let condition_eval = eval_offset_lc(
&constr.cond,
flattened_polynomials,
step_index,
next_step_index_m,
);
let bz = condition_eval;
bz
},
);
let cz_sparse = self.compute_spartan_Xz(
flattened_polynomials,
|constraint: &Constraint| &constraint.c,
|_, _, _, _| F::zero(),
);
drop(_enter);

let num_vars = self.constraint_rows().next_power_of_two().log_2();
let az_poly = SparsePolynomial::new(num_vars, az_sparse);
let bz_poly = SparsePolynomial::new(num_vars, bz_sparse);
let cz_poly = SparsePolynomial::new(num_vars, cz_sparse);

#[cfg(test)]
self.assert_valid(flattened_polynomials, &az_poly, &bz_poly, &cz_poly);

return (az_poly, bz_poly, cz_poly);
}

let uniform_constraint_rows = self.uniform_repeat_constraint_rows();

// uniform_constraints: Xz[0..uniform_constraint_rows]
Expand Down
78 changes: 63 additions & 15 deletions jolt-core/src/r1cs/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct UniformSpartanKey<const C: usize, I: ConstraintInput, F: JoltField> {
pub offset_eq_r1cs: NonUniformR1CS<F>,

/// Number of constraints across all steps padded to nearest power of 2
// TODO: #[not(cfg(feature = "reorder"))]
pub num_cons_total: usize,

/// Number of steps padded to the nearest power of 2
Expand Down Expand Up @@ -112,6 +113,11 @@ impl<F: JoltField> NonUniformR1CS<F> {

(eq_constants, condition_constants)
}

/// Unpadded number of non-uniform constraints.
fn num_constraints(&self) -> usize {
self.constraints.len()
}
}

/// Represents a single constraint row where the variables are either from the current step (offset = false)
Expand All @@ -138,15 +144,17 @@ impl<const C: usize, F: JoltField, I: ConstraintInput> UniformSpartanKey<C, I, F
let uniform_r1cs = constraint_builder.materialize_uniform();
let offset_eq_r1cs = constraint_builder.materialize_offset_eq();

// TODO: #[not(cfg(feature = "reorder"))]
let total_rows = constraint_builder.constraint_rows().next_power_of_two();
let num_steps = constraint_builder.uniform_repeat().next_power_of_two();
let num_steps = constraint_builder.uniform_repeat().next_power_of_two(); // TODO(JP): Number of steps no longer need to be padded.

let vk_digest = Self::digest(&uniform_r1cs, &offset_eq_r1cs, num_steps);

Self {
_inputs: PhantomData,
uniform_r1cs,
offset_eq_r1cs,
// TODO: #[not(cfg(feature = "reorder"))]
num_cons_total: total_rows,
num_steps,
vk_digest,
Expand All @@ -167,10 +175,25 @@ impl<const C: usize, F: JoltField, I: ConstraintInput> UniformSpartanKey<C, I, F
2 * self.num_vars_total()
}

// TODO: #[not(cfg(feature = "reorder"))]
pub fn num_rows_total(&self) -> usize {
self.num_cons_total
}

/// Padded number of constraint rows per step.
// #[cfg(feature = "reorder")]
pub fn padded_row_constraint_per_step(&self) -> usize {
// JP: This is redundant with `padded_rows_per_step`. Can we reuse that instead?
(self.uniform_r1cs.num_rows + self.offset_eq_r1cs.num_constraints()).next_power_of_two()
}

/// Number of bits needed for all rows.
// #[cfg(feature = "reorder")]
pub fn num_rows_bits(&self) -> usize {
let row_count = self.num_steps * self.padded_row_constraint_per_step();
row_count.next_power_of_two().log_2()
}

/// Evaluates A(r_x, y) + r_rlc * B(r_x, y) + r_rlc^2 * C(r_x, y) where r_x = r_constr || r_step for all y.
#[tracing::instrument(skip_all, name = "UniformSpartanKey::evaluate_r1cs_mle_rlc")]
pub fn evaluate_r1cs_mle_rlc(&self, r_constr: &[F], r_step: &[F], r_rlc: F) -> Vec<F> {
Expand Down Expand Up @@ -213,9 +236,9 @@ impl<const C: usize, F: JoltField, I: ConstraintInput> UniformSpartanKey<C, I, F
};

let (eq_constants, condition_constants) = self.offset_eq_r1cs.constants();
let sm_a_r = compute_repeated(&self.uniform_r1cs.a, Some(eq_constants));
let sm_b_r = compute_repeated(&self.uniform_r1cs.b, Some(condition_constants));
let sm_c_r = compute_repeated(&self.uniform_r1cs.c, None);
let sm_a_r = compute_repeated(&self.uniform_r1cs.a, Some(eq_constants)); // V var entries
let sm_b_r = compute_repeated(&self.uniform_r1cs.b, Some(condition_constants)); // V var entries
let sm_c_r = compute_repeated(&self.uniform_r1cs.c, None); // V var entries

let r_rlc_sq = r_rlc.square();
let sm_rlc = sm_a_r
Expand Down Expand Up @@ -306,18 +329,43 @@ impl<const C: usize, F: JoltField, I: ConstraintInput> UniformSpartanKey<C, I, F

/// Evaluates A(r), B(r), C(r) efficiently using their small uniform representations.
#[tracing::instrument(skip_all, name = "UniformSpartanKey::evaluate_r1cs_matrix_mles")]
pub fn evaluate_r1cs_matrix_mles(&self, r: &[F]) -> (F, F, F) {
let total_rows_bits = self.num_rows_total().log_2();
let total_cols_bits = self.num_cols_total().log_2();
let steps_bits: usize = self.num_steps.log_2();
let constraint_rows_bits = (self.uniform_r1cs.num_rows + 1).next_power_of_two().log_2();
pub fn evaluate_r1cs_matrix_mles(&self, r_row: &[F], r_col: &[F]) -> (F, F, F) {
let (total_cols_bits, steps_bits, constraint_rows_bits, r_row_constr, r_row_step) =
if cfg!(feature = "reorder") {
let total_rows_bits = r_row.len();
let total_cols_bits = r_col.len();
let constraint_rows_bits = self.padded_row_constraint_per_step().log_2();
let steps_bits: usize = total_rows_bits - constraint_rows_bits;
let (r_row_step, r_row_constr) =
r_row.split_at(total_rows_bits - constraint_rows_bits); // TMP
(
total_cols_bits,
steps_bits,
constraint_rows_bits,
r_row_constr,
r_row_step,
)
} else {
let total_rows_bits = self.num_rows_total().log_2();
let total_cols_bits = self.num_cols_total().log_2();
let steps_bits: usize = self.num_steps.log_2();
let constraint_rows_bits =
(self.uniform_r1cs.num_rows + 1).next_power_of_two().log_2();
assert_eq!(r_row.len(), total_rows_bits);
assert_eq!(r_col.len(), total_cols_bits);
assert_eq!(total_rows_bits - steps_bits, constraint_rows_bits);

// Deconstruct 'r' into representitive bits
let (r_row_constr, r_row_step) = r_row.split_at(constraint_rows_bits);
(
total_cols_bits,
steps_bits,
constraint_rows_bits,
r_row_constr,
r_row_step,
)
};
let uniform_cols_bits = self.uniform_r1cs.num_vars.next_power_of_two().log_2();
assert_eq!(r.len(), total_rows_bits + total_cols_bits);
assert_eq!(total_rows_bits - steps_bits, constraint_rows_bits);

// Deconstruct 'r' into representitive bits
let (r_row, r_col) = r.split_at(total_rows_bits);
let (r_row_constr, r_row_step) = r_row.split_at(constraint_rows_bits);
let (r_col_var, r_col_step) = r_col.split_at(uniform_cols_bits + 1);
assert_eq!(r_row_step.len(), r_col_step.len());

Expand Down
17 changes: 17 additions & 0 deletions jolt-core/src/r1cs/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ impl LC {
result
}

// JP: Why does `evaluate` exist?
pub fn evaluate_row<F: JoltField>(
&self,
flattened_polynomials: &[&DensePolynomial<F>],
row: usize,
) -> F {
self.terms()
.iter()
.map(|term| match term.0 {
Variable::Input(var_index) | Variable::Auxiliary(var_index) => {
F::from_i64(term.1).mul_01_optimized(flattened_polynomials[var_index][row])
}
Variable::Constant => F::from_i64(term.1),
})
.sum()
}

pub fn evaluate_batch<F: JoltField>(
&self,
flattened_polynomials: &[&DensePolynomial<F>],
Expand Down
Loading