Skip to content

Commit

Permalink
Change order of elements in hashing/proving
Browse files Browse the repository at this point in the history
  • Loading branch information
anodar committed Jan 29, 2024
1 parent e925b7c commit 4b32ed3
Showing 1 changed file with 56 additions and 44 deletions.
100 changes: 56 additions & 44 deletions arbitrator/prover/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,15 +904,14 @@ where
data
}

// Prove stacks encodes proof for stack of stacks.
// prove_multistacks encodes proof for stack of stacks.
// Layout of the proof is:
// - size of a stack of stacks
// - proof of first stack
// - proof of last stack
// - hash of everything in between
// - Proof of first(main) stack
// - Hash of last stack (or 0 if size is 1)
// - Recursive hash of the rest
// Accepts prover function so that it can work both for proving stack and window.
#[must_use]
fn prove_stacks<T, F>(
fn prove_multistack<T, F>(
items: Vec<&[T]>,
stack_hasher: F,
prover: fn(&[T])-> Vec<u8>,
Expand All @@ -921,19 +920,25 @@ where
F: Fn(&[T]) -> Bytes32,
{
let mut data = Vec::with_capacity(33);
// Add Stack size.
data.extend(Bytes32::from(items.len()));
// Hash go thread (first stack).
// Proof of the first stack (main).
data.extend(prover(items.first().unwrap()));
// Hash of the last element or 0 if size is 1.
let mut last_hash = Bytes32::default();
if items.len() > 1 {
// Hash last co thread (currently executing).
data.extend(prover(items.last().unwrap()));
last_hash = stack_hasher(items.last().unwrap());
}
// Hash stacks in between.
for st in &items[1..items.len() - 1] {
let hash = stack_hasher(st);
data.extend(hash.as_ref());
data.extend(last_hash);
let mut hash: Bytes32 = Bytes32::default();

for st in &items[1..items.len()-1]{
hash = Keccak256::new()
.chain("cothread: ")
.chain(stack_hasher(st))
.chain(hash)
.finalize()
.into();
}
data.extend(hash);
data
}

Expand Down Expand Up @@ -2502,7 +2507,7 @@ impl Machine {
self.get_modules_merkle().root()
}

fn stack_hashes(
fn multistack_hash(
&self,
) -> (
FrameStackHash,
Expand All @@ -2517,34 +2522,41 @@ impl Machine {
hash_stack_with_heights(frames, &heights, concat!($prefix, " stack:"))
}};
}

// compute_slices works similarly as compute, except for vector of slices.
// Instead of constructing vector of each Value's hash, it constructs
// hash of each slice (&[Value]). Rest is the same.
macro_rules! compute_vec {
macro_rules! compute_multistack {
($field:expr, $stacks:expr, $prefix:expr) => {{
let heights: Vec<_> = self.guards.iter().map($field).collect();

// Map each &[Value] slice to its hash using `hash_stack` and collect these hashes into a vector
// Reorder stacks for hashing to follow same order as in proof:
// [0, n-1, 1, 2, ..., n - 2].
let ordered_stacks: Vec<_> = match $stacks.len() {
0 => Vec::new(),
1 => vec![*$stacks.first().unwrap()],
_ => std::iter::once(*$stacks.first().unwrap())
.chain(std::iter::once(*$stacks.last().unwrap()))
.chain($stacks.iter().skip(1).take($stacks.len() - 2).cloned())
.collect()
let first_elem = *$stacks.first().unwrap();
let last_elem = *$stacks.last().unwrap();
let first_hash = hash_stack(first_elem.iter().map(|v|v.hash()), $prefix);
let last_hash = match $stacks.len() {
0 => panic!("Stacks size is 0"),
1 => Bytes32::default(),
_ => hash_stack(last_elem.iter().map(|v|v.hash()), $prefix)
};
let slice_hashes: Vec<_> = ordered_stacks.iter()
.map(|slice| hash_stack(slice.iter().map(|v| v.hash()), concat!($prefix, " stack:")))
.collect();
hash_stack_with_heights(slice_hashes, &heights, concat!($prefix, " stack:"))

let mut hash = Keccak256::new()
.chain($prefix)
.chain("multistack: ")
.chain(first_hash)
.chain(last_hash)
.finalize()
.into();
if $stacks.len() > 2 {
for item in $stacks.iter().skip(1).take($stacks.len() - 2) {
hash = Keccak256::new()
.chain("cothread: ")
.chain(hash_stack(item.iter().map(|v| v.hash()), $prefix))
.chain(hash)
.finalize()
.into();
}
}
hash
}};
}
let (frame_stack, frames) =
compute_vec!(|x| x.frame_stack, self.get_frame_stacks(), "Stack frame");
let (value_stack, values) = compute_vec!(|x| x.value_stack, self.get_data_stacks(), "Value");
let frame_stack = compute_multistack!(|x| x.frame_stack, self.get_frame_stacks(), "Stack frame");
let value_stack = compute_multistack!(|x| x.value_stack, self.get_data_stacks(), "Value");
let (_, frames) = compute!(|x| x.frame_stack, self.get_frame_stack(), "Stack frame");
let (_, values) = compute!(|x| x.value_stack, self.get_data_stack(), "Value");
let (inter_stack, inters) = compute!(|x| x.inter_stack, self.internal_stack, "Value");

let pcs = self.guards.iter().map(|x| x.on_error);
Expand All @@ -2563,7 +2575,7 @@ impl Machine {
let mut h = Keccak256::new();
match self.status {
MachineStatus::Running => {
let (frame_stack, value_stack, inter_stack, guards) = self.stack_hashes();
let (frame_stack, value_stack, inter_stack, guards) = self.multistack_hash();

h.update(b"Machine running:");
h.update(value_stack);
Expand Down Expand Up @@ -2611,7 +2623,7 @@ impl Machine {
panic!("WASM validation failed: {text}");
}};
}
out!(prove_stacks(
out!(prove_multistack(
self.get_data_stacks(),
hash_value_stack,
|stack| prove_stack(
Expand All @@ -2629,7 +2641,7 @@ impl Machine {
|v| v.serialize_for_proof(),
));

out!(prove_stacks(
out!(prove_multistack(
self.get_frame_stacks(),
hash_stack_frame_stack,
|stack| prove_window(
Expand Down Expand Up @@ -2857,7 +2869,7 @@ impl Machine {

fn prove_guards(&self) -> Vec<u8> {
let mut data = Vec::with_capacity(34); // size in the empty case
let guards = self.stack_hashes().3;
let guards = self.multistack_hash().3;
let empty = self.guards.is_empty();

data.push(empty as u8);
Expand Down

0 comments on commit 4b32ed3

Please sign in to comment.