Skip to content

Commit

Permalink
Merge pull request #902 from Chia-Network/add_datalayer_proof_of_incl…
Browse files Browse the repository at this point in the history
…usion

add datalayer proof of inclusion generation (and other fixes to make it pass tests)
  • Loading branch information
altendky authored Feb 18, 2025
2 parents b352414 + d5846c0 commit cbbfa26
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 11 deletions.
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ dependencies = [
"open",
"percent-encoding",
"pyo3",
"rand",
"rstest",
"thiserror",
"url",
Expand Down
1 change: 1 addition & 0 deletions crates/chia-datalayer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ hex = { workspace = true }
hex-literal = { workspace = true }
open = { workspace = true }
percent-encoding = { workspace = true }
rand = { workspace = true }
rstest = { workspace = true }
url = { workspace = true }

Expand Down
254 changes: 244 additions & 10 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,74 @@ pub enum NodeType {
Leaf = 1,
}

#[cfg_attr(feature = "py-bindings", pyclass(get_all))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProofOfInclusionLayer {
pub other_hash_side: Side,
pub other_hash: Hash,
pub combined_hash: Hash,
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl ProofOfInclusionLayer {
#[new]
pub fn py_init(other_hash_side: Side, other_hash: Hash, combined_hash: Hash) -> PyResult<Self> {
Ok(Self {
other_hash_side,
other_hash,
combined_hash,
})
}
}

#[cfg_attr(feature = "py-bindings", pyclass(get_all))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProofOfInclusion {
pub node_hash: Hash,
pub layers: Vec<ProofOfInclusionLayer>,
}

impl ProofOfInclusion {
pub fn root_hash(&self) -> Hash {
if let Some(last) = self.layers.last() {
last.combined_hash
} else {
self.node_hash
}
}

pub fn valid(&self) -> bool {
let mut existing_hash = self.node_hash;

for layer in &self.layers {
let calculated_hash =
calculate_internal_hash(&existing_hash, layer.other_hash_side, &layer.other_hash);

if calculated_hash != layer.combined_hash {
return false;
}

existing_hash = calculated_hash;
}

existing_hash == self.root_hash()
}
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl ProofOfInclusion {
#[pyo3(name = "root_hash")]
pub fn py_root_hash(&self) -> Hash {
self.root_hash()
}
#[pyo3(name = "valid")]
pub fn py_valid(&self) -> bool {
self.valid()
}
}

#[allow(clippy::needless_pass_by_value)]
fn sha256_num<T: ToBytes>(input: T) -> Hash {
let mut hasher = Sha256::new();
Expand All @@ -358,6 +426,13 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash {
Hash(Bytes32::new(hasher.finalize()))
}

pub fn calculate_internal_hash(hash: &Hash, other_hash_side: Side, other_hash: &Hash) -> Hash {
match other_hash_side {
Side::Left => internal_hash(other_hash, hash),
Side::Right => internal_hash(hash, other_hash),
}
}

#[cfg_attr(feature = "py-bindings", derive(PyJsonDict, PyStreamable))]
#[repr(u8)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Streamable)]
Expand Down Expand Up @@ -407,6 +482,16 @@ impl InternalNode {
Err(Error::IndexIsNotAChild(index))
}
}

pub fn get_sibling_side(&self, index: TreeIndex) -> Result<Side, Error> {
if self.left == index {
Ok(Side::Right)
} else if self.right == index {
Ok(Side::Left)
} else {
Err(Error::IndexIsNotAChild(index))
}
}
}

#[cfg_attr(
Expand Down Expand Up @@ -488,6 +573,15 @@ impl Node {
*leaf
}

fn expect_internal(&self, message: &str) -> InternalNode {
let Node::Internal(internal) = self else {
let message = message.replace("<<self>>", &format!("{self:?}"));
panic!("{}", message)
};

*internal
}

fn try_into_leaf(self) -> Result<LeafNode, Error> {
match self {
Node::Leaf(leaf) => Ok(leaf),
Expand Down Expand Up @@ -825,16 +919,17 @@ impl MerkleBlob {
Ok(new_leaf_index)
}

pub fn batch_insert<I>(&mut self, mut keys_values_hashes: I) -> Result<(), Error>
where
I: Iterator<Item = ((KeyId, ValueId), Hash)>,
{
pub fn batch_insert(
&mut self,
mut keys_values_hashes: Vec<((KeyId, ValueId), Hash)>,
) -> Result<(), Error> {
// OPT: perhaps go back to taking an iterator?
// OPT: would it be worthwhile to hold the entire blocks?
let mut indexes = vec![];

if self.key_to_index.len() <= 1 {
for _ in 0..2 {
let Some(((key, value), hash)) = keys_values_hashes.next() else {
let Some(((key, value), hash)) = keys_values_hashes.pop() else {
return Ok(());
};
self.insert(key, value, &hash, InsertLocation::Auto {})?;
Expand Down Expand Up @@ -978,6 +1073,7 @@ impl MerkleBlob {
panic!("not handling this case now...")
}
self.insert_entry_to_blob(old_leaf_parent, &parent)?;
self.mark_lineage_as_dirty(old_leaf_parent)?;
self.update_parent(old_leaf_index, Some(new_internal_node_index))?;

Ok(())
Expand Down Expand Up @@ -1012,14 +1108,14 @@ impl MerkleBlob {

let Some(grandparent_index) = parent.parent.0 else {
sibling_block.node.set_parent(Parent(None));
self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?;

if let Node::Internal(node) = sibling_block.node {
sibling_block.metadata.dirty = true;
for child_index in [node.left, node.right] {
self.update_parent(child_index, Some(TreeIndex(0)))?;
}
};

self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?;
self.free_indexes.insert(sibling_index);

return Ok(());
Expand Down Expand Up @@ -1148,7 +1244,7 @@ impl MerkleBlob {
let mut block = Block::from_bytes(self.get_block_bytes(this_index)?)?;

if block.metadata.dirty {
return Ok(());
break;
}

block.metadata.dirty = true;
Expand Down Expand Up @@ -1377,6 +1473,38 @@ impl MerkleBlob {
.copied()
.ok_or(Error::UnknownKey(key))
}

pub fn get_proof_of_inclusion(&self, key: KeyId) -> Result<ProofOfInclusion, Error> {
let mut index = *self.key_to_index.get(&key).ok_or(Error::UnknownKey(key))?;

let node = self
.get_node(index)?
.expect_leaf("key to index mapping should only have leaves");

let parents = self.get_lineage_with_indexes(index)?;
let mut layers: Vec<ProofOfInclusionLayer> = Vec::new();
let mut parents_iter = parents.iter();
// first in the lineage is the index itself, second is the first parent
parents_iter.next();
for (next_index, parent) in parents_iter {
let parent = parent.expect_internal("all nodes after the first should be internal");
let sibling_index = parent.sibling_index(index)?;
let sibling_block = self.get_block(sibling_index)?;
let sibling = sibling_block.node;
let layer = ProofOfInclusionLayer {
other_hash_side: parent.get_sibling_side(index)?,
other_hash: sibling.hash(),
combined_hash: parent.hash,
};
layers.push(layer);
index = *next_index;
}

Ok(ProofOfInclusion {
node_hash: node.hash,
layers,
})
}
}

impl PartialEq for MerkleBlob {
Expand Down Expand Up @@ -1553,7 +1681,7 @@ impl MerkleBlob {
));
}

self.batch_insert(&mut zip(keys_values, hashes))?;
self.batch_insert(zip(keys_values, hashes).collect())?;

Ok(())
}
Expand All @@ -1578,6 +1706,11 @@ impl MerkleBlob {
pub fn py_get_key_index(&self, key: KeyId) -> PyResult<TreeIndex> {
Ok(self.get_key_index(key)?)
}

#[pyo3(name = "get_proof_of_inclusion")]
pub fn py_get_proof_of_inclusion(&self, key: KeyId) -> PyResult<ProofOfInclusion> {
Ok(self.get_proof_of_inclusion(key)?)
}
}

fn try_get_block(blob: &[u8], index: TreeIndex) -> Result<Block, Error> {
Expand Down Expand Up @@ -1779,6 +1912,9 @@ mod tests {
use super::*;
use crate::merkle::dot::DotLines;
use expect_test::{expect, Expect};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rstest::{fixture, rstest};
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -2289,7 +2425,7 @@ mod tests {
}

let before = blob.get_keys_values().unwrap();
blob.batch_insert(batch.into_iter()).unwrap();
blob.batch_insert(batch).unwrap();
let after = blob.get_keys_values().unwrap();

open_dot(
Expand Down Expand Up @@ -2591,4 +2727,102 @@ mod tests {
assert!(free_indexes.contains(&new_index));
assert!(small_blob.free_indexes.is_empty());
}

fn generate_kvid(seed: i64) -> (KeyId, ValueId) {
let mut kv_ids: Vec<i64> = Vec::new();

for offset in 0..2 {
let seed_int = 2i64 * seed + offset;
let seed_bytes = seed_int.to_be_bytes();
let hash = sha256_bytes(&seed_bytes);
let hash_int = i64::from_be_bytes(hash.0[0..8].try_into().unwrap());
kv_ids.push(hash_int);
}

(KeyId(kv_ids[0]), ValueId(kv_ids[1]))
}

fn generate_hash(seed: i64) -> Hash {
let seed_bytes = seed.to_be_bytes();
sha256_bytes(&seed_bytes)
}

#[test]
fn test_proof_of_inclusion() {
let num_repeats = 10;
let mut seed = 0;

let mut random = StdRng::seed_from_u64(37);

let mut merkle_blob = MerkleBlob::new(Vec::new()).unwrap();
let mut keys_values: HashMap<KeyId, ValueId> = HashMap::new();

for repeats in 0..num_repeats {
let num_inserts = 1 + repeats * 100;
let num_deletes = 1 + repeats * 10;

let mut kv_ids: Vec<(KeyId, ValueId)> = Vec::new();
let mut hashes: Vec<Hash> = Vec::new();
for _ in 0..num_inserts {
seed += 1;
let (key, value) = generate_kvid(seed);
kv_ids.push((key, value));
hashes.push(generate_hash(seed));
keys_values.insert(key, value);
}

merkle_blob
.batch_insert(zip(kv_ids, hashes).collect())
.unwrap();
merkle_blob.calculate_lazy_hashes().unwrap();
for i in 0..(merkle_blob.blob.len() / BLOCK_SIZE) {
let node = merkle_blob.get_node(TreeIndex(i as u32)).unwrap();
println!("{i:05}: {node:?}");
}

for kv_id in keys_values.keys().copied() {
let proof_of_inclusion = match merkle_blob.get_proof_of_inclusion(kv_id) {
Ok(proof_of_inclusion) => proof_of_inclusion,
Err(error) => {
open_dot(merkle_blob.to_dot().unwrap().set_note(&error.to_string()));
panic!("here");
}
};
assert!(proof_of_inclusion.valid());
}

let mut delete_ordering: Vec<KeyId> = keys_values.keys().copied().collect();
delete_ordering.shuffle(&mut random);
delete_ordering = delete_ordering[0..num_deletes].to_vec();
for kv_id in delete_ordering.iter().copied() {
merkle_blob.delete(kv_id).unwrap();
keys_values.remove(&kv_id);
}

for kv_id in delete_ordering {
// with pytest.raises(Exception, match = f"unknown key: {re.escape(str(kv_id))}"):
merkle_blob
.get_proof_of_inclusion(kv_id)
.expect_err("stuff");
}

let mut new_keys_values: HashMap<KeyId, ValueId> = HashMap::new();
for old_kv in keys_values.keys().copied() {
seed += 1;
let (_, value) = generate_kvid(seed);
let hash = generate_hash(seed);
merkle_blob.upsert(old_kv, value, &hash).unwrap();
new_keys_values.insert(old_kv, value);
}
if !merkle_blob.blob.is_empty() {
merkle_blob.calculate_lazy_hashes().unwrap();
}

keys_values = new_keys_values;
for kv_id in keys_values.keys().copied() {
let proof_of_inclusion = merkle_blob.get_proof_of_inclusion(kv_id).unwrap();
assert!(proof_of_inclusion.valid());
}
}
}
}
Loading

0 comments on commit cbbfa26

Please sign in to comment.