Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Implement streaming version of BytecodePolynomials #451

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::field::JoltField;
use crate::host;
use crate::jolt::vm::rv32i_vm::{RV32IJoltVM, C, M};
use crate::jolt::vm::Jolt;
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme;
use crate::poly::commitment::hyperkzg::HyperKZG;
use crate::poly::commitment::hyrax::HyraxScheme;
use crate::poly::commitment::zeromorph::Zeromorph;
Expand Down Expand Up @@ -61,23 +61,23 @@ pub fn benchmarks(
fn fibonacci<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
prove_example::<u32, PCS, F>("fibonacci-guest", &9u32)
}

fn sha2<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
prove_example::<Vec<u8>, PCS, F>("sha2-guest", &vec![5u8; 2048])
}

fn sha3<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
prove_example::<Vec<u8>, PCS, F>("sha3-guest", &vec![5u8; 2048])
}
Expand All @@ -99,7 +99,7 @@ fn prove_example<T: Serialize, PCS, F>(
) -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
let mut tasks = Vec::new();
let mut program = host::Program::new(example_name);
Expand Down Expand Up @@ -149,7 +149,7 @@ where
fn sha2chain<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
let mut tasks = Vec::new();
let mut program = host::Program::new("sha2-chain-guest");
Expand Down
3 changes: 3 additions & 0 deletions jolt-core/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ pub trait JoltField:
fn to_u64(&self) -> Option<u64> {
unimplemented!("conversion to u64 not implemented");
}
fn from_usize(val: usize) -> Option<Self> {
Self::from_u64(val as u64)
}
}

pub trait OptimizedMul<Rhs, Output>: Sized + Mul<Rhs, Output = Output> {
Expand Down
171 changes: 170 additions & 1 deletion jolt-core/src/jolt/vm/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[cfg(test)]
use std::collections::HashSet;
use std::marker::PhantomData;

use crate::field::JoltField;
use crate::jolt::instruction::JoltInstructionSet;
use crate::lasso::memory_checking::{
Initializable, NoExogenousOpenings, StructuredPolynomialData, VerifierComputedOpening,
};
use crate::poly::commitment::commitment_scheme::{BatchType, CommitShape, CommitmentScheme};
use crate::poly::commitment::commitment_scheme::{BatchType, CommitShape, CommitmentScheme, StreamingCommitmentScheme};
use crate::poly::eq_poly::EqPolynomial;
use crate::utils::streaming::map_state;
use common::constants::{BYTES_PER_INSTRUCTION, RAM_START_ADDRESS, REGISTER_COUNT};
use common::rv_trace::ELFInstruction;
use common::to_ram_address;
Expand Down Expand Up @@ -95,6 +97,29 @@ impl<T: CanonicalSerialize + CanonicalDeserialize> StructuredPolynomialData<T>
pub type BytecodeProof<F, PCS> =
MemoryCheckingProof<F, PCS, BytecodeOpenings<F>, NoExogenousOpenings>;

pub struct BytecodeRowStep<F: JoltField, C: CommitmentScheme<Field = F>> {
_group: PhantomData<C>,

/// Memory address as read from the ELF.
pub(super) address: F,
/// Packed instruction/circuit flags, used for r1cs
pub(super) bitflags: F,
/// Index of the destination register for this instruction (0 if register is unused).
pub(super) rd: F,
/// Index of the first source register for this instruction (0 if register is unused).
pub(super) rs1: F,
/// Index of the second source register for this instruction (0 if register is unused).
pub(super) rs2: F,
/// "Immediate" value for this instruction (0 if unused).
pub(super) imm: F,
// /// If this instruction is part of a "virtual sequence" (see Section 6.2 of the
// /// Jolt paper), then this contains the number of virtual instructions after this
// /// one in the sequence. I.e. if this is the last instruction in the sequence,
// /// `virtual_sequence_remaining` will be Some(0); if this is the penultimate instruction
// /// in the sequence, `virtual_sequence_remaining` will be Some(1); etc.
// virtual_sequence_remaining: Option<usize>,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BytecodeRow {
/// Memory address as read from the ELF.
Expand Down Expand Up @@ -210,6 +235,26 @@ pub fn random_bytecode_trace(
trace
}

pub struct StreamingBytecodePolynomials<'a, F: JoltField, C: CommitmentScheme<Field = F>> {
/// Length of the polynomial.
length: usize,
/// Stream that builds the bytecode polynomial.
polynomial_stream: Box<dyn Iterator<Item = BytecodePolynomialStep<F, C>> + 'a>, // MapState<Vec<usize>, I, FN>,
}

pub struct BytecodePolynomialStep<F: JoltField, C: CommitmentScheme<Field = F>> {
_group: PhantomData<C>,
/// MLE of read/write addresses. For offline memory checking, each read is paired with a "virtual" write,
/// so the read addresses and write addresses are the same.
pub(super) a_read_write: F,
/// MLE of read/write values. For offline memory checking, each read is paired with a "virtual" write,
/// so the read values and write values are the same. There are six values (address, bitflags, rd, rs1, rs2, imm)
/// associated with each memory address, so `v_read_write` comprises five polynomials.
pub(super) v_read_write: BytecodeRowStep<F, C>,
/// MLE of the read timestamps.
pub(super) t_read: F,
}

#[derive(Clone)]
pub struct BytecodePreprocessing<F: JoltField> {
/// Size of the (padded) bytecode.
Expand Down Expand Up @@ -290,6 +335,80 @@ impl<F: JoltField> BytecodePreprocessing<F> {
}
}

impl<'a, F: JoltField, C: CommitmentScheme<Field = F>> StreamingBytecodePolynomials<'a, F, C> {
#[tracing::instrument(skip_all, name = "StreamingBytecodePolynomials::new")]
pub fn new<InstructionSet: JoltInstructionSet>(
preprocessing: &'a BytecodePreprocessing<F>,
trace: &'a mut [JoltTraceStep<InstructionSet>],
) -> Self {
let final_cts: Vec<usize> = vec![0; preprocessing.code_size];
let length = trace.len();

let polynomial_stream = map_state(final_cts, trace.iter_mut(), |final_cts, step| {
if !step.bytecode_row.address.is_zero() {
assert!(step.bytecode_row.address >= RAM_START_ADDRESS as usize);
assert!(step.bytecode_row.address % BYTES_PER_INSTRUCTION == 0);
// Compress instruction address for more efficient commitment:
step.bytecode_row.address = 1
+ (step.bytecode_row.address - RAM_START_ADDRESS as usize)
/ BYTES_PER_INSTRUCTION;
}

let virtual_address = preprocessing
.virtual_address_map
.get(&(
step.bytecode_row.address,
step.bytecode_row.virtual_sequence_remaining.unwrap_or(0),
))
.unwrap();
let a_read_write_usize = *virtual_address;
let counter = final_cts[*virtual_address];
final_cts[*virtual_address] = counter + 1;

let address = F::from_u64(step.bytecode_row.address as u64).unwrap();
let bitflags = F::from_u64(step.bytecode_row.bitflags).unwrap();
let rd = F::from_u64(step.bytecode_row.rd).unwrap();
let rs1 = F::from_u64(step.bytecode_row.rs1).unwrap();
let rs2 = F::from_u64(step.bytecode_row.rs2).unwrap();
let imm = F::from_u64(step.bytecode_row.imm).unwrap();

let v_read_write = BytecodeRowStep {
_group: PhantomData,
address,
bitflags,
rd,
rs1,
rs2,
imm,
};

BytecodePolynomialStep {
_group: PhantomData,
a_read_write: F::from_usize(a_read_write_usize).unwrap(),
v_read_write,
t_read: F::from_usize(counter).unwrap(),
}
});

StreamingBytecodePolynomials {
length,
polynomial_stream: Box::new(polynomial_stream),
}
}

pub fn fold<T, FN>(self, init: T, f: FN) -> T
where
FN: FnMut(T, BytecodePolynomialStep<F, C>) -> T,
{
self.polynomial_stream.fold(init, f)
}

/// Returns the number of evaluations of the polynomial.
pub fn length(&self) -> usize {
self.length
}
}

impl<F: JoltField, PCS: CommitmentScheme<Field = F>> BytecodeProof<F, PCS> {
#[tracing::instrument(skip_all, name = "BytecodePolynomials::new")]
pub fn generate_witness<InstructionSet: JoltInstructionSet>(
Expand Down Expand Up @@ -469,6 +588,56 @@ impl<F: JoltField, PCS: CommitmentScheme<Field = F>> BytecodeProof<F, PCS> {
}
}

// TODO: Use BytecodeStuff? XXX
pub struct StreamingBytecodeCommitment<'a, C: StreamingCommitmentScheme> {
a_read_write: C::State<'a>,
v_read_write: [C::State<'a>; 6],
t_read: C::State<'a>,
}

impl<'a, C: StreamingCommitmentScheme> StreamingBytecodeCommitment<'a, C> {
/// Initialize a streaming computation of a commitment.
pub fn initialize(size: usize, setup: &'a C::Setup, batch_type: &BatchType) -> Self {
let a_read_write = C::initialize(size, setup, batch_type);
let v_read_write = std::array::from_fn(|_| a_read_write.clone());
let t_read = a_read_write.clone();

StreamingBytecodeCommitment {
a_read_write,
v_read_write,
t_read,
}
}

/// Process one step to compute the commitment.
pub fn process(self, step: &BytecodePolynomialStep<C::Field, C>) -> Self {
let step_v_read_write = [
step.v_read_write.address,
step.v_read_write.bitflags,
step.v_read_write.rd,
step.v_read_write.rs1,
step.v_read_write.rs2,
step.v_read_write.imm,
];
StreamingBytecodeCommitment {
a_read_write: C::process(self.a_read_write, step.a_read_write),
v_read_write: self.v_read_write.into_iter().zip(step_v_read_write)
.map(|(vrd, step)| C::process(vrd, step)).collect::<Vec<_>>().try_into().unwrap(),
t_read: C::process(self.t_read, step.t_read),
}
}

/// Return the trace commitments.
pub fn finalize(self) -> Vec<C::Commitment> {
[
C::finalize(self.a_read_write),
C::finalize(self.t_read),
].into_iter().chain(
self.v_read_write.into_iter().map(|vrw| C::finalize(vrw))
).collect()
}
}

impl<F, PCS> MemoryCheckingProver<F, PCS> for BytecodeProof<F, PCS>
where
F: JoltField,
Expand Down
52 changes: 47 additions & 5 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::field::JoltField;
use crate::poly::opening_proof::{
ProverOpeningAccumulator, ReducedOpeningProof, VerifierOpeningAccumulator,
};
use crate::r1cs::inputs::StreamingR1CSCommitment;
use crate::r1cs::constraints::R1CSConstraints;
use crate::r1cs::spartan::{self, UniformSpartanProof};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand All @@ -21,14 +22,20 @@ use crate::jolt::{
VirtualInstructionSequence,
},
subtable::JoltSubtableSet,
vm::timestamp_range_check::TimestampValidityProof,
vm::{
bytecode::{
StreamingBytecodeCommitment,
StreamingBytecodePolynomials,
},
timestamp_range_check::TimestampValidityProof,
},
};
use crate::lasso::memory_checking::{
Initializable, MemoryCheckingProver, MemoryCheckingVerifier, StructuredPolynomialData,
};
use crate::poly::commitment::commitment_scheme::{BatchType, CommitmentScheme};
use crate::poly::commitment::commitment_scheme::{BatchType, CommitmentScheme, StreamingCommitmentScheme};
use crate::poly::dense_mlpoly::DensePolynomial;
use crate::r1cs::inputs::{ConstraintInput, R1CSPolynomials, R1CSProof, R1CSStuff};
use crate::r1cs::inputs::{ConstraintInput, R1CSPolynomials, R1CSProof, R1CSStuff, StreamingR1CSPolynomials};
use crate::utils::errors::ProofVerifyError;
use crate::utils::thread::drop_in_background_thread;
use crate::utils::transcript::{AppendToTranscript, ProofTranscript};
Expand Down Expand Up @@ -241,7 +248,7 @@ impl<F: JoltField> JoltPolynomials<F> {
}
}

pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, const M: usize> {
pub trait Jolt<F: JoltField, PCS: StreamingCommitmentScheme<Field = F>, const C: usize, const M: usize> {
type InstructionSet: JoltInstructionSet;
type Subtables: JoltSubtableSet<F>;
type Constraints: R1CSConstraints<C, F>;
Expand Down Expand Up @@ -332,10 +339,11 @@ pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, c
) {
let trace_length = trace.len();
let padded_trace_length = trace_length.next_power_of_two();
println!("Trace length: {}", trace_length);

JoltTraceStep::pad(&mut trace);

let mut trace2 = trace.clone();

let mut transcript = ProofTranscript::new(b"Jolt transcript");
Self::fiat_shamir_preamble(&mut transcript, &program_io, trace_length);

Expand Down Expand Up @@ -380,6 +388,27 @@ pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, c
<Self::Constraints as R1CSConstraints<C, F>>::Inputs,
>(&trace);

let streaming_bytecode_polynomials = StreamingBytecodePolynomials::<F, PCS>::new(&preprocessing.bytecode, &mut trace2);
let initialized_commitment = StreamingBytecodeCommitment::initialize(streaming_bytecode_polynomials.length(), &preprocessing.generators, &BatchType::Big);
// JP: `fold` likely isn't sufficient since we need to extract the internal state.
let streaming_trace_commitments =
streaming_bytecode_polynomials.fold(initialized_commitment, |state, step| {
StreamingBytecodeCommitment::process(state, &step)
});
let bytecode_commitments = StreamingBytecodeCommitment::finalize(streaming_trace_commitments);

let streaming_r1cs_polynomials = StreamingR1CSPolynomials::<F>::new::<
C,
M,
Self::InstructionSet,
<Self::Constraints as R1CSConstraints<C, F>>::Inputs,
>(&trace2);
let r1cs_commitments = StreamingR1CSCommitment::<C, PCS>::initialize(&streaming_r1cs_polynomials, &preprocessing.generators, &BatchType::Big);
let r1cs_commitments = streaming_r1cs_polynomials.fold(r1cs_commitments, |state, step| {
StreamingR1CSCommitment::process(state, &step)
});
let r1cs_commitments = StreamingR1CSCommitment::finalize(r1cs_commitments);

let mut jolt_polynomials = JoltPolynomials {
bytecode: bytecode_polynomials,
read_write_memory: memory_polynomials,
Expand All @@ -388,9 +417,22 @@ pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, c
r1cs: r1cs_polynomials,
};


r1cs_builder.compute_aux(&mut jolt_polynomials);

let jolt_commitments = jolt_polynomials.commit::<C, PCS>(&preprocessing);
/// TODO: Temp, remove me XXX
assert_eq!(bytecode_commitments[0], jolt_commitments.bytecode.a_read_write);
assert_eq!(bytecode_commitments[1], jolt_commitments.bytecode.t_read);
assert_eq!(bytecode_commitments[2], jolt_commitments.bytecode.v_read_write[0]);
assert_eq!(bytecode_commitments[3], jolt_commitments.bytecode.v_read_write[1]);
assert_eq!(bytecode_commitments[4], jolt_commitments.bytecode.v_read_write[2]);
assert_eq!(bytecode_commitments[5], jolt_commitments.bytecode.v_read_write[3]);
assert_eq!(bytecode_commitments[6], jolt_commitments.bytecode.v_read_write[4]);
assert_eq!(bytecode_commitments[7], jolt_commitments.bytecode.v_read_write[5]);
assert_eq!(r1cs_commitments.0, jolt_commitments.r1cs.chunks_x);
assert_eq!(r1cs_commitments.1, jolt_commitments.r1cs.chunks_y);
assert_eq!(r1cs_commitments.2, jolt_commitments.r1cs.circuit_flags);

transcript.append_scalar(&spartan_key.vk_digest);

Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/jolt/vm/rv32i_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::jolt::subtable::{
truncate_overflow::TruncateOverflowSubtable, xor::XorSubtable, JoltSubtableSet, LassoSubtable,
SubtableId,
};
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme;

/// Generates an enum out of a list of JoltInstruction types. All JoltInstruction methods
/// are callable on the enum type via enum_dispatch.
Expand Down Expand Up @@ -176,7 +176,7 @@ pub const M: usize = 1 << 16;
impl<F, PCS> Jolt<F, PCS, C, M> for RV32IJoltVM
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
type InstructionSet = RV32I;
type Subtables = RV32ISubtables<F>;
Expand Down
Loading