Skip to content

Commit

Permalink
Create blake component that uses GKR for lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 28, 2024
1 parent 4f24c54 commit a956cef
Show file tree
Hide file tree
Showing 25 changed files with 1,065 additions and 81 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<const N: usize> LookupElements<N> {
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<Output = EF>,
{
EF::from(values[0])
+ values[1..]
Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ mod tests {
let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -354,6 +355,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -391,6 +393,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -427,6 +430,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/simd/lookups/gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -590,6 +591,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -629,6 +631,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down Expand Up @@ -666,6 +669,7 @@ mod tests {
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
gate_by_instance: _,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;

Expand Down
13 changes: 12 additions & 1 deletion crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;

use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::gkr_verifier::{Gate, GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
Expand Down Expand Up @@ -409,6 +409,16 @@ pub fn prove_batch<B: GkrOps>(
.collect_vec();
let n_layers = *n_layers_by_instance.iter().max().unwrap();

let gate_by_instance = input_layer_by_instance
.iter()
.map(|l| match l {
Layer::GrandProduct(_) => Gate::GrandProduct,
Layer::LogUpGeneric { .. }
| Layer::LogUpMultiplicities { .. }
| Layer::LogUpSingles { .. } => Gate::LogUp,
})
.collect();

// Evaluate all instance circuits and collect the layer values.
let mut layers_by_instance = input_layer_by_instance
.into_iter()
Expand Down Expand Up @@ -502,6 +512,7 @@ pub fn prove_batch<B: GkrOps>(

let artifact = GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance: n_layers_by_instance,
};
Expand Down
109 changes: 108 additions & 1 deletion crates/prover/src/core/lookups/gkr_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub fn partially_verify_batch(

Ok(GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
})
Expand All @@ -162,12 +163,114 @@ pub struct GkrBatchProof {
pub struct GkrArtifact {
/// Out-of-domain (OOD) point for evaluating columns in the input layer.
pub ood_point: Vec<SecureField>,
/// The gate of each instance.
pub gate_by_instance: Vec<Gate>,
/// The claimed evaluation at `ood_point` for each column in the input layer of each instance.
pub claims_to_verify_by_instance: Vec<Vec<SecureField>>,
/// The number of variables that interpolate the input layer of each instance.
pub n_variables_by_instance: Vec<usize>,
}

impl GkrArtifact {
pub fn ood_point(&self, instance_n_variables: usize) -> &[SecureField] {
&self.ood_point[self.ood_point.len() - instance_n_variables..]
}
}

pub struct LookupArtifactInstanceIter<'proof, 'artifact> {
instance: usize,
gkr_proof: &'proof GkrBatchProof,
gkr_artifact: &'artifact GkrArtifact,
}

impl<'proof, 'artifact> LookupArtifactInstanceIter<'proof, 'artifact> {
pub fn new(gkr_proof: &'proof GkrBatchProof, gkr_artifact: &'artifact GkrArtifact) -> Self {
Self {
instance: 0,
gkr_proof,
gkr_artifact,
}
}
}

impl<'proof, 'artifact> Iterator for LookupArtifactInstanceIter<'proof, 'artifact> {
type Item = LookupArtifactInstance;

fn next(&mut self) -> Option<LookupArtifactInstance> {
if self.instance >= self.gkr_proof.output_claims_by_instance.len() {
return None;
}

let instance = self.instance;
let input_n_variables = self.gkr_artifact.n_variables_by_instance[instance];
let eval_point = self.gkr_artifact.ood_point(input_n_variables).to_vec();
let output_claim = &*self.gkr_proof.output_claims_by_instance[instance];
let input_claims = &*self.gkr_artifact.claims_to_verify_by_instance[instance];
let gate = self.gkr_artifact.gate_by_instance[instance];

let res = Some(match gate {
Gate::LogUp => {
let [numerator, denominator] = output_claim.try_into().unwrap();
let claimed_sum = Fraction::new(numerator, denominator);
let [input_numerators_claim, input_denominators_claim] =
input_claims.try_into().unwrap();

LookupArtifactInstance::LogUp(LogUpArtifactInstance {
eval_point,
input_n_variables,
input_numerators_claim,
input_denominators_claim,
claimed_sum,
})
}
Gate::GrandProduct => {
let [claimed_product] = output_claim.try_into().unwrap();
let [input_claim] = input_claims.try_into().unwrap();

LookupArtifactInstance::GrandProduct(GrandProductArtifactInstance {
eval_point,
input_n_variables,
input_claim,
claimed_product,
})
}
});

self.instance += 1;
res
}
}

// TODO: Consider making the GKR artifact just a Vec<LookupArtifactInstance>.
pub enum LookupArtifactInstance {
GrandProduct(GrandProductArtifactInstance),
LogUp(LogUpArtifactInstance),
}

pub struct GrandProductArtifactInstance {
/// GKR input layer eval point.
pub eval_point: Vec<SecureField>,
/// Number of variables the MLE in the GKR input layer had.
pub input_n_variables: usize,
/// Claimed input MLE evaluation at `eval_point`.
pub input_claim: SecureField,
/// Output claim from the circuit.
pub claimed_product: SecureField,
}

pub struct LogUpArtifactInstance {
/// GKR input layer eval point.
pub eval_point: Vec<SecureField>,
/// Number of variables the MLEs in the GKR input layer had.
pub input_n_variables: usize,
/// Claimed input numerators MLE evaluation at `eval_point`.
pub input_numerators_claim: SecureField,
/// Claimed input denominators MLE evaluation at `eval_point`.
pub input_denominators_claim: SecureField,
/// Output claim from the circuit.
pub claimed_sum: Fraction<SecureField, SecureField>,
}

/// Defines how a circuit operates locally on two input rows to produce a single output row.
/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure.
///
Expand All @@ -176,7 +279,7 @@ pub struct GkrArtifact {
/// circuit) GKR prover implementations.
///
/// [Thaler13]: https://eprint.iacr.org/2013/351.pdf
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Gate {
LogUp,
GrandProduct,
Expand Down Expand Up @@ -305,11 +408,13 @@ mod tests {

let GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance,
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;

assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]);
assert_eq!(gate_by_instance, [Gate::LogUp, Gate::LogUp]);
assert_eq!(proof.output_claims_by_instance.len(), 2);
assert_eq!(claims_to_verify_by_instance.len(), 2);
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
Expand Down Expand Up @@ -338,11 +443,13 @@ mod tests {

let GkrArtifact {
ood_point,
gate_by_instance,
claims_to_verify_by_instance,
n_variables_by_instance,
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;

assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]);
assert_eq!(gate_by_instance, [Gate::LogUp, Gate::LogUp]);
assert_eq!(proof.output_claims_by_instance.len(), 2);
assert_eq!(claims_to_verify_by_instance.len(), 2);
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
Expand Down
31 changes: 14 additions & 17 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ impl BlakeStatement0 {
}

pub struct AllElements {
blake_elements: BlakeElements,
round_elements: RoundElements,
xor_elements: BlakeXorElements,
pub blake_elements: BlakeElements,
pub round_elements: RoundElements,
pub xor_elements: BlakeXorElements,
}
impl AllElements {
pub fn draw(channel: &mut impl Channel) -> Self {
Expand Down Expand Up @@ -223,7 +223,7 @@ where
{
assert!(log_size >= LOG_N_LANES);
assert_eq!(
ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::<u32>() as usize,
ROUND_LOG_SPLIT.map(|x| 1 << x).iter().sum::<usize>(),
N_ROUNDS
);

Expand All @@ -240,7 +240,7 @@ where
span.exit();

// Prepare inputs.
let blake_inputs = (0..(1 << (log_size - LOG_N_LANES)))
let blake_inputs = (0..1 << (log_size - LOG_N_LANES))
.map(|i| {
let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16];
let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16];
Expand Down Expand Up @@ -282,18 +282,15 @@ where

// Trace commitment.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(
chain![
scheduler_trace,
round_traces.into_iter().flatten(),
xor_trace12,
xor_trace9,
xor_trace8,
xor_trace7,
xor_trace4,
]
.collect_vec(),
);
tree_builder.extend_evals(chain![
scheduler_trace,
round_traces.into_iter().flatten(),
xor_trace12,
xor_trace9,
xor_trace8,
xor_trace7,
xor_trace4,
]);
tree_builder.commit(channel);
span.exit();

Expand Down
Loading

0 comments on commit a956cef

Please sign in to comment.