diff --git a/arbitrator/prover/src/machine.rs b/arbitrator/prover/src/machine.rs index 3f162da29..5c79c6612 100644 --- a/arbitrator/prover/src/machine.rs +++ b/arbitrator/prover/src/machine.rs @@ -805,6 +805,7 @@ pub struct Machine { type FrameStackHash = Bytes32; type ValueStackHash = Bytes32; +type MultiStackHash = Bytes32; type InterStackHash = Bytes32; pub(crate) fn hash_stack(stack: I, prefix: &str) -> Bytes32 @@ -859,6 +860,13 @@ fn hash_stack_frame_stack(frames: &[StackFrame]) -> FrameStackHash { hash_stack(frames.iter().map(|f| f.hash()), "Stack frame stack:") } +fn hash_multistack(multistack: &[&[T]], stack_hasher: F) -> MultiStackHash +where + F: Fn(&[T]) -> Bytes32, +{ + hash_stack(multistack.iter().map(|v| stack_hasher(v)), "cothread:") +} + #[must_use] fn prove_window(items: &[T], stack_hasher: F, encoder: G) -> Vec where @@ -901,6 +909,49 @@ where data } +// prove_multistacks encodes proof for multistacks: +// - Proof of first(main) if not cothread otherwise last +// - Hash of first if cothread, otherwise last +// - Recursive hash of the rest +// If length is < 1, hash of last element is assumed 0xff..f, same for hash +// of in-between stacks ([2nd..last)). +// Accepts prover function so that it can work both for proving stack and window. +#[must_use] +fn prove_multistack( + cothread: bool, + items: Vec<&[T]>, + stack_hasher: F, + multistack_hasher: MF, + prover: fn(&[T]) -> Vec, +) -> Vec +where + F: Fn(&[T]) -> Bytes32, + MF: Fn(&[&[T]], F) -> Bytes32, +{ + let mut data = Vec::with_capacity(33); + + if cothread { + data.extend(prover(items.last().unwrap())); + data.extend(stack_hasher(items.first().unwrap())) + } else { + data.extend(prover(items.first().unwrap())); + + let last_hash = if items.len() > 1 { + stack_hasher(items.last().unwrap()) + } else { + Machine::NO_STACK_HASH + }; + data.extend(last_hash); + } + let hash: Bytes32 = if items.len() > 2 { + multistack_hasher(&items[1..items.len() - 1], stack_hasher) + } else { + Machine::NO_STACK_HASH + }; + data.extend(hash); + data +} + #[must_use] fn exec_ibin_op(a: T, b: T, op: IBinOpType) -> Option where @@ -969,6 +1020,7 @@ pub fn get_empty_preimage_resolver() -> PreimageResolver { impl Machine { pub const MAX_STEPS: u64 = 1 << 43; + pub const NO_STACK_HASH: Bytes32 = Bytes32([255_u8; 32]); pub fn from_paths( library_paths: &[PathBuf], @@ -2448,36 +2500,82 @@ impl Machine { self.get_modules_merkle().root() } - fn stack_hashes( - &self, - ) -> ( - FrameStackHash, - ValueStackHash, - InterStackHash, - ) { + fn stack_hashes(&self) -> (FrameStackHash, ValueStackHash, InterStackHash) { macro_rules! compute { ($stack:expr, $prefix:expr) => {{ let frames = $stack.iter().map(|v| v.hash()); hash_stack(frames, concat!($prefix, " stack:")) }}; } - let frame_stack = compute!(self.get_frame_stack(), "Stack frame"); - let value_stack = compute!(self.get_data_stack(), "Value"); + // compute_multistack returns the hash of multistacks as follows: + // Keccak( + // "multistack:" + // + hash_stack(first_stack) + // + hash_stack(last_stack) + // + Keccak("cothread:" + 2nd_stack+Keccak("cothread:" + 3drd_stack + ...) + // ) + macro_rules! compute_multistack { + ($field:expr, $stacks:expr, $prefix:expr, $hasher: expr) => {{ + let first_elem = *$stacks.first().unwrap(); + let first_hash = hash_stack( + first_elem.iter().map(|v| v.hash()), + concat!($prefix, " stack:"), + ); + + let last_elem = *$stacks.last().unwrap(); + let last_hash = if $stacks.len() == 0 { + panic!("Stacks size is 0") + } else { + hash_stack( + last_elem.iter().map(|v| v.hash()), + concat!($prefix, " stack:"), + ) + }; + + // Hash of stacks [2nd..last) or 0xfff...f if len <= 2. + let mut hash = if $stacks.len() <= 2 { + Machine::NO_STACK_HASH + } else { + hash_multistack(&$stacks[1..$stacks.len() - 2], $hasher) + }; + + hash = Keccak256::new() + .chain("multistack:") + .chain(first_hash) + .chain(last_hash) + .chain(hash) + .finalize() + .into(); + hash + }}; + } + let frame_stacks = compute_multistack!( + |x| x.frame_stack, + self.get_frame_stacks(), + "Stack frame", + hash_stack_frame_stack + ); + let value_stacks = compute_multistack!( + |x| x.value_stack, + self.get_data_stacks(), + "Value", + hash_value_stack + ); let inter_stack = compute!(self.internal_stack, "Value"); - (frame_stack, value_stack, inter_stack) + (frame_stacks, value_stacks, inter_stack) } pub fn hash(&self) -> Bytes32 { let mut h = Keccak256::new(); match self.status { MachineStatus::Running => { - let (frame_stack, value_stack, inter_stack) = self.stack_hashes(); + let (frame_stacks, value_stacks, inter_stack) = self.stack_hashes(); h.update(b"Machine running:"); - h.update(value_stack); + h.update(value_stacks); h.update(inter_stack); - h.update(frame_stack); + h.update(frame_stacks); h.update(self.global_state.hash()); h.update(self.pc.module.to_be_bytes()); h.update(self.pc.func.to_be_bytes()); @@ -2515,12 +2613,13 @@ impl Machine { panic!("WASM validation failed: {text}"); }}; } - - out!(prove_stack( - self.get_data_stack(), - STACK_PROVING_DEPTH, + out!(prove_multistack( + self.cothread, + self.get_data_stacks(), hash_value_stack, - |v| v.serialize_for_proof(), + hash_multistack, + |stack| prove_stack(stack, STACK_PROVING_DEPTH, hash_value_stack, |v| v + .serialize_for_proof()), )); out!(prove_stack( @@ -2530,10 +2629,16 @@ impl Machine { |v| v.serialize_for_proof(), )); - out!(prove_window( - self.get_frame_stack(), + out!(prove_multistack( + self.cothread, + self.get_frame_stacks(), hash_stack_frame_stack, - StackFrame::serialize_for_proof, + hash_multistack, + |stack| prove_window( + stack, + hash_stack_frame_stack, + StackFrame::serialize_for_proof + ), )); out!(self.global_state.hash()); @@ -2545,6 +2650,8 @@ impl Machine { let mod_merkle = self.get_modules_merkle(); out!(mod_merkle.root()); + out!(if self.cothread { [1_u8; 1] } else { [0_u8; 1] }); + // End machine serialization, serialize module let module = &self.modules[self.pc.module()]; @@ -2745,6 +2852,23 @@ impl Machine { out!(mod_merkle.prove_any(leaf + 1)); } } + PopCoThread => { + out!(hash_value_stack(self.get_data_stacks().last().unwrap())); + out!(match self.get_data_stacks().len() { + len if len > 1 => + hash_multistack(&self.get_data_stacks()[..len - 1], hash_value_stack), + _ => Machine::NO_STACK_HASH, + }); + + out!(hash_stack_frame_stack( + self.get_frame_stacks().last().unwrap() + )); + out!(match self.get_frame_stacks().len() { + len if len > 1 => + hash_multistack(&self.get_frame_stacks()[..len - 1], hash_stack_frame_stack), + _ => Machine::NO_STACK_HASH, + }); + } _ => {} } data @@ -2757,6 +2881,10 @@ impl Machine { } } + pub fn get_data_stacks(&self) -> Vec<&[Value]> { + self.value_stacks.iter().map(|v| v.as_slice()).collect() + } + fn get_frame_stack(&self) -> &[StackFrame] { match self.cothread { false => &self.frame_stacks[0], @@ -2764,6 +2892,13 @@ impl Machine { } } + fn get_frame_stacks(&self) -> Vec<&[StackFrame]> { + self.frame_stacks + .iter() + .map(|v: &Vec| v.as_slice()) + .collect() + } + pub fn get_internals_stack(&self) -> &[Value] { &self.internal_stack }