Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add datalayer proof of inclusion generation (and other fixes to make it pass tests) #902

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ dependencies = [
"open",
"percent-encoding",
"pyo3",
"rand",
"rstest",
"thiserror",
"url",
Expand Down
1 change: 1 addition & 0 deletions crates/chia-datalayer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ hex = { workspace = true }
hex-literal = { workspace = true }
open = { workspace = true }
percent-encoding = { workspace = true }
rand = { workspace = true }
rstest = { workspace = true }
url = { workspace = true }

Expand Down
254 changes: 244 additions & 10 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,74 @@ pub enum NodeType {
Leaf = 1,
}

#[cfg_attr(feature = "py-bindings", pyclass(get_all))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProofOfInclusionLayer {
pub other_hash_side: Side,
pub other_hash: Hash,
pub combined_hash: Hash,
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl ProofOfInclusionLayer {
#[new]
pub fn py_init(other_hash_side: Side, other_hash: Hash, combined_hash: Hash) -> PyResult<Self> {
Ok(Self {
other_hash_side,
other_hash,
combined_hash,
})
}
}

#[cfg_attr(feature = "py-bindings", pyclass(get_all))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct ProofOfInclusion {
pub node_hash: Hash,
pub layers: Vec<ProofOfInclusionLayer>,
}

impl ProofOfInclusion {
pub fn root_hash(&self) -> Hash {
if let Some(last) = self.layers.last() {
last.combined_hash
} else {
self.node_hash
}
}

pub fn valid(&self) -> bool {
let mut existing_hash = self.node_hash;

for layer in &self.layers {
let calculated_hash =
calculate_internal_hash(&existing_hash, layer.other_hash_side, &layer.other_hash);

if calculated_hash != layer.combined_hash {
return false;
}

existing_hash = calculated_hash;
}

existing_hash == self.root_hash()
}
}

#[cfg(feature = "py-bindings")]
#[pymethods]
impl ProofOfInclusion {
#[pyo3(name = "root_hash")]
pub fn py_root_hash(&self) -> Hash {
self.root_hash()
}
#[pyo3(name = "valid")]
pub fn py_valid(&self) -> bool {
self.valid()
}
}

#[allow(clippy::needless_pass_by_value)]
fn sha256_num<T: ToBytes>(input: T) -> Hash {
let mut hasher = Sha256::new();
Expand All @@ -358,6 +426,13 @@ fn internal_hash(left_hash: &Hash, right_hash: &Hash) -> Hash {
Hash(Bytes32::new(hasher.finalize()))
}

pub fn calculate_internal_hash(hash: &Hash, other_hash_side: Side, other_hash: &Hash) -> Hash {
match other_hash_side {
Side::Left => internal_hash(other_hash, hash),
Side::Right => internal_hash(hash, other_hash),
}
}

#[cfg_attr(feature = "py-bindings", derive(PyJsonDict, PyStreamable))]
#[repr(u8)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Streamable)]
Expand Down Expand Up @@ -407,6 +482,16 @@ impl InternalNode {
Err(Error::IndexIsNotAChild(index))
}
}

pub fn get_sibling_side(&self, index: TreeIndex) -> Result<Side, Error> {
if self.left == index {
Ok(Side::Right)
} else if self.right == index {
Ok(Side::Left)
} else {
Err(Error::IndexIsNotAChild(index))
}
}
}

#[cfg_attr(
Expand Down Expand Up @@ -488,6 +573,15 @@ impl Node {
*leaf
}

fn expect_internal(&self, message: &str) -> InternalNode {
let Node::Internal(internal) = self else {
let message = message.replace("<<self>>", &format!("{self:?}"));
panic!("{}", message)
};

*internal
}

fn try_into_leaf(self) -> Result<LeafNode, Error> {
match self {
Node::Leaf(leaf) => Ok(leaf),
Expand Down Expand Up @@ -825,16 +919,17 @@ impl MerkleBlob {
Ok(new_leaf_index)
}

pub fn batch_insert<I>(&mut self, mut keys_values_hashes: I) -> Result<(), Error>
where
I: Iterator<Item = ((KeyId, ValueId), Hash)>,
{
pub fn batch_insert(
&mut self,
mut keys_values_hashes: Vec<((KeyId, ValueId), Hash)>,
) -> Result<(), Error> {
// OPT: perhaps go back to taking an iterator?
// OPT: would it be worthwhile to hold the entire blocks?
let mut indexes = vec![];

if self.key_to_index.len() <= 1 {
for _ in 0..2 {
let Some(((key, value), hash)) = keys_values_hashes.next() else {
let Some(((key, value), hash)) = keys_values_hashes.pop() else {
return Ok(());
};
self.insert(key, value, &hash, InsertLocation::Auto {})?;
Expand Down Expand Up @@ -978,6 +1073,7 @@ impl MerkleBlob {
panic!("not handling this case now...")
}
self.insert_entry_to_blob(old_leaf_parent, &parent)?;
self.mark_lineage_as_dirty(old_leaf_parent)?;
self.update_parent(old_leaf_index, Some(new_internal_node_index))?;

Ok(())
Expand Down Expand Up @@ -1012,14 +1108,14 @@ impl MerkleBlob {

let Some(grandparent_index) = parent.parent.0 else {
sibling_block.node.set_parent(Parent(None));
self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?;

if let Node::Internal(node) = sibling_block.node {
sibling_block.metadata.dirty = true;
for child_index in [node.left, node.right] {
self.update_parent(child_index, Some(TreeIndex(0)))?;
}
};

self.insert_entry_to_blob(TreeIndex(0), &sibling_block)?;
self.free_indexes.insert(sibling_index);

return Ok(());
Expand Down Expand Up @@ -1148,7 +1244,7 @@ impl MerkleBlob {
let mut block = Block::from_bytes(self.get_block_bytes(this_index)?)?;

if block.metadata.dirty {
return Ok(());
break;
}

block.metadata.dirty = true;
Expand Down Expand Up @@ -1377,6 +1473,38 @@ impl MerkleBlob {
.copied()
.ok_or(Error::UnknownKey(key))
}

pub fn get_proof_of_inclusion(&self, key: KeyId) -> Result<ProofOfInclusion, Error> {
let mut index = *self.key_to_index.get(&key).ok_or(Error::UnknownKey(key))?;

let node = self
.get_node(index)?
.expect_leaf("key to index mapping should only have leaves");

let parents = self.get_lineage_with_indexes(index)?;
let mut layers: Vec<ProofOfInclusionLayer> = Vec::new();
let mut parents_iter = parents.iter();
// first in the lineage is the index itself, second is the first parent
parents_iter.next();
for (next_index, parent) in parents_iter {
let parent = parent.expect_internal("all nodes after the first should be internal");
let sibling_index = parent.sibling_index(index)?;
let sibling_block = self.get_block(sibling_index)?;
let sibling = sibling_block.node;
let layer = ProofOfInclusionLayer {
other_hash_side: parent.get_sibling_side(index)?,
other_hash: sibling.hash(),
combined_hash: parent.hash,
};
layers.push(layer);
index = *next_index;
}

Ok(ProofOfInclusion {
node_hash: node.hash,
layers,
})
}
}

impl PartialEq for MerkleBlob {
Expand Down Expand Up @@ -1553,7 +1681,7 @@ impl MerkleBlob {
));
}

self.batch_insert(&mut zip(keys_values, hashes))?;
self.batch_insert(zip(keys_values, hashes).collect())?;

Ok(())
}
Expand All @@ -1578,6 +1706,11 @@ impl MerkleBlob {
pub fn py_get_key_index(&self, key: KeyId) -> PyResult<TreeIndex> {
Ok(self.get_key_index(key)?)
}

#[pyo3(name = "get_proof_of_inclusion")]
pub fn py_get_proof_of_inclusion(&self, key: KeyId) -> PyResult<ProofOfInclusion> {
Ok(self.get_proof_of_inclusion(key)?)
}
}

fn try_get_block(blob: &[u8], index: TreeIndex) -> Result<Block, Error> {
Expand Down Expand Up @@ -1779,6 +1912,9 @@ mod tests {
use super::*;
use crate::merkle::dot::DotLines;
use expect_test::{expect, Expect};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rstest::{fixture, rstest};
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -2289,7 +2425,7 @@ mod tests {
}

let before = blob.get_keys_values().unwrap();
blob.batch_insert(batch.into_iter()).unwrap();
blob.batch_insert(batch).unwrap();
let after = blob.get_keys_values().unwrap();

open_dot(
Expand Down Expand Up @@ -2591,4 +2727,102 @@ mod tests {
assert!(free_indexes.contains(&new_index));
assert!(small_blob.free_indexes.is_empty());
}

fn generate_kvid(seed: i64) -> (KeyId, ValueId) {
let mut kv_ids: Vec<i64> = Vec::new();

for offset in 0..2 {
let seed_int = 2i64 * seed + offset;
let seed_bytes = seed_int.to_be_bytes();
let hash = sha256_bytes(&seed_bytes);
let hash_int = i64::from_be_bytes(hash.0[0..8].try_into().unwrap());
kv_ids.push(hash_int);
}

(KeyId(kv_ids[0]), ValueId(kv_ids[1]))
}

fn generate_hash(seed: i64) -> Hash {
let seed_bytes = seed.to_be_bytes();
sha256_bytes(&seed_bytes)
}

#[test]
fn test_proof_of_inclusion() {
let num_repeats = 10;
let mut seed = 0;

let mut random = StdRng::seed_from_u64(37);

let mut merkle_blob = MerkleBlob::new(Vec::new()).unwrap();
let mut keys_values: HashMap<KeyId, ValueId> = HashMap::new();

for repeats in 0..num_repeats {
let num_inserts = 1 + repeats * 100;
let num_deletes = 1 + repeats * 10;

let mut kv_ids: Vec<(KeyId, ValueId)> = Vec::new();
let mut hashes: Vec<Hash> = Vec::new();
for _ in 0..num_inserts {
seed += 1;
let (key, value) = generate_kvid(seed);
kv_ids.push((key, value));
hashes.push(generate_hash(seed));
keys_values.insert(key, value);
}

merkle_blob
.batch_insert(zip(kv_ids, hashes).collect())
.unwrap();
merkle_blob.calculate_lazy_hashes().unwrap();
for i in 0..(merkle_blob.blob.len() / BLOCK_SIZE) {
let node = merkle_blob.get_node(TreeIndex(i as u32)).unwrap();
println!("{i:05}: {node:?}");
}

for kv_id in keys_values.keys().copied() {
let proof_of_inclusion = match merkle_blob.get_proof_of_inclusion(kv_id) {
Ok(proof_of_inclusion) => proof_of_inclusion,
Err(error) => {
open_dot(merkle_blob.to_dot().unwrap().set_note(&error.to_string()));
panic!("here");
}
};
assert!(proof_of_inclusion.valid());
}

let mut delete_ordering: Vec<KeyId> = keys_values.keys().copied().collect();
delete_ordering.shuffle(&mut random);
delete_ordering = delete_ordering[0..num_deletes].to_vec();
for kv_id in delete_ordering.iter().copied() {
merkle_blob.delete(kv_id).unwrap();
keys_values.remove(&kv_id);
}

for kv_id in delete_ordering {
// with pytest.raises(Exception, match = f"unknown key: {re.escape(str(kv_id))}"):
merkle_blob
.get_proof_of_inclusion(kv_id)
.expect_err("stuff");
}

let mut new_keys_values: HashMap<KeyId, ValueId> = HashMap::new();
for old_kv in keys_values.keys().copied() {
seed += 1;
let (_, value) = generate_kvid(seed);
let hash = generate_hash(seed);
merkle_blob.upsert(old_kv, value, &hash).unwrap();
new_keys_values.insert(old_kv, value);
}
if !merkle_blob.blob.is_empty() {
merkle_blob.calculate_lazy_hashes().unwrap();
}

keys_values = new_keys_values;
for kv_id in keys_values.keys().copied() {
let proof_of_inclusion = merkle_blob.get_proof_of_inclusion(kv_id).unwrap();
assert!(proof_of_inclusion.valid());
}
}
}
}
Loading
Loading