Skip to content

Commit

Permalink
Address some comments from ChiaMineJP.
Browse files Browse the repository at this point in the history
  • Loading branch information
richardkiss committed May 24, 2022
1 parent e92f68f commit e4bce12
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 31 deletions.
2 changes: 1 addition & 1 deletion clvm/SExp.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def nullp(self):
def as_int(self):
return int_from_bytes(self.atom)

def as_bin(self, /, allow_backrefs=False):
def as_bin(self, *, allow_backrefs=False):
f = io.BytesIO()
sexp_to_stream(self, f, allow_backrefs=allow_backrefs)
return f.getvalue()
Expand Down
19 changes: 19 additions & 0 deletions clvm/object_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,23 @@ class ObjectCache:
in a clvm object tree. It can be used to calculate the sha256 tree hash
for an object and save the hash for all the child objects for building
usage tables, for example.
It also allows a function that's defined recursively on a clvm tree to
have a non-recursive implementation (as it keeps a stack of uncached
objects locally).
"""
def __init__(self, f):
"""
`f`: Callable[ObjectCache, CLVMObject] -> Union[None, T]
The function `f` is expected to calculate its T value recursively based
on the T values for the left and right child for a pair. For an atom, the
function f must calculate the T value directly.
If a pair is passed and one of the children does not have its T value cached
in `ObjectCache` yet, return `None` and f will be called with each child in turn.
Don't recurse in f; that's part of the point of this function.
"""
self.f = f
self.lookup = dict()

Expand Down Expand Up @@ -42,6 +57,8 @@ def treehash(cache, obj):
"""
if obj.pair:
left, right = obj.pair

# ensure both `left` and `right` have cached values
if cache.contains(left) and cache.contains(right):
left_hash = cache.get(left)
right_hash = cache.get(right)
Expand All @@ -57,6 +74,8 @@ def serialized_length(cache, obj):
"""
if obj.pair:
left, right = obj.pair

# ensure both `left` and `right` have cached values
if cache.contains(left) and cache.contains(right):
left_length = cache.get(left)
right_length = cache.get(right)
Expand Down
61 changes: 34 additions & 27 deletions clvm/read_cache_lookup.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from collections import Counter
from typing import Optional, List, Tuple

import hashlib


def hash_blobs(*blobs):
s = hashlib.sha256()
for blob in blobs:
s.update(blob)
return s.digest()


class ReadCacheLookup:
"""
When deserializing a clvm object, a stack of deserialized child objects
Expand All @@ -27,39 +21,41 @@ class ReadCacheLookup:
child objects that are transient, and no longer appear in the stack
at later times in the parsing. We don't want to waste time looking for
these objects that no longer exist, so we reference-count them.
All hashes correspond to sha256 tree hashes.
"""

def __init__(self):
"""
Create a new `ReadCacheLookup` object with just the null terminator
(ie. and empty list of objects).
(ie. an empty list of objects).
"""
self.root_hash = hash_blobs(b"\1")
self.root_hash = hashlib.sha256(b"\1").digest()
self.read_stack = []
self.count = Counter()
self.parent_lookup = {}
self.parent_paths_for_child = {}

def push(self, obj_hash):
def push(self, obj_hash: bytes) -> None:
"""
Note that an object with the given hash has just been pushed to
the read stack, and update the lookups as appropriate.
"""
# we add two new entries: the new root of the tree, and this object (by id)
# new_root: (obj_hash, old_root)
new_root_hash = hash_blobs(b"\2", obj_hash, self.root_hash)
new_root_hash = hashlib.sha256(b"\2" + obj_hash + self.root_hash).digest()

self.read_stack.append((obj_hash, self.root_hash))

self.count.update([obj_hash, new_root_hash])

new_parent_to_old_root = (new_root_hash, 0)
self.parent_lookup.setdefault(obj_hash, list()).append(new_parent_to_old_root)
self.parent_paths_for_child.setdefault(obj_hash, list()).append(new_parent_to_old_root)

new_parent_to_id = (new_root_hash, 1)
self.parent_lookup.setdefault(self.root_hash, list()).append(new_parent_to_id)
self.parent_paths_for_child.setdefault(self.root_hash, list()).append(new_parent_to_id)
self.root_hash = new_root_hash

def pop(self):
def pop(self) -> Tuple[bytes, bytes]:
"""
Note that the top object has just been popped from the read stack.
Return the 2-tuple of the child hashes.
Expand All @@ -70,7 +66,7 @@ def pop(self):
self.root_hash = item[1]
return item

def pop2_and_cons(self):
def pop2_and_cons(self) -> None:
"""
Note that a "pop-and-cons" operation has just happened. We remove
two objects, cons them together, and push the cons, updating
Expand All @@ -82,13 +78,13 @@ def pop2_and_cons(self):

self.count.update([left[0], right[0]])

new_root_hash = hash_blobs(b"\2", left[0], right[0])
new_root_hash = hashlib.sha256(b"\2" + left[0] + right[0]).digest()

self.parent_lookup.setdefault(left[0], list()).append((new_root_hash, 0))
self.parent_lookup.setdefault(right[0], list()).append((new_root_hash, 1))
self.parent_paths_for_child.setdefault(left[0], list()).append((new_root_hash, 0))
self.parent_paths_for_child.setdefault(right[0], list()).append((new_root_hash, 1))
self.push(new_root_hash)

def find_path(self, obj, serialized_length):
def find_path(self, obj_hash: bytes, serialized_length: int) -> Optional[bytes]:
"""
This function looks for a path from the root to a child node with a given hash
by using the read cache.
Expand All @@ -102,9 +98,9 @@ def find_path(self, obj, serialized_length):
# 1 byte for 0xfe, 1 min byte for savings

max_path_length = max_bytes_for_path_encoding * 8 - 1
seen_ids.add(obj)
seen_ids.add(obj_hash)

partial_paths = [(obj, [])]
partial_paths = [(obj_hash, [])]

while partial_paths:
new_partial_paths = []
Expand All @@ -113,10 +109,10 @@ def find_path(self, obj, serialized_length):
path.reverse()
return path_to_bytes(path)

parents = self.parent_lookup.get(node)
parent_paths = self.parent_paths_for_child.get(node)

if parents:
for (parent, direction) in parents:
if parent_paths:
for (parent, direction) in parent_paths:
if self.count[parent] > 0 and parent not in seen_ids:
new_path = list(path)
new_path.append(direction)
Expand All @@ -128,10 +124,21 @@ def find_path(self, obj, serialized_length):
return None


def path_to_bytes(path):
def path_to_bytes(path: List[int]) -> bytes:
"""
Convert a list of 0/1 values to a path expected by clvm.
Examples:
[] => bytes([0b1])
[0] => bytes([0b10])
[1] => bytes([0b11])
[0, 0] => bytes([0b100])
[0, 1] => bytes([0b110])
[1, 0] => 0b101
[1, 1] => 0b111
[1, 0, 0] => 0b1001
"""

byte_count = (len(path) + 1 + 7) >> 3
v = bytearray(byte_count)
index = byte_count - 1
Expand All @@ -145,4 +152,4 @@ def path_to_bytes(path):
else:
mask <<= 1
v[index] |= mask
return v
return bytes(v)
6 changes: 3 additions & 3 deletions clvm/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CONS_BOX_MARKER = 0xFF


def sexp_to_byte_iterator(sexp, /, allow_backrefs=False) -> Iterator[bytes]:
def sexp_to_byte_iterator(sexp, *, allow_backrefs=False) -> Iterator[bytes]:
if allow_backrefs:
yield from sexp_to_byte_iterator_with_backrefs(sexp)
return
Expand Down Expand Up @@ -125,7 +125,7 @@ def atom_to_byte_iterator(as_atom):
yield as_atom


def sexp_to_stream(sexp, f, /, allow_backrefs=False):
def sexp_to_stream(sexp, f, *, allow_backrefs=False):
for b in sexp_to_byte_iterator(sexp, allow_backrefs=allow_backrefs):
f.write(b)

Expand Down Expand Up @@ -206,7 +206,7 @@ def _op_read_sexp_allow_backrefs(op_stack, val_stack, f, to_sexp):
return to_sexp((_atom_from_stream(f, b, to_sexp), val_stack))


def sexp_from_stream(f, to_sexp, /, allow_backrefs=False):
def sexp_from_stream(f, to_sexp, *, allow_backrefs=False):
op_stack = [_op_read_sexp_allow_backrefs if allow_backrefs else _op_read_sexp]
val_stack = to_sexp(b"")

Expand Down
52 changes: 52 additions & 0 deletions tests/test_object_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import io
import unittest

from clvm import to_sexp_f
from clvm.object_cache import ObjectCache, treehash, serialized_length
from clvm.serialize import sexp_from_stream

from clvm_tools.binutils import assemble

TEXT = b"the quick brown fox jumps over the lazy dogs"


def obj_from_hex(h):
b = bytes.fromhex(h)
f = io.BytesIO(b)
return sexp_from_stream(f, to_sexp_f)


class ObjectCacheTest(unittest.TestCase):
def check(self, obj_text, expected_hash, expected_length):
obj = assemble(obj_text)
th = ObjectCache(treehash)
self.assertEqual(th.get(obj).hex(), expected_hash)
sl = ObjectCache(serialized_length)
self.assertEqual(sl.get(obj), expected_length)

def test_various(self):
self.check(
"0x00",
"47dc540c94ceb704a23875c11273e16bb0b8a87aed84de911f2133568115f254",
1,
)

self.check(
"0", "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", 1
)

self.check(
"foo", "0080b50a51ecd0ccfaaa4d49dba866fe58724f18445d30202bafb03e21eef6cb", 4
)

self.check(
"(foo . bar)",
"c518e45ae6a7b4146017b7a1d81639051b132f1f5572ce3088a3898a9ed1280b",
9,
)

self.check(
"(this is a longer test of a deeper tree)",
"0a072d7d860d77d8e290ced0fdb29a271198ca3db54d701c45d831e3aae6422c",
47,
)

0 comments on commit e4bce12

Please sign in to comment.