From 05e34472abe287bf14c764b8d41a02c7ac7150d1 Mon Sep 17 00:00:00 2001 From: francis <28549164+FrancisLennon17@users.noreply.github.com> Date: Wed, 1 Jan 2025 11:33:24 +0000 Subject: [PATCH 1/2] fix: reduce calls to MinNamespace & MaxNamespace --- hasher.go | 82 ++++++++++++++++++++++++++++---------------------- hasher_test.go | 50 ------------------------------ nmt.go | 6 ++-- nmt_test.go | 1 + 4 files changed, 49 insertions(+), 90 deletions(-) diff --git a/hasher.go b/hasher.go index cb5b3fb8..fd8e05db 100644 --- a/hasher.go +++ b/hasher.go @@ -154,9 +154,9 @@ func (n *NmtHasher) BlockSize() int { func (n *NmtHasher) EmptyRoot() []byte { n.baseHasher.Reset() - emptyNs := bytes.Repeat([]byte{0}, int(n.NamespaceLen)) + emptyNs := bytes.Repeat([]byte{0}, int(n.NamespaceLen)*2) h := n.baseHasher.Sum(nil) - digest := append(append(emptyNs, emptyNs...), h...) + digest := append(emptyNs, h...) return digest } @@ -212,42 +212,59 @@ func (n *NmtHasher) MustHashLeaf(ndata []byte) []byte { return res } -// ValidateNodeFormat checks whether the supplied node conforms to the -// namespaced hash format and returns ErrInvalidNodeLen if not. -func (n *NmtHasher) ValidateNodeFormat(node []byte) (err error) { +type nsIDRange struct { + Min, Max namespace.ID +} + +// tryFetchNodeNSRange attempts to return the min and max namespace ids. +// It will return an ErrInvalidNodeLen | ErrInvalidNodeNamespaceOrder +// if the supplied node does not conform to the namespaced hash format. +func (n *NmtHasher) tryFetchNodeNSRange(node []byte) (nsIDRange, error) { expectedNodeLen := n.Size() nodeLen := len(node) if nodeLen != expectedNodeLen { - return fmt.Errorf("%w: got: %v, want %v", ErrInvalidNodeLen, nodeLen, expectedNodeLen) + return nsIDRange{}, fmt.Errorf("%w: got: %v, want %v", ErrInvalidNodeLen, nodeLen, expectedNodeLen) } // check the namespace order minNID := namespace.ID(MinNamespace(node, n.NamespaceSize())) maxNID := namespace.ID(MaxNamespace(node, n.NamespaceSize())) if maxNID.Less(minNID) { - return fmt.Errorf("%w: max namespace ID %d is less than min namespace ID %d ", ErrInvalidNodeNamespaceOrder, maxNID, minNID) + return nsIDRange{}, fmt.Errorf("%w: max namespace ID %d is less than min namespace ID %d ", ErrInvalidNodeNamespaceOrder, maxNID, minNID) } - return nil + return nsIDRange{Min: minNID, Max: maxNID}, nil +} + +// ValidateNodeFormat checks whether the supplied node conforms to the +// namespaced hash format and returns ErrInvalidNodeLen if not. +func (n *NmtHasher) ValidateNodeFormat(node []byte) (err error) { + _, err = n.tryFetchNodeNSRange(node) + return } -// validateSiblingsNamespaceOrder checks whether left and right as two sibling -// nodes in an NMT have correct namespace IDs relative to each other, more -// specifically, the maximum namespace ID of the left sibling should not exceed -// the minimum namespace ID of the right sibling. It returns ErrUnorderedSiblings error if the check fails. -func (n *NmtHasher) validateSiblingsNamespaceOrder(left, right []byte) (err error) { - if err := n.ValidateNodeFormat(left); err != nil { - return fmt.Errorf("%w: left node does not match the namesapce hash format", err) +// tryFetchLeftAndRightNSRange attempts to return the min/max namespace ids of both +// the left and right nodes. It verifies whether left +// and right comply by the namespace hash format, and are correctly ordered +// according to their namespace IDs. +func (n *NmtHasher) tryFetchLeftAndRightNSRanges(left, right []byte) ( + lNsRange nsIDRange, + rNsRange nsIDRange, + err error, +) { + lNsRange, err = n.tryFetchNodeNSRange(left) + if err != nil { + return } - if err := n.ValidateNodeFormat(right); err != nil { - return fmt.Errorf("%w: right node does not match the namesapce hash format", err) + rNsRange, err = n.tryFetchNodeNSRange(right) + if err != nil { + return } - leftMaxNs := namespace.ID(MaxNamespace(left, n.NamespaceSize())) - rightMinNs := namespace.ID(MinNamespace(right, n.NamespaceSize())) // check the namespace range of the left and right children - if rightMinNs.Less(leftMaxNs) { - return fmt.Errorf("%w: the maximum namespace of the left child %x is greater than the min namespace of the right child %x", ErrUnorderedSiblings, leftMaxNs, rightMinNs) + if rNsRange.Min.Less(lNsRange.Max) { + err = fmt.Errorf("%w: the maximum namespace of the left child %x is greater than the min namespace of the right child %x", ErrUnorderedSiblings, lNsRange.Max, rNsRange.Min) + return } - return nil + return } // ValidateNodes is a helper function to verify the @@ -255,13 +272,8 @@ func (n *NmtHasher) validateSiblingsNamespaceOrder(left, right []byte) (err erro // and right comply by the namespace hash format, and are correctly ordered // according to their namespace IDs. func (n *NmtHasher) ValidateNodes(left, right []byte) error { - if err := n.ValidateNodeFormat(left); err != nil { - return err - } - if err := n.ValidateNodeFormat(right); err != nil { - return err - } - return n.validateSiblingsNamespaceOrder(left, right) + _, _, err := n.tryFetchLeftAndRightNSRanges(left, right) + return err } // HashNode calculates a namespaced hash of a node using the supplied left and @@ -278,21 +290,19 @@ func (n *NmtHasher) ValidateNodes(left, right []byte) error { // If the namespace range of the right child is start=end=MAXNID, indicating that it represents the root of a subtree whose leaves all have the namespace ID of `MAXNID`, then exclude the right child from the namespace range calculation. Instead, // assign the namespace range of the left child as the parent's namespace range. func (n *NmtHasher) HashNode(left, right []byte) ([]byte, error) { - // validate the inputs - if err := n.ValidateNodes(left, right); err != nil { + // validate the inputs & fetch the namespace ranges + lRange, rRange, err := n.tryFetchLeftAndRightNSRanges(left, right) + if err != nil { return nil, err } h := n.baseHasher h.Reset() - leftMinNs, leftMaxNs := MinNamespace(left, n.NamespaceLen), MaxNamespace(left, n.NamespaceLen) - rightMinNs, rightMaxNs := MinNamespace(right, n.NamespaceLen), MaxNamespace(right, n.NamespaceLen) - // compute the namespace range of the parent node - minNs, maxNs := computeNsRange(leftMinNs, leftMaxNs, rightMinNs, rightMaxNs, n.ignoreMaxNs, n.precomputedMaxNs) + minNs, maxNs := computeNsRange(lRange.Min, lRange.Max, rRange.Min, rRange.Max, n.ignoreMaxNs, n.precomputedMaxNs) - res := make([]byte, 0) + res := make([]byte, 0, len(minNs)*2) res = append(res, minNs...) res = append(res, maxNs...) diff --git a/hasher_test.go b/hasher_test.go index 0bf1178e..8014b0b3 100644 --- a/hasher_test.go +++ b/hasher_test.go @@ -283,56 +283,6 @@ func TestHashNode_Error(t *testing.T) { } } -func TestValidateSiblings(t *testing.T) { - // create a dummy hash to use as the digest of the left and right child - randHash := createByteSlice(sha256.Size, 0x01) - - type children struct { - l []byte // namespace hash of the left child with the format of MinNs||MaxNs||h - r []byte // namespace hash of the right child with the format of MinNs||MaxNs||h - } - - tests := []struct { - name string - nidLen namespace.IDSize - children children - wantErr bool - }{ - { - "wrong left node format", 2, - children{concat([]byte{0, 0, 1, 1}, randHash[:len(randHash)-1]), concat([]byte{0, 0, 1, 1}, randHash)}, - true, - }, - { - "wrong right node format", 2, - children{concat([]byte{0, 0, 1, 1}, randHash), concat([]byte{0, 0, 1, 1}, randHash[:len(randHash)-1])}, - true, - }, - { - "left.maxNs>right.minNs", 2, - children{concat([]byte{0, 0, 1, 1}, randHash), concat([]byte{0, 0, 1, 1}, randHash)}, - true, - }, - { - "left.maxNs=right.minNs", 2, - children{concat([]byte{0, 0, 1, 1}, randHash), concat([]byte{1, 1, 2, 2}, randHash)}, - false, - }, - { - "left.maxNs Date: Fri, 17 Jan 2025 10:05:51 +0000 Subject: [PATCH 2/2] fix: review comments --- hasher.go | 32 ++++++++++++++++++-------------- nmt_test.go | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/hasher.go b/hasher.go index fd8e05db..f0d174ff 100644 --- a/hasher.go +++ b/hasher.go @@ -154,9 +154,8 @@ func (n *NmtHasher) BlockSize() int { func (n *NmtHasher) EmptyRoot() []byte { n.baseHasher.Reset() - emptyNs := bytes.Repeat([]byte{0}, int(n.NamespaceLen)*2) h := n.baseHasher.Sum(nil) - digest := append(emptyNs, h...) + digest := append(make([]byte, int(n.NamespaceLen)*2), h...) return digest } @@ -212,6 +211,7 @@ func (n *NmtHasher) MustHashLeaf(ndata []byte) []byte { return res } +// nsIDRange represents the range of namespace IDs with minimum and maximum values. type nsIDRange struct { Min, Max namespace.ID } @@ -235,10 +235,10 @@ func (n *NmtHasher) tryFetchNodeNSRange(node []byte) (nsIDRange, error) { } // ValidateNodeFormat checks whether the supplied node conforms to the -// namespaced hash format and returns ErrInvalidNodeLen if not. -func (n *NmtHasher) ValidateNodeFormat(node []byte) (err error) { - _, err = n.tryFetchNodeNSRange(node) - return +// namespaced hash format and returns an error if not. +func (n *NmtHasher) ValidateNodeFormat(node []byte) error { + _, err := n.tryFetchNodeNSRange(node) + return err } // tryFetchLeftAndRightNSRange attempts to return the min/max namespace ids of both @@ -246,25 +246,29 @@ func (n *NmtHasher) ValidateNodeFormat(node []byte) (err error) { // and right comply by the namespace hash format, and are correctly ordered // according to their namespace IDs. func (n *NmtHasher) tryFetchLeftAndRightNSRanges(left, right []byte) ( - lNsRange nsIDRange, - rNsRange nsIDRange, - err error, + nsIDRange, + nsIDRange, + error, ) { + var lNsRange nsIDRange + var rNsRange nsIDRange + var err error + lNsRange, err = n.tryFetchNodeNSRange(left) if err != nil { - return + return lNsRange, rNsRange, err } rNsRange, err = n.tryFetchNodeNSRange(right) if err != nil { - return + return lNsRange, rNsRange, err } // check the namespace range of the left and right children if rNsRange.Min.Less(lNsRange.Max) { - err = fmt.Errorf("%w: the maximum namespace of the left child %x is greater than the min namespace of the right child %x", ErrUnorderedSiblings, lNsRange.Max, rNsRange.Min) - return + err = fmt.Errorf("%w: the min namespace ID of the right child %d is less than the max namespace ID of the left child %d", ErrUnorderedSiblings, rNsRange.Min, lNsRange.Max) } - return + + return lNsRange, rNsRange, err } // ValidateNodes is a helper function to verify the diff --git a/nmt_test.go b/nmt_test.go index 9728ff5d..1ffdccff 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -704,7 +704,7 @@ func BenchmarkComputeRoot(b *testing.B) { {"64-leaves", 64, 8, 256}, {"128-leaves", 128, 8, 256}, {"256-leaves", 256, 8, 256}, - //{"20k-leaves", 20000, 8, 512}, + {"20k-leaves", 20000, 8, 512}, } for _, tt := range tests {