Skip to content

Commit

Permalink
Create blake component that uses GKR for lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 26, 2024
1 parent ddd533f commit d1436e5
Show file tree
Hide file tree
Showing 20 changed files with 489 additions and 60 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<const N: usize> LookupElements<N> {
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<Output = EF>,
{
EF::from(values[0])
+ values[1..]
Expand Down
6 changes: 6 additions & 0 deletions crates/prover/src/core/lookups/gkr_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ pub struct GkrArtifact {
pub n_variables_by_instance: Vec<usize>,
}

impl GkrArtifact {
pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] {
&self.ood_point[self.ood_point.len() - instance_n_variables..]
}
}

/// Defines how a circuit operates locally on two input rows to produce a single output row.
/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure.
///
Expand Down
31 changes: 14 additions & 17 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ impl BlakeStatement0 {
}

pub struct AllElements {
blake_elements: BlakeElements,
round_elements: RoundElements,
xor_elements: BlakeXorElements,
pub blake_elements: BlakeElements,
pub round_elements: RoundElements,
pub xor_elements: BlakeXorElements,
}
impl AllElements {
pub fn draw(channel: &mut impl Channel) -> Self {
Expand Down Expand Up @@ -223,7 +223,7 @@ where
{
assert!(log_size >= LOG_N_LANES);
assert_eq!(
ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::<u32>() as usize,
ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::<usize>(),
N_ROUNDS
);

Expand All @@ -240,7 +240,7 @@ where
span.exit();

// Prepare inputs.
let blake_inputs = (0..(1 << (log_size - LOG_N_LANES)))
let blake_inputs = (0..1 << (log_size - LOG_N_LANES))
.map(|i| {
let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16];
let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16];
Expand Down Expand Up @@ -282,18 +282,15 @@ where

// Trace commitment.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(
chain![
scheduler_trace,
round_traces.into_iter().flatten(),
xor_trace12,
xor_trace9,
xor_trace8,
xor_trace7,
xor_trace4,
]
.collect_vec(),
);
tree_builder.extend_evals(chain![
scheduler_trace,
round_traces.into_iter().flatten(),
xor_trace12,
xor_trace9,
xor_trace8,
xor_trace7,
xor_trace4,
]);
tree_builder.commit(channel);
span.exit();

Expand Down
44 changes: 22 additions & 22 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@ use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::FieldExpOps;

mod air;
mod round;
mod scheduler;
mod xor_table;
pub mod air;
pub mod round;
pub mod scheduler;
pub mod xor_table;

const STATE_SIZE: usize = 16;
const MESSAGE_SIZE: usize = 16;
const N_FELTS_IN_U32: usize = 2;
const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32;
pub const STATE_SIZE: usize = 16;
pub const MESSAGE_SIZE: usize = 16;
pub const N_FELTS_IN_U32: usize = 2;
pub const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32;

// Parameters for Blake2s. Change these for blake3.
const N_ROUNDS: usize = 10;
pub const N_ROUNDS: usize = 10;
/// A splitting N_ROUNDS into several powers of 2.
const ROUND_LOG_SPLIT: [u32; 2] = [3, 1];
pub const ROUND_LOG_SPLIT: [u32; 2] = [3, 1];

#[derive(Default)]
struct XorAccums {
xor12: XorAccumulator<12, 4>,
xor9: XorAccumulator<9, 2>,
xor8: XorAccumulator<8, 2>,
xor7: XorAccumulator<7, 2>,
xor4: XorAccumulator<4, 0>,
pub struct XorAccums {
pub xor12: XorAccumulator<12, 4>,
pub xor9: XorAccumulator<9, 2>,
pub xor8: XorAccumulator<8, 2>,
pub xor7: XorAccumulator<7, 2>,
pub xor4: XorAccumulator<4, 0>,
}
impl XorAccums {
fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) {
Expand All @@ -50,11 +50,11 @@ impl XorAccums {

#[derive(Clone)]
pub struct BlakeXorElements {
xor12: XorElements,
xor9: XorElements,
xor8: XorElements,
xor7: XorElements,
xor4: XorElements,
pub xor12: XorElements,
pub xor9: XorElements,
pub xor8: XorElements,
pub xor7: XorElements,
pub xor4: XorElements,
}
impl BlakeXorElements {
fn draw(channel: &mut impl Channel) -> Self {
Expand All @@ -75,7 +75,7 @@ impl BlakeXorElements {
xor4: XorElements::dummy(),
}
}
fn get(&self, w: u32) -> &XorElements {
pub fn get(&self, w: u32) -> &XorElements {
match w {
12 => &self.xor12,
9 => &self.xor9,
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/blake/round/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZ
pub struct BlakeRoundLookupData {
/// A vector of (w, [a_col, b_col, c_col]) for each xor lookup.
/// w is the xor width. c_col is the xor col of a_col and b_col.
xor_lookups: Vec<(u32, [BaseColumn; 3])>,
pub xor_lookups: Vec<(u32, [BaseColumn; 3])>,
/// A column of round lookup values (v_in, v_out, m).
round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS],
pub round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS],
}

pub struct TraceGenerator {
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod constraints;
mod gen;

pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput};
pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput, BlakeRoundLookupData};
use num_traits::Zero;

use super::{BlakeXorElements, N_ROUND_INPUT_FELTS};
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/examples/blake/scheduler/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub fn gen_trace(
.map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) })
.collect_vec();

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let mut col_index = 0;

let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| {
Expand Down Expand Up @@ -125,11 +125,11 @@ pub fn gen_interaction_trace(

let mut logup_gen = LogupTraceGenerator::new(log_size);

for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() {
for [l0, l1] in lookup_data.round_lookups.array_chunks() {
let mut col_gen = logup_gen.new_col();

#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let p0: PackedSecureField =
round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row]));
let p1: PackedSecureField =
Expand All @@ -145,7 +145,7 @@ pub fn gen_interaction_trace(
// with the entire blake lookup.
let mut col_gen = logup_gen.new_col();
#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
for vec_row in 0..1 << (log_size - LOG_N_LANES) {
let p_blake: PackedSecureField = blake_lookup_elements.combine(
&lookup_data
.blake_lookups
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod constraints;
mod gen;

use constraints::eval_blake_scheduler_constraints;
pub use gen::{gen_interaction_trace, gen_trace, BlakeInput};
pub use gen::{gen_interaction_trace, gen_trace, BlakeInput, BlakeSchedulerLookupData};
use num_traits::Zero;

use super::round::RoundElements;
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/xor_table/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn generate_interaction_trace<const ELEM_BITS: u32, const EXPAND_BITS: u32>(

// Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES.
#[allow(clippy::needless_range_loop)]
for vec_row in 0..(1 << (column_bits::<ELEM_BITS, EXPAND_BITS>() - LOG_N_LANES)) {
for vec_row in 0..1 << (column_bits::<ELEM_BITS, EXPAND_BITS>() - LOG_N_LANES) {
// vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl.
// Extract al, blh from vec_row.
let al = vec_row >> (limb_bits - LOG_N_LANES);
Expand Down
6 changes: 4 additions & 2 deletions crates/prover/src/examples/blake/xor_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use std::simd::u32x16;

use itertools::Itertools;
use num_traits::Zero;
pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace};
pub use r#gen::{
generate_constant_trace, generate_interaction_trace, generate_trace, XorTableLookupData,
};

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
Expand All @@ -37,7 +39,7 @@ pub fn trace_sizes<const ELEM_BITS: u32, const EXPAND_BITS: u32>() -> TreeVec<Ve
.map_cols(|_| column_bits::<ELEM_BITS, EXPAND_BITS>())
}

const fn limb_bits<const ELEM_BITS: u32, const EXPAND_BITS: u32>() -> u32 {
pub const fn limb_bits<const ELEM_BITS: u32, const EXPAND_BITS: u32>() -> u32 {
ELEM_BITS - EXPAND_BITS
}
pub const fn column_bits<const ELEM_BITS: u32, const EXPAND_BITS: u32>() -> u32 {
Expand Down
Loading

0 comments on commit d1436e5

Please sign in to comment.