diff --git a/Cargo.lock b/Cargo.lock index 973a4c319..a72ea6727 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -414,6 +414,7 @@ dependencies = [ "open", "percent-encoding", "pyo3", + "rand", "rstest", "thiserror", "url", diff --git a/crates/chia-datalayer/Cargo.toml b/crates/chia-datalayer/Cargo.toml index b938f12ed..afa4fcd5a 100644 --- a/crates/chia-datalayer/Cargo.toml +++ b/crates/chia-datalayer/Cargo.toml @@ -34,6 +34,7 @@ hex = { workspace = true } hex-literal = { workspace = true } open = { workspace = true } percent-encoding = { workspace = true } +rand = { workspace = true } rstest = { workspace = true } url = { workspace = true } diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index 91285f12c..6b387e834 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -334,6 +334,74 @@ pub enum NodeType { Leaf = 1, } +#[cfg_attr(feature = "py-bindings", pyclass(get_all))] +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub struct ProofOfInclusionLayer { + pub other_hash_side: Side, + pub other_hash: Hash, + pub combined_hash: Hash, +} + +#[cfg(feature = "py-bindings")] +#[pymethods] +impl ProofOfInclusionLayer { + #[new] + pub fn py_init(other_hash_side: Side, other_hash: Hash, combined_hash: Hash) -> PyResult { + Ok(Self { + other_hash_side, + other_hash, + combined_hash, + }) + } +} + +#[cfg_attr(feature = "py-bindings", pyclass(get_all))] +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub struct ProofOfInclusion { + pub node_hash: Hash, + pub layers: Vec, +} + +impl ProofOfInclusion { + pub fn root_hash(&self) -> Hash { + if let Some(last) = self.layers.last() { + last.combined_hash + } else { + self.node_hash + } + } + + pub fn valid(&self) -> bool { + let mut existing_hash = self.node_hash; + + for layer in &self.layers { + let calculated_hash = + calculate_internal_hash(&existing_hash, layer.other_hash_side, &layer.other_hash); + + if calculated_hash != layer.combined_hash { + return false; + } + + existing_hash = calculated_hash; + } + + existing_hash == self.root_hash() + } +} + +#[cfg(feature = "py-bindings")] +#[pymethods] +impl ProofOfInclusion { + #[pyo3(name = "root_hash")] + pub fn py_root_hash(&self) -> Hash { + self.root_hash() + } + #[pyo3(name = "valid")] + pub fn py_valid(&self) -> bool { + self.valid() + } +} + #[allow(clippy::needless_pass_by_value)] fn sha256_num(input: T) -> Hash { let mut hasher = Sha256::new(); @@ -358,6 +426,13 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash { Hash(Bytes32::new(hasher.finalize())) } +pub fn calculate_internal_hash(hash: &Hash, other_hash_side: Side, other_hash: &Hash) -> Hash { + match other_hash_side { + Side::Left => internal_hash(other_hash, hash), + Side::Right => internal_hash(hash, other_hash), + } +} + #[cfg_attr(feature = "py-bindings", derive(PyJsonDict, PyStreamable))] #[repr(u8)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Streamable)] @@ -407,6 +482,16 @@ impl InternalNode { Err(Error::IndexIsNotAChild(index)) } } + + pub fn get_sibling_side(&self, index: TreeIndex) -> Result { + if self.left == index { + Ok(Side::Right) + } else if self.right == index { + Ok(Side::Left) + } else { + Err(Error::IndexIsNotAChild(index)) + } + } } #[cfg_attr( @@ -488,6 +573,15 @@ impl Node { *leaf } + fn expect_internal(&self, message: &str) -> InternalNode { + let Node::Internal(internal) = self else { + let message = message.replace("<>", &format!("{self:?}")); + panic!("{}", message) + }; + + *internal + } + fn try_into_leaf(self) -> Result { match self { Node::Leaf(leaf) => Ok(leaf), @@ -825,16 +919,17 @@ impl MerkleBlob { Ok(new_leaf_index) } - pub fn batch_insert(&mut self, mut keys_values_hashes: I) -> Result<(), Error> - where - I: Iterator, - { + pub fn batch_insert( + &mut self, + mut keys_values_hashes: Vec<((KeyId, ValueId), Hash)>, + ) -> Result<(), Error> { + // OPT: perhaps go back to taking an iterator? // OPT: would it be worthwhile to hold the entire blocks? let mut indexes = vec![]; if self.key_to_index.len() <= 1 { for _ in 0..2 { - let Some(((key, value), hash)) = keys_values_hashes.next() else { + let Some(((key, value), hash)) = keys_values_hashes.pop() else { return Ok(()); }; self.insert(key, value, &hash, InsertLocation::Auto {})?; @@ -978,6 +1073,7 @@ impl MerkleBlob { panic!("not handling this case now...") } self.insert_entry_to_blob(old_leaf_parent, &parent)?; + self.mark_lineage_as_dirty(old_leaf_parent)?; self.update_parent(old_leaf_index, Some(new_internal_node_index))?; Ok(()) @@ -1012,14 +1108,14 @@ impl MerkleBlob { let Some(grandparent_index) = parent.parent.0 else { sibling_block.node.set_parent(Parent(None)); - self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?; - if let Node::Internal(node) = sibling_block.node { + sibling_block.metadata.dirty = true; for child_index in [node.left, node.right] { self.update_parent(child_index, Some(TreeIndex(0)))?; } }; + self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?; self.free_indexes.insert(sibling_index); return Ok(()); @@ -1148,7 +1244,7 @@ impl MerkleBlob { let mut block = Block::from_bytes(self.get_block_bytes(this_index)?)?; if block.metadata.dirty { - return Ok(()); + break; } block.metadata.dirty = true; @@ -1377,6 +1473,38 @@ impl MerkleBlob { .copied() .ok_or(Error::UnknownKey(key)) } + + pub fn get_proof_of_inclusion(&self, key: KeyId) -> Result { + let mut index = *self.key_to_index.get(&key).ok_or(Error::UnknownKey(key))?; + + let node = self + .get_node(index)? + .expect_leaf("key to index mapping should only have leaves"); + + let parents = self.get_lineage_with_indexes(index)?; + let mut layers: Vec = Vec::new(); + let mut parents_iter = parents.iter(); + // first in the lineage is the index itself, second is the first parent + parents_iter.next(); + for (next_index, parent) in parents_iter { + let parent = parent.expect_internal("all nodes after the first should be internal"); + let sibling_index = parent.sibling_index(index)?; + let sibling_block = self.get_block(sibling_index)?; + let sibling = sibling_block.node; + let layer = ProofOfInclusionLayer { + other_hash_side: parent.get_sibling_side(index)?, + other_hash: sibling.hash(), + combined_hash: parent.hash, + }; + layers.push(layer); + index = *next_index; + } + + Ok(ProofOfInclusion { + node_hash: node.hash, + layers, + }) + } } impl PartialEq for MerkleBlob { @@ -1553,7 +1681,7 @@ impl MerkleBlob { )); } - self.batch_insert(&mut zip(keys_values, hashes))?; + self.batch_insert(zip(keys_values, hashes).collect())?; Ok(()) } @@ -1578,6 +1706,11 @@ impl MerkleBlob { pub fn py_get_key_index(&self, key: KeyId) -> PyResult { Ok(self.get_key_index(key)?) } + + #[pyo3(name = "get_proof_of_inclusion")] + pub fn py_get_proof_of_inclusion(&self, key: KeyId) -> PyResult { + Ok(self.get_proof_of_inclusion(key)?) + } } fn try_get_block(blob: &[u8], index: TreeIndex) -> Result { @@ -1779,6 +1912,9 @@ mod tests { use super::*; use crate::merkle::dot::DotLines; use expect_test::{expect, Expect}; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::SeedableRng; use rstest::{fixture, rstest}; use std::time::{Duration, Instant}; @@ -2289,7 +2425,7 @@ mod tests { } let before = blob.get_keys_values().unwrap(); - blob.batch_insert(batch.into_iter()).unwrap(); + blob.batch_insert(batch).unwrap(); let after = blob.get_keys_values().unwrap(); open_dot( @@ -2591,4 +2727,102 @@ mod tests { assert!(free_indexes.contains(&new_index)); assert!(small_blob.free_indexes.is_empty()); } + + fn generate_kvid(seed: i64) -> (KeyId, ValueId) { + let mut kv_ids: Vec = Vec::new(); + + for offset in 0..2 { + let seed_int = 2i64 * seed + offset; + let seed_bytes = seed_int.to_be_bytes(); + let hash = sha256_bytes(&seed_bytes); + let hash_int = i64::from_be_bytes(hash.0[0..8].try_into().unwrap()); + kv_ids.push(hash_int); + } + + (KeyId(kv_ids[0]), ValueId(kv_ids[1])) + } + + fn generate_hash(seed: i64) -> Hash { + let seed_bytes = seed.to_be_bytes(); + sha256_bytes(&seed_bytes) + } + + #[test] + fn test_proof_of_inclusion() { + let num_repeats = 10; + let mut seed = 0; + + let mut random = StdRng::seed_from_u64(37); + + let mut merkle_blob = MerkleBlob::new(Vec::new()).unwrap(); + let mut keys_values: HashMap = HashMap::new(); + + for repeats in 0..num_repeats { + let num_inserts = 1 + repeats * 100; + let num_deletes = 1 + repeats * 10; + + let mut kv_ids: Vec<(KeyId, ValueId)> = Vec::new(); + let mut hashes: Vec = Vec::new(); + for _ in 0..num_inserts { + seed += 1; + let (key, value) = generate_kvid(seed); + kv_ids.push((key, value)); + hashes.push(generate_hash(seed)); + keys_values.insert(key, value); + } + + merkle_blob + .batch_insert(zip(kv_ids, hashes).collect()) + .unwrap(); + merkle_blob.calculate_lazy_hashes().unwrap(); + for i in 0..(merkle_blob.blob.len() / BLOCK_SIZE) { + let node = merkle_blob.get_node(TreeIndex(i as u32)).unwrap(); + println!("{i:05}: {node:?}"); + } + + for kv_id in keys_values.keys().copied() { + let proof_of_inclusion = match merkle_blob.get_proof_of_inclusion(kv_id) { + Ok(proof_of_inclusion) => proof_of_inclusion, + Err(error) => { + open_dot(merkle_blob.to_dot().unwrap().set_note(&error.to_string())); + panic!("here"); + } + }; + assert!(proof_of_inclusion.valid()); + } + + let mut delete_ordering: Vec = keys_values.keys().copied().collect(); + delete_ordering.shuffle(&mut random); + delete_ordering = delete_ordering[0..num_deletes].to_vec(); + for kv_id in delete_ordering.iter().copied() { + merkle_blob.delete(kv_id).unwrap(); + keys_values.remove(&kv_id); + } + + for kv_id in delete_ordering { + // with pytest.raises(Exception, match = f"unknown key: {re.escape(str(kv_id))}"): + merkle_blob + .get_proof_of_inclusion(kv_id) + .expect_err("stuff"); + } + + let mut new_keys_values: HashMap = HashMap::new(); + for old_kv in keys_values.keys().copied() { + seed += 1; + let (_, value) = generate_kvid(seed); + let hash = generate_hash(seed); + merkle_blob.upsert(old_kv, value, &hash).unwrap(); + new_keys_values.insert(old_kv, value); + } + if !merkle_blob.blob.is_empty() { + merkle_blob.calculate_lazy_hashes().unwrap(); + } + + keys_values = new_keys_values; + for kv_id in keys_values.keys().copied() { + let proof_of_inclusion = merkle_blob.get_proof_of_inclusion(kv_id).unwrap(); + assert!(proof_of_inclusion.valid()); + } + } + } } diff --git a/wheel/python/chia_rs/datalayer.pyi b/wheel/python/chia_rs/datalayer.pyi index 32f746f4e..14797cd99 100644 --- a/wheel/python/chia_rs/datalayer.pyi +++ b/wheel/python/chia_rs/datalayer.pyi @@ -177,6 +177,22 @@ class LeafNode: def replace(self, *, key: KeyId = ..., value: ValueId = ...) -> Self: ... +@final +class ProofOfInclusionLayer: + def __init__(self, parent: Optional[uint32], hash: bytes32, left: uint32, right: uint32) -> None: ... + other_hash_side: uint8 + other_hash: bytes32 + combined_hash: bytes32 + +@final +class ProofOfInclusion: + node_hash: bytes32 + # children before parents + layers: list[ProofOfInclusionLayer] + + def root_hash(self) -> bytes32: ... + def valid(self) -> bool: ... + @final class MerkleBlob: @property @@ -202,10 +218,11 @@ class MerkleBlob: def get_nodes_with_indexes(self) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ... def empty(self) -> bool: ... def get_root_hash(self) -> bytes32: ... - def batch_insert(self, keys_values: list[tuple[int64, int64]], hashes: list[bytes32]): ... + def batch_insert(self, keys_values: list[tuple[KeyId, ValueId]], hashes: list[bytes32]): ... def get_hash_at_index(self, index: TreeIndex): ... def get_keys_values(self) -> dict[KeyId, ValueId]: ... def get_key_index(self, key: KeyId) -> TreeIndex: ... + def get_proof_of_inclusion(self, key: KeyId) -> ProofOfInclusion: ... def __len__(self) -> int: ... diff --git a/wheel/src/api.rs b/wheel/src/api.rs index aac067935..82b314eb7 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -649,6 +649,8 @@ pub fn add_datalayer_submodule(py: Python<'_>, parent: &Bound<'_, PyModule>) -> datalayer.add_class::()?; datalayer.add_class::()?; datalayer.add_class::()?; + datalayer.add_class::()?; + datalayer.add_class::()?; datalayer.add("BLOCK_SIZE", BLOCK_SIZE)?; datalayer.add("DATA_SIZE", DATA_SIZE)?;