diff --git a/hasher.go b/hasher.go index 5dd64704..cb5b3fb8 100644 --- a/hasher.go +++ b/hasher.go @@ -14,7 +14,7 @@ const ( NodePrefix = 1 ) -var _ hash.Hash = (*Hasher)(nil) +var _ hash.Hash = (*NmtHasher)(nil) var ( ErrUnorderedSiblings = errors.New("NMT sibling nodes should be ordered lexicographically by namespace IDs") @@ -23,7 +23,24 @@ var ( ErrInvalidNodeNamespaceOrder = errors.New("invalid NMT node namespace order") ) -type Hasher struct { +// Hasher describes the interface nmts use to hash leafs and nodes. +// +// Note: it is not advised to create alternative hashers if following the +// specification is desired. The main reason this exists is to not follow the +// specification for testing purposes. +type Hasher interface { + IsMaxNamespaceIDIgnored() bool + NamespaceSize() namespace.IDSize + HashLeaf(data []byte) ([]byte, error) + HashNode(leftChild, rightChild []byte) ([]byte, error) + EmptyRoot() []byte +} + +var _ Hasher = &NmtHasher{} + +// NmtHasher is the default hasher. It follows the description of the original +// hashing function described in the LazyLedger white paper. +type NmtHasher struct { //nolint:revive baseHasher hash.Hash NamespaceLen namespace.IDSize @@ -41,16 +58,16 @@ type Hasher struct { data []byte // written data of the NMT node } -func (n *Hasher) IsMaxNamespaceIDIgnored() bool { +func (n *NmtHasher) IsMaxNamespaceIDIgnored() bool { return n.ignoreMaxNs } -func (n *Hasher) NamespaceSize() namespace.IDSize { +func (n *NmtHasher) NamespaceSize() namespace.IDSize { return n.NamespaceLen } -func NewNmtHasher(baseHasher hash.Hash, nidLen namespace.IDSize, ignoreMaxNamespace bool) *Hasher { - return &Hasher{ +func NewNmtHasher(baseHasher hash.Hash, nidLen namespace.IDSize, ignoreMaxNamespace bool) *NmtHasher { + return &NmtHasher{ baseHasher: baseHasher, NamespaceLen: nidLen, ignoreMaxNs: ignoreMaxNamespace, @@ -59,7 +76,7 @@ func NewNmtHasher(baseHasher hash.Hash, nidLen namespace.IDSize, ignoreMaxNamesp } // Size returns the number of bytes Sum will return. -func (n *Hasher) Size() int { +func (n *NmtHasher) Size() int { return n.baseHasher.Size() + int(n.NamespaceLen)*2 } @@ -69,7 +86,7 @@ func (n *Hasher) Size() int { // write is allowed. // It panics if more than one single write is attempted. // If the data does not match the format of an NMT non-leaf node or leaf node, an error will be returned. -func (n *Hasher) Write(data []byte) (int, error) { +func (n *NmtHasher) Write(data []byte) (int, error) { if n.data != nil { panic("only a single Write is allowed") } @@ -101,7 +118,7 @@ func (n *Hasher) Write(data []byte) (int, error) { // Sum computes the hash. Does not append the given suffix, violating the // interface. // It may panic if the data being hashed is invalid. This should never happen since the Write method refuses an invalid data and errors out. -func (n *Hasher) Sum([]byte) []byte { +func (n *NmtHasher) Sum([]byte) []byte { switch n.tp { case LeafPrefix: res, err := n.HashLeaf(n.data) @@ -125,17 +142,17 @@ func (n *Hasher) Sum([]byte) []byte { } // Reset resets the Hash to its initial state. -func (n *Hasher) Reset() { +func (n *NmtHasher) Reset() { n.tp, n.data = 255, nil // reset with an invalid node type, as zero value is a valid Leaf n.baseHasher.Reset() } // BlockSize returns the hash's underlying block size. -func (n *Hasher) BlockSize() int { +func (n *NmtHasher) BlockSize() int { return n.baseHasher.BlockSize() } -func (n *Hasher) EmptyRoot() []byte { +func (n *NmtHasher) EmptyRoot() []byte { n.baseHasher.Reset() emptyNs := bytes.Repeat([]byte{0}, int(n.NamespaceLen)) h := n.baseHasher.Sum(nil) @@ -145,7 +162,7 @@ func (n *Hasher) EmptyRoot() []byte { } // ValidateLeaf verifies if data is namespaced and returns an error if not. -func (n *Hasher) ValidateLeaf(data []byte) (err error) { +func (n *NmtHasher) ValidateLeaf(data []byte) (err error) { nidSize := int(n.NamespaceSize()) lenData := len(data) if lenData < nidSize { @@ -160,7 +177,7 @@ func (n *Hasher) ValidateLeaf(data []byte) (err error) { // leaves minNs = maxNs = ns(leaf) = leaf[:NamespaceLen]. HashLeaf can return the ErrInvalidNodeLen error if the input is not namespaced. // //nolint:errcheck -func (n *Hasher) HashLeaf(ndata []byte) ([]byte, error) { +func (n *NmtHasher) HashLeaf(ndata []byte) ([]byte, error) { h := n.baseHasher h.Reset() @@ -187,7 +204,7 @@ func (n *Hasher) HashLeaf(ndata []byte) ([]byte, error) { // MustHashLeaf is a wrapper around HashLeaf that panics if an error is // encountered. The ndata must be a valid leaf node. -func (n *Hasher) MustHashLeaf(ndata []byte) []byte { +func (n *NmtHasher) MustHashLeaf(ndata []byte) []byte { res, err := n.HashLeaf(ndata) if err != nil { panic(err) @@ -197,7 +214,7 @@ func (n *Hasher) MustHashLeaf(ndata []byte) []byte { // ValidateNodeFormat checks whether the supplied node conforms to the // namespaced hash format and returns ErrInvalidNodeLen if not. -func (n *Hasher) ValidateNodeFormat(node []byte) (err error) { +func (n *NmtHasher) ValidateNodeFormat(node []byte) (err error) { expectedNodeLen := n.Size() nodeLen := len(node) if nodeLen != expectedNodeLen { @@ -216,7 +233,7 @@ func (n *Hasher) ValidateNodeFormat(node []byte) (err error) { // nodes in an NMT have correct namespace IDs relative to each other, more // specifically, the maximum namespace ID of the left sibling should not exceed // the minimum namespace ID of the right sibling. It returns ErrUnorderedSiblings error if the check fails. -func (n *Hasher) validateSiblingsNamespaceOrder(left, right []byte) (err error) { +func (n *NmtHasher) validateSiblingsNamespaceOrder(left, right []byte) (err error) { if err := n.ValidateNodeFormat(left); err != nil { return fmt.Errorf("%w: left node does not match the namesapce hash format", err) } @@ -237,7 +254,7 @@ func (n *Hasher) validateSiblingsNamespaceOrder(left, right []byte) (err error) // validity of the inputs of HashNode. It verifies whether left // and right comply by the namespace hash format, and are correctly ordered // according to their namespace IDs. -func (n *Hasher) ValidateNodes(left, right []byte) error { +func (n *NmtHasher) ValidateNodes(left, right []byte) error { if err := n.ValidateNodeFormat(left); err != nil { return err } @@ -260,7 +277,7 @@ func (n *Hasher) ValidateNodes(left, right []byte) error { // slightly changes. Let MAXNID be the maximum possible namespace ID value i.e., 2^NamespaceIDSize-1. // If the namespace range of the right child is start=end=MAXNID, indicating that it represents the root of a subtree whose leaves all have the namespace ID of `MAXNID`, then exclude the right child from the namespace range calculation. Instead, // assign the namespace range of the left child as the parent's namespace range. -func (n *Hasher) HashNode(left, right []byte) ([]byte, error) { +func (n *NmtHasher) HashNode(left, right []byte) ([]byte, error) { // validate the inputs if err := n.ValidateNodes(left, right); err != nil { return nil, err diff --git a/hasher_test.go b/hasher_test.go index 8227ed3e..0bf1178e 100644 --- a/hasher_test.go +++ b/hasher_test.go @@ -600,7 +600,7 @@ func TestWrite_Err(t *testing.T) { tests := []struct { name string - hasher *Hasher + hasher *NmtHasher data []byte wantErr bool errType error @@ -639,7 +639,7 @@ func TestSum_Err(t *testing.T) { tests := []struct { name string - hasher *Hasher + hasher *NmtHasher data []byte nodeType byte wantWriteErr bool diff --git a/nmt.go b/nmt.go index 415e7247..5b897f00 100644 --- a/nmt.go +++ b/nmt.go @@ -39,6 +39,7 @@ type Options struct { // in the "Hasher. IgnoreMaxNamespace bool NodeVisitor NodeVisitorFn + Hasher Hasher } type Option func(*Options) @@ -82,8 +83,15 @@ func NodeVisitor(nodeVisitorFn NodeVisitorFn) Option { } } +// CustomHasher replaces the default hasher. +func CustomHasher(h Hasher) Option { + return func(o *Options) { + o.Hasher = h + } +} + type NamespacedMerkleTree struct { - treeHasher *Hasher + treeHasher Hasher visit NodeVisitorFn // just cache stuff until we pass in a store and keep all nodes in there @@ -128,9 +136,18 @@ func New(h hash.Hash, setters ...Option) *NamespacedMerkleTree { for _, setter := range setters { setter(opts) } - treeHasher := NewNmtHasher(h, opts.NamespaceIDSize, opts.IgnoreMaxNamespace) + + // first create the default hasher using the updated options + hasher := NewNmtHasher(h, opts.NamespaceIDSize, opts.IgnoreMaxNamespace) + opts.Hasher = hasher + + // set the options a second time to replace the hasher if needed + for _, setter := range setters { + setter(opts) + } + return &NamespacedMerkleTree{ - treeHasher: treeHasher, + treeHasher: opts.Hasher, visit: opts.NodeVisitor, leaves: make([][]byte, 0, opts.InitialCapacity), leafHashes: make([][]byte, 0, opts.InitialCapacity), @@ -491,6 +508,27 @@ func (n *NamespacedMerkleTree) MaxNamespace() (namespace.ID, error) { return MaxNamespace(r, n.NamespaceSize()), nil } +// ForceAddLeaf adds a namespaced data to the tree without validating its +// namespace ID. This method should only be used by tests that are attempting to +// create out of order trees. The default hasher will fail for trees that are +// out of order. +func (n *NamespacedMerkleTree) ForceAddLeaf(leaf namespace.PrefixedData) error { + nID := namespace.ID(leaf[:n.NamespaceSize()]) + // compute the leaf hash + res, err := n.treeHasher.HashLeaf(leaf) + if err != nil { + return err + } + + // update relevant "caches": + n.leaves = append(n.leaves, leaf) + n.leafHashes = append(n.leafHashes, res) + n.updateNamespaceRanges() + n.updateMinMaxID(nID) + n.rawRoot = nil + return nil +} + // computeRoot calculates the namespace Merkle root for a tree/sub-tree that // encompasses the leaves within the range of [start, end). // Any errors returned by this method are irrecoverable and indicate an illegal state of the tree (n). diff --git a/nmt_test.go b/nmt_test.go index b1acac27..c4907b5b 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -553,6 +553,19 @@ func TestNodeVisitor(t *testing.T) { } } +func TestCustomHasher(t *testing.T) { + type customHasher struct { + *NmtHasher + } + + h := customHasher{NewNmtHasher(sha256.New(), namespace.IDSize(8), true)} + + tree := New(sha256.New(), NamespaceIDSize(8), IgnoreMaxNamespace(true), CustomHasher(h)) + + _, ok := tree.treeHasher.(customHasher) + require.True(t, ok) +} + func TestNamespacedMerkleTree_ProveErrors(t *testing.T) { tests := []struct { name string @@ -1138,3 +1151,19 @@ func TestEmptyRoot_NMT(t *testing.T) { assert.True(t, bytes.Equal(gotEmptyRoot, expectedEmptyRoot)) } + +func TestForcedOutOfOrderNamespacedMerkleTree(t *testing.T) { + data := [][]byte{ + append(namespace.ID{0}, []byte("leaf_0")...), + append(namespace.ID{2}, []byte("leaf_1")...), + append(namespace.ID{1}, []byte("leaf_2")...), + append(namespace.ID{1}, []byte("leaf_3")...), + } + nidSize := 1 + tree := New(sha256.New(), NamespaceIDSize(nidSize)) + + for _, d := range data { + err := tree.ForceAddLeaf(d) + assert.NoError(t, err) + } +} diff --git a/proof.go b/proof.go index 2db25890..2a1037cf 100644 --- a/proof.go +++ b/proof.go @@ -228,7 +228,7 @@ func (proof Proof) VerifyNamespace(h hash.Hash, nID namespace.ID, leaves [][]byt // the completeness of the proof by verifying that there is no leaf in the // tree represented by the root parameter that matches the namespace ID nID // outside the leafHashes list. -func (proof Proof) VerifyLeafHashes(nth *Hasher, verifyCompleteness bool, nID namespace.ID, leafHashes [][]byte, root []byte) (bool, error) { +func (proof Proof) VerifyLeafHashes(nth *NmtHasher, verifyCompleteness bool, nID namespace.ID, leafHashes [][]byte, root []byte) (bool, error) { // check that the proof range is valid if proof.Start() < 0 || proof.Start() >= proof.End() { return false, fmt.Errorf("proof range [proof.start=%d, proof.end=%d) is not valid: %w", proof.Start(), proof.End(), ErrInvalidRange) diff --git a/proof_test.go b/proof_test.go index 291de098..df0f90fa 100644 --- a/proof_test.go +++ b/proof_test.go @@ -276,7 +276,8 @@ func TestVerifyLeafHashes_Err(t *testing.T) { // create a sample tree nameIDSize := 2 nmt := exampleNMT(nameIDSize, true, 1, 2, 3, 4, 5, 6, 7, 8) - hasher := nmt.treeHasher + nmthasher := nmt.treeHasher + hasher := nmthasher.(*NmtHasher) root, err := nmt.Root() require.NoError(t, err) @@ -330,7 +331,7 @@ func TestVerifyLeafHashes_Err(t *testing.T) { tests := []struct { name string proof Proof - Hasher *Hasher + Hasher *NmtHasher verifyCompleteness bool nID namespace.ID leafHashes [][]byte