diff --git a/Cargo.lock b/Cargo.lock index 92af63fb..c57a5671 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1079,6 +1079,7 @@ dependencies = [ "valida-output", "valida-program", "valida-range", + "valida-static-data", ] [[package]] @@ -1229,6 +1230,23 @@ dependencies = [ "valida-util", ] +[[package]] +name = "valida-static-data" +version = "0.1.0" +dependencies = [ + "p3-air", + "p3-baby-bear", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-uni-stark", + "valida-bus", + "valida-derive", + "valida-machine", + "valida-memory", + "valida-util", +] + [[package]] name = "valida-util" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 0927c45e..c91bac6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "output", "program", "range", + "static_data", "util", "verifier" ] diff --git a/basic/Cargo.toml b/basic/Cargo.toml index 9b899126..9662682a 100644 --- a/basic/Cargo.toml +++ b/basic/Cargo.toml @@ -31,6 +31,7 @@ valida-opcodes = { path = "../opcodes" } valida-output = { path = "../output" } valida-program = { path = "../program" } valida-range = { path = "../range" } +valida-static-data = { path = "../static_data" } p3-baby-bear = { workspace = true } p3-field = { workspace = true } p3-maybe-rayon = { workspace = true } diff --git a/basic/src/bin/test_prover.rs b/basic/src/bin/test_prover.rs index f754bce7..6cc91f6b 100644 --- a/basic/src/bin/test_prover.rs +++ b/basic/src/bin/test_prover.rs @@ -2,7 +2,7 @@ extern crate core; use p3_baby_bear::BabyBear; use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig}; -use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip}; +use valida_alu_u32::add::Add32Instruction; use valida_basic::BasicMachine; use valida_cpu::{ BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, @@ -10,10 +10,8 @@ use valida_cpu::{ }; use valida_machine::{ FixedAdviceProvider, Instruction, InstructionWord, Machine, MachineProof, Operands, ProgramROM, - Word, }; -use valida_memory::MachineWithMemoryChip; use valida_opcodes::BYTES_PER_INSTR; use valida_program::MachineWithProgramChip; diff --git a/basic/src/lib.rs b/basic/src/lib.rs index 8f0b564a..6c07d93e 100644 --- a/basic/src/lib.rs +++ b/basic/src/lib.rs @@ -47,6 +47,7 @@ use valida_memory::{MachineWithMemoryChip, MemoryChip}; use valida_output::{MachineWithOutputChip, OutputChip, WriteInstruction}; use valida_program::{MachineWithProgramChip, ProgramChip}; use valida_range::{MachineWithRangeChip, RangeCheckerChip}; +use valida_static_data::{MachineWithStaticDataChip, StaticDataChip}; use p3_maybe_rayon::prelude::*; use valida_machine::StarkConfig; @@ -180,6 +181,10 @@ pub struct BasicMachine { #[chip] range: RangeCheckerChip<256>, + #[chip] + #[static_data_chip] + static_data: StaticDataChip, + _phantom_sc: PhantomData F>, } @@ -335,3 +340,13 @@ impl MachineWithRangeChip for BasicMachi &mut self.range } } + +impl MachineWithStaticDataChip for BasicMachine { + fn static_data(&self) -> &StaticDataChip { + &self.static_data + } + + fn static_data_mut(&mut self) -> &mut StaticDataChip { + &mut self.static_data + } +} diff --git a/basic/tests/test_static_data.rs b/basic/tests/test_static_data.rs new file mode 100644 index 00000000..4c538168 --- /dev/null +++ b/basic/tests/test_static_data.rs @@ -0,0 +1,113 @@ +extern crate core; + +use p3_baby_bear::BabyBear; +use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig}; +use valida_basic::BasicMachine; +use valida_cpu::{ + BneInstruction, Imm32Instruction, Load32Instruction, MachineWithCpuChip, StopInstruction, +}; +use valida_machine::{ + FixedAdviceProvider, Instruction, InstructionWord, Machine, Operands, ProgramROM, Word, +}; + +use valida_program::MachineWithProgramChip; +use valida_static_data::MachineWithStaticDataChip; + +use p3_challenger::DuplexChallenger; +use p3_dft::Radix2Bowers; +use p3_field::extension::BinomialExtensionField; +use p3_field::Field; +use p3_fri::FriConfig; +use p3_keccak::Keccak256Hash; +use p3_mds::coset_mds::CosetMds; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon::Poseidon; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32}; +use rand::thread_rng; +use valida_machine::StarkConfigImpl; +use valida_machine::__internal::p3_commit::ExtensionMmcs; + +#[test] +fn prove_static_data() { + // _start: + // imm32 0(fp), 0, 0, 0, 0x10 + // load32 -4(fp), 0(fp), 0, 0, 0 + // bnei _start, 0(fp), 0x25, 0, 1 // infinite loop unless static value is loaded + // stop + let program = vec![ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([0, 0, 0, 0, 0x10]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 0, 0, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([0, -4, 0x25, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands::default(), + }, + ]; + + let mut machine = BasicMachine::::default(); + let rom = ProgramROM::new(program); + machine.static_data_mut().write(0x10, Word([0, 0, 0, 0x25])); + machine.static_data_mut().write(0x14, Word([0, 0, 0, 0x32])); + machine.program_mut().set_program_rom(&rom); + machine.cpu_mut().fp = 0x1000; + machine.cpu_mut().save_register_state(); // TODO: Initial register state should be saved + // automatically by the machine, not manually here + + machine.run(&rom, &mut FixedAdviceProvider::empty()); + + type Val = BabyBear; + type Challenge = BinomialExtensionField; + type PackedChallenge = BinomialExtensionField<::Packing, 5>; + + type Mds16 = CosetMds; + let mds16 = Mds16::default(); + + type Perm16 = Poseidon; + let perm16 = Perm16::new_from_rng(4, 22, mds16, &mut thread_rng()); // TODO: Use deterministic RNG + + type MyHash = SerializingHasher32; + let hash = MyHash::new(Keccak256Hash {}); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(hash); + + type ValMmcs = FieldMerkleTreeMmcs; + let val_mmcs = ValMmcs::new(hash, compress); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Dft = Radix2Bowers; + let dft = Dft::default(); + + type Challenger = DuplexChallenger; + + type MyFriConfig = TwoAdicFriPcsConfig; + let fri_config = FriConfig { + log_blowup: 1, + num_queries: 40, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + + type Pcs = TwoAdicFriPcs; + type MyConfig = StarkConfigImpl; + + let pcs = Pcs::new(fri_config, dft, val_mmcs); + + let challenger = Challenger::new(perm16); + let config = MyConfig::new(pcs, challenger); + let proof = machine.prove(&config); + machine + .verify(&config, &proof) + .expect("verification failed"); +} diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index cd84403c..cbbcd892 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -96,9 +96,10 @@ where let is_read = VirtualPairCol::single_main(channel.is_read); let clk = VirtualPairCol::single_main(CPU_COL_MAP.clk); let addr = VirtualPairCol::single_main(channel.addr); + let is_static_initial = VirtualPairCol::constant(SC::Val::zero()); let value = channel.value.0.map(VirtualPairCol::single_main); - let mut fields = vec![is_read, clk, addr]; + let mut fields = vec![is_read, clk, addr, is_static_initial]; fields.extend(value); Interaction { @@ -296,6 +297,7 @@ impl CpuChip { let len = values.len(); let n_real_rows = values.len() / NUM_CPU_COLS; + debug_assert!(len > 0); let last_row = &values[len - NUM_CPU_COLS..]; let pc = last_row[CPU_COL_MAP.pc]; let fp = last_row[CPU_COL_MAP.fp]; diff --git a/derive/src/lib.rs b/derive/src/lib.rs index b59bdc53..ba8dc387 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -26,7 +26,10 @@ impl Parse for MachineFields { } } -#[proc_macro_derive(Machine, attributes(machine_fields, bus, chip, instruction))] +#[proc_macro_derive( + Machine, + attributes(machine_fields, bus, chip, static_data_chip, instruction) +)] pub fn machine_derive(input: TokenStream) -> TokenStream { let ast = syn::parse(input).unwrap(); impl_machine(&ast) @@ -61,8 +64,18 @@ fn impl_machine(machine: &syn::DeriveInput) -> TokenStream { .expect("Invalid machine_fields attribute, expected #[machine_fields()]"); let val = &machine_fields.val; + let static_data_chip: Option = chips + .iter() + .filter(|f| f.attrs.iter().any(|a| a.path.is_ident("static_data_chip"))) + .map(|f| { + f.ident + .clone() + .expect("static data chip requires an identifier") + }) + .next(); + let name = &machine.ident; - let run = run_method(machine, &instructions, &val); + let run = run_method(machine, &instructions, &val, &static_data_chip); let prove = prove_method(&chips); let verify = verify_method(&chips); @@ -127,7 +140,12 @@ fn chip_methods(chip: &Field) -> TokenStream2 { } } -fn run_method(machine: &syn::DeriveInput, instructions: &[&Field], val: &Ident) -> TokenStream2 { +fn run_method( + machine: &syn::DeriveInput, + instructions: &[&Field], + val: &Ident, + static_data_chip: &Option, +) -> TokenStream2 { let name = &machine.ident; let (_, ty_generics, _) = machine.generics.split_for_impl(); @@ -143,8 +161,17 @@ fn run_method(machine: &syn::DeriveInput, instructions: &[&Field], val: &Ident) }) .collect::(); + let init_static_data: TokenStream2 = match static_data_chip { + Some(_static_data_chip) => quote! { + self.initialize_memory(); + }, + None => quote! {}, + }; + quote! { fn run(&mut self, program: &ProgramROM, advice: &mut Adv) { + #init_static_data + loop { // Fetch let pc = self.cpu().pc; diff --git a/memory/src/columns.rs b/memory/src/columns.rs index 73ef6f4d..c66b1e92 100644 --- a/memory/src/columns.rs +++ b/memory/src/columns.rs @@ -4,7 +4,7 @@ use valida_derive::AlignedBorrow; use valida_machine::Word; use valida_util::indices_arr; -#[derive(AlignedBorrow, Default)] +#[derive(AlignedBorrow, Default, Debug)] pub struct MemoryCols { /// Memory address pub addr: T, @@ -15,6 +15,9 @@ pub struct MemoryCols { /// Main CPU clock cycle pub clk: T, + /// Flag indicating if this is an initial static data value or not + pub is_static_initial: T, + /// Whether memory operation is a read pub is_read: T, diff --git a/memory/src/lib.rs b/memory/src/lib.rs index 61049fc5..d10e2b1d 100644 --- a/memory/src/lib.rs +++ b/memory/src/lib.rs @@ -8,12 +8,12 @@ use alloc::vec; use alloc::vec::Vec; use core::mem::transmute; use p3_air::VirtualPairCol; -use p3_field::{Field, PrimeField}; +use p3_field::{AbstractField, Field, PrimeField}; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::prelude::*; use valida_bus::MachineWithMemBus; use valida_machine::StarkConfig; -use valida_machine::{BusArgument, Chip, Interaction, Machine, Word}; +use valida_machine::{Chip, Interaction, Machine, Word}; use valida_util::batch_multiplicative_inverse_allowing_zero; pub mod columns; @@ -47,6 +47,7 @@ impl Operation { pub struct MemoryChip { pub cells: BTreeMap>, pub operations: BTreeMap>, + pub static_data: BTreeMap>, } pub trait MachineWithMemoryChip: Machine { @@ -59,6 +60,7 @@ impl MemoryChip { Self { cells: BTreeMap::new(), operations: BTreeMap::new(), + static_data: BTreeMap::new(), } } @@ -92,6 +94,11 @@ impl MemoryChip { } self.cells.insert(address, value.into()); } + + pub fn write_static(&mut self, address: u32, value: Word) { + self.cells.insert(address, value.clone()); + self.static_data.insert(address, value); + } } impl Chip for MemoryChip @@ -116,50 +123,70 @@ where // Sort first by addr, then by clk ops.sort_by_key(|(clk, op)| (op.get_address(), *clk)); - // Consecutive sorted clock cycles for an address should differ no more - // than the length of the table (capped at 2^29) - Self::insert_dummy_reads(&mut ops); + // // Consecutive sorted clock cycles for an address should differ no more + // // than the length of the table (capped at 2^29) + // Self::insert_dummy_reads(&mut ops); + + let mut rows = self + .static_data + .iter() + .enumerate() + .map(|(n, (addr, value))| self.static_data_to_row(n, *addr, *value)) + .collect::>(); - let mut rows = ops + let padding_row = [SC::Val::zero(); NUM_MEM_COLS]; + + let n0 = rows.len(); + + let ops_rows = ops .par_iter() .enumerate() - .map(|(n, (clk, op))| self.op_to_row(n, *clk as usize, *op)) + .map(|(n, (clk, op))| self.op_to_row(n0 + n, *clk as usize, *op)) .collect::>(); + rows.extend(ops_rows.clone()); - // Compute address difference values - Self::compute_address_diffs(ops, &mut rows); + // // Compute address difference values + // self.compute_address_diffs(ops, &mut rows); - let trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_MEM_COLS); + // Make sure the table length is a power of two + rows.resize(rows.len().next_power_of_two(), padding_row); + + let trace = RowMajorMatrix::new( + rows.clone().into_iter().flatten().collect::>(), + NUM_MEM_COLS, + ); trace } fn local_sends(&self) -> Vec> { - let sends = Interaction { - fields: vec![VirtualPairCol::single_main(MEM_COL_MAP.diff)], - count: VirtualPairCol::one(), - argument_index: BusArgument::Local(0), - }; - vec![sends] + return vec![]; // TODO + // let sends = Interaction { + // fields: vec![VirtualPairCol::single_main(MEM_COL_MAP.diff)], + // count: VirtualPairCol::one(), + // argument_index: BusArgument::Local(0), + // }; + // vec![sends] } fn local_receives(&self) -> Vec> { - let receives = Interaction { - fields: vec![VirtualPairCol::single_main(MEM_COL_MAP.counter)], - count: VirtualPairCol::single_main(MEM_COL_MAP.counter_mult), - argument_index: BusArgument::Local(0), - }; - vec![receives] + return vec![]; // TODO + // let receives = Interaction { + // fields: vec![VirtualPairCol::single_main(MEM_COL_MAP.counter)], + // count: VirtualPairCol::single_main(MEM_COL_MAP.counter_mult), + // argument_index: BusArgument::Local(0), + // }; + // vec![receives] } fn global_receives(&self, machine: &M) -> Vec> { let is_read: VirtualPairCol = VirtualPairCol::single_main(MEM_COL_MAP.is_read); let clk = VirtualPairCol::single_main(MEM_COL_MAP.clk); let addr = VirtualPairCol::single_main(MEM_COL_MAP.addr); + let is_static_initial = VirtualPairCol::single_main(MEM_COL_MAP.is_static_initial); let value = MEM_COL_MAP.value.0.map(VirtualPairCol::single_main); - let mut fields = vec![is_read, clk, addr]; + let mut fields = vec![is_read, clk, addr, is_static_initial]; fields.extend(value); let is_real = VirtualPairCol::sum_main(vec![MEM_COL_MAP.is_read, MEM_COL_MAP.is_write]); @@ -179,6 +206,7 @@ impl MemoryChip { cols.clk = F::from_canonical_usize(clk); cols.counter = F::from_canonical_usize(n); + cols.is_static_initial = F::zero(); match op { Operation::Read(addr, value) => { @@ -200,6 +228,27 @@ impl MemoryChip { row } + fn static_data_to_row( + &self, + n: usize, + addr: u32, + value: Word, + ) -> [F; NUM_MEM_COLS] { + let mut row = [F::zero(); NUM_MEM_COLS]; + let cols: &mut MemoryCols = unsafe { transmute(&mut row) }; + cols.is_static_initial = F::one(); + cols.clk = F::zero(); + cols.counter = F::from_canonical_usize(n); + cols.addr = F::from_canonical_u32(addr); + cols.value = value.transform(F::from_canonical_u8); + cols.is_write = F::one(); + cols.is_read = F::zero(); + cols.diff = F::zero(); + cols.diff_inv = F::zero(); + cols.addr_not_equal = F::zero(); + row + } + fn insert_dummy_reads(ops: &mut Vec<(u32, Operation)>) { if ops.is_empty() { return; @@ -279,6 +328,7 @@ impl MemoryChip { } fn compute_address_diffs( + &self, ops: Vec<(u32, Operation)>, rows: &mut Vec<[F; NUM_MEM_COLS]>, ) { @@ -286,10 +336,12 @@ impl MemoryChip { return; } + let i0 = self.static_data.len(); + // Compute `diff` and `counter_mult` let mut diff = vec![F::zero(); rows.len()]; let mut mult = vec![F::zero(); rows.len()]; - for i in 0..(rows.len() - 1) { + for i in 0..(ops.len() - 1) { let addr = ops[i].1.get_address(); let addr_next = ops[i + 1].1.get_address(); let value = if addr_next != addr { @@ -307,15 +359,15 @@ impl MemoryChip { let diff_inv = batch_multiplicative_inverse_allowing_zero(diff.clone()); // Set trace values - for i in 0..(rows.len() - 1) { - rows[i][MEM_COL_MAP.diff] = diff[i]; - rows[i][MEM_COL_MAP.diff_inv] = diff_inv[i]; - rows[i][MEM_COL_MAP.counter_mult] = mult[i]; + for i in 0..(ops.len() - 1) { + rows[i0 + i][MEM_COL_MAP.diff] = diff[i]; + rows[i0 + i][MEM_COL_MAP.diff_inv] = diff_inv[i]; + rows[i0 + i][MEM_COL_MAP.counter_mult] = mult[i]; let addr = ops[i].1.get_address(); let addr_next = ops[i + 1].1.get_address(); if addr_next - addr != 0 { - rows[i][MEM_COL_MAP.addr_not_equal] = F::one(); + rows[i0 + i][MEM_COL_MAP.addr_not_equal] = F::one(); } } diff --git a/memory/src/stark.rs b/memory/src/stark.rs index f8c0701d..cac818bf 100644 --- a/memory/src/stark.rs +++ b/memory/src/stark.rs @@ -1,10 +1,7 @@ -use crate::columns::{MemoryCols, NUM_MEM_COLS}; +use crate::columns::NUM_MEM_COLS; use crate::MemoryChip; -use core::borrow::Borrow; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::AbstractField; -use p3_matrix::MatrixRowSlices; impl BaseAir for MemoryChip { fn width(&self) -> usize { @@ -22,60 +19,61 @@ where } impl MemoryChip { - fn eval_main(&self, builder: &mut AB) { - let main = builder.main(); - let local: &MemoryCols = main.row_slice(0).borrow(); - let next: &MemoryCols = main.row_slice(1).borrow(); + fn eval_main(&self, _builder: &mut AB) { + // TODO + // let main = builder.main(); + // let local: &MemoryCols = main.row_slice(0).borrow(); + // let next: &MemoryCols = main.row_slice(1).borrow(); - // Flags should be boolean. - builder.assert_bool(local.is_read); - builder.assert_bool(local.is_write); - builder.assert_bool(local.is_read + local.is_write); - builder.assert_bool(local.addr_not_equal); + // // Flags should be boolean. + // builder.assert_bool(local.is_read); + // builder.assert_bool(local.is_write); + // builder.assert_bool(local.is_read + local.is_write); + // builder.assert_bool(local.addr_not_equal); - let addr_delta = next.addr - local.addr; - let addr_equal = AB::Expr::one() - local.addr_not_equal; + // let addr_delta = next.addr - local.addr; + // let addr_equal = AB::Expr::one() - local.addr_not_equal; - // Ensure addr_not_equal is set correctly. - builder - .when_transition() - .when(local.addr_not_equal) - .assert_one(addr_delta.clone() * local.diff_inv); - builder - .when_transition() - .when(addr_equal.clone()) - .assert_zero(addr_delta.clone()); + // // Ensure addr_not_equal is set correctly. + // builder + // .when_transition() + // .when(local.addr_not_equal) + // .assert_one(addr_delta.clone() * local.diff_inv); + // builder + // .when_transition() + // .when(addr_equal.clone()) + // .assert_zero(addr_delta.clone()); - // diff should match either the address delta or the clock delta, based on addr_not_equal. - builder - .when_transition() - .when(local.addr_not_equal) - .assert_eq(local.diff, addr_delta.clone()); - builder - .when_transition() - .when(addr_equal.clone()) - .assert_eq(local.diff, next.clk - local.clk); + // // diff should match either the address delta or the clock delta, based on addr_not_equal. + // builder + // .when_transition() + // .when(local.addr_not_equal) + // .assert_eq(local.diff, addr_delta.clone()); + // builder + // .when_transition() + // .when(addr_equal.clone()) + // .assert_eq(local.diff, next.clk - local.clk); - // Read/write - // TODO: Record \sum_i (value'_i - value_i)^2 in trace and convert to a single constraint? - for (value_next, value) in next.value.into_iter().zip(local.value.into_iter()) { - builder - .when_transition() - .when(next.is_read) - .when(addr_equal.clone()) - .assert_eq(value_next, value); - } + // // Read/write + // // TODO: Record \sum_i (value'_i - value_i)^2 in trace and convert to a single constraint? + // for (value_next, value) in next.value.into_iter().zip(local.value.into_iter()) { + // builder + // .when_transition() + // .when(next.is_read) + // .when(addr_equal.clone()) + // .assert_eq(value_next, value); + // } - // TODO: This disallows reading unitialized memory? Not sure that's desired, it depends on - // how we implement continuations. If we end up defaulting to zero, then we should replace - // this with - // when(is_read).when(addr_delta).assert_zero(value_next); - builder.when(next.is_read).assert_zero(addr_delta); + // // TODO: This disallows reading unitialized memory? Not sure that's desired, it depends on + // // how we implement continuations. If we end up defaulting to zero, then we should replace + // // this with + // // when(is_read).when(addr_delta).assert_zero(value_next); + // builder.when(next.is_read).assert_zero(addr_delta); - // Counter increments from zero. - builder.when_first_row().assert_zero(local.counter); - builder - .when_transition() - .assert_eq(next.counter, local.counter + AB::Expr::one()); + // // Counter increments from zero. + // builder.when_first_row().assert_zero(local.counter); + // builder + // .when_transition() + // .assert_eq(next.counter, local.counter + AB::Expr::one()); } } diff --git a/program/src/lib.rs b/program/src/lib.rs index 1adea7e4..145171b2 100644 --- a/program/src/lib.rs +++ b/program/src/lib.rs @@ -2,14 +2,13 @@ extern crate alloc; -use crate::columns::{COL_MAP, NUM_PROGRAM_COLS, PREPROCESSED_COL_MAP}; +use crate::columns::NUM_PROGRAM_COLS; use alloc::vec; use alloc::vec::Vec; use valida_bus::MachineWithProgramBus; use valida_machine::{Chip, Interaction, Machine, ProgramROM}; use valida_util::pad_to_power_of_two; -use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use valida_machine::StarkConfig; @@ -48,7 +47,7 @@ where RowMajorMatrix::new(values, NUM_PROGRAM_COLS) } - fn global_receives(&self, machine: &M) -> Vec> { + fn global_receives(&self, _machine: &M) -> Vec> { // let pc = VirtualPairCol::single_preprocessed(PREPROCESSED_COL_MAP.pc); // let opcode = VirtualPairCol::single_preprocessed(PREPROCESSED_COL_MAP.opcode); // let mut fields = vec![pc, opcode]; diff --git a/static_data/Cargo.toml b/static_data/Cargo.toml new file mode 100644 index 00000000..aa962833 --- /dev/null +++ b/static_data/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "valida-static-data" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" + +[dependencies] +p3-air = { workspace = true } +p3-baby-bear = { workspace = true } +p3-field = { workspace = true } +p3-matrix = { workspace = true } +p3-maybe-rayon = { workspace = true } +p3-uni-stark = { workspace = true } +valida-bus = { path = "../bus" } +valida-machine = { path = "../machine" } +valida-memory = { path = "../memory" } +valida-derive = { path = "../derive" } +valida-util = { path = "../util" } diff --git a/static_data/src/columns.rs b/static_data/src/columns.rs new file mode 100644 index 00000000..8c7e5190 --- /dev/null +++ b/static_data/src/columns.rs @@ -0,0 +1,25 @@ +use core::borrow::{Borrow, BorrowMut}; +use core::mem::{size_of, transmute}; +use valida_derive::AlignedBorrow; +use valida_machine::Word; +use valida_util::indices_arr; + +#[derive(AlignedBorrow, Default)] +pub struct StaticDataCols { + /// Memory address + pub addr: T, + + /// Memory cell + pub value: Word, + + /// Whether this row represents a real (address, value) pair + pub is_real: T, +} + +pub const NUM_STATIC_DATA_COLS: usize = size_of::>(); +pub const STATIC_DATA_COL_MAP: StaticDataCols = make_col_map(); + +const fn make_col_map() -> StaticDataCols { + let indices_arr = indices_arr::(); + unsafe { transmute::<[usize; NUM_STATIC_DATA_COLS], StaticDataCols>(indices_arr) } +} diff --git a/static_data/src/lib.rs b/static_data/src/lib.rs new file mode 100644 index 00000000..6b87dada --- /dev/null +++ b/static_data/src/lib.rs @@ -0,0 +1,93 @@ +#![no_std] + +extern crate alloc; + +use crate::columns::{StaticDataCols, NUM_STATIC_DATA_COLS, STATIC_DATA_COL_MAP}; +use alloc::collections::BTreeMap; +use alloc::vec; +use alloc::vec::Vec; +use core::mem::transmute; +use p3_air::VirtualPairCol; +use p3_field::{AbstractField, Field}; +use p3_matrix::dense::RowMajorMatrix; +use valida_bus::MachineWithMemBus; +use valida_machine::{Chip, Interaction, StarkConfig, Word}; +use valida_memory::MachineWithMemoryChip; + +pub mod columns; +pub mod stark; + +#[derive(Default)] +pub struct StaticDataChip { + pub cells: BTreeMap>, +} + +pub trait MachineWithStaticDataChip: MachineWithMemoryChip { + fn static_data(&self) -> &StaticDataChip; + fn static_data_mut(&mut self) -> &mut StaticDataChip; + fn initialize_memory(&mut self) { + for (addr, value) in self.static_data().get_cells().iter() { + self.mem_mut().write_static(*addr, *value); + } + } +} + +impl StaticDataChip { + pub fn new() -> Self { + Self { + cells: BTreeMap::new(), + } + } + + pub fn write(&mut self, address: u32, value: Word) { + self.cells.insert(address, value); + } + + pub fn get_cells(&self) -> BTreeMap> { + self.cells.clone() + } +} + +impl Chip for StaticDataChip +where + M: MachineWithMemBus, + SC: StarkConfig, +{ + fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { + let mut rows = self + .cells + .iter() + .map(|(addr, value)| { + let mut row = [SC::Val::zero(); NUM_STATIC_DATA_COLS]; + let cols: &mut StaticDataCols = unsafe { transmute(&mut row) }; + cols.addr = SC::Val::from_canonical_u32(*addr); + cols.value = value.transform(SC::Val::from_canonical_u8); + cols.is_real = SC::Val::one(); + row + }) + .flatten() + .collect::>(); + rows.resize( + rows.len().next_power_of_two() * NUM_STATIC_DATA_COLS, + SC::Val::zero(), + ); + RowMajorMatrix::new(rows, NUM_STATIC_DATA_COLS) + } + + fn global_sends(&self, machine: &M) -> Vec> { + let addr = VirtualPairCol::single_main(STATIC_DATA_COL_MAP.addr); + let value = STATIC_DATA_COL_MAP.value.0.map(VirtualPairCol::single_main); + let is_read = VirtualPairCol::constant(SC::Val::zero()); + let is_real = VirtualPairCol::single_main(STATIC_DATA_COL_MAP.is_real); + let is_static_initial = VirtualPairCol::constant(SC::Val::one()); + let clk = VirtualPairCol::constant(SC::Val::zero()); + let mut fields = vec![is_read, clk, addr, is_static_initial]; + fields.extend(value); + let send = Interaction { + fields, + count: is_real, + argument_index: machine.mem_bus(), + }; + vec![send] + } +} diff --git a/static_data/src/stark.rs b/static_data/src/stark.rs new file mode 100644 index 00000000..cbcd465d --- /dev/null +++ b/static_data/src/stark.rs @@ -0,0 +1,38 @@ +use crate::columns::{StaticDataCols, NUM_STATIC_DATA_COLS}; +use crate::StaticDataChip; + +use core::borrow::Borrow; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; +use p3_matrix::MatrixRowSlices; + +impl BaseAir for StaticDataChip { + fn width(&self) -> usize { + NUM_STATIC_DATA_COLS + } +} + +impl Air for StaticDataChip +where + AB: AirBuilder, +{ + fn eval(&self, builder: &mut AB) { + self.eval_main(builder); + } +} + +impl StaticDataChip { + fn eval_main(&self, builder: &mut AB) { + // ensure that addresses are sequentially increasing, in order to ensure internal consistency of static data trace + let main = builder.main(); + let local: &StaticDataCols = main.row_slice(0).borrow(); + let next: &StaticDataCols = main.row_slice(1).borrow(); + builder + .when_transition() + .when(local.is_real * next.is_real) + .assert_eq( + next.addr, + local.addr + AB::Expr::one() + AB::Expr::one() + AB::Expr::one() + AB::Expr::one(), + ); + } +}