diff --git a/nmt.go b/nmt.go index e9c318a2..f8f82f5f 100644 --- a/nmt.go +++ b/nmt.go @@ -107,9 +107,9 @@ type NamespacedMerkleTree struct { // namespaceRanges can be used to efficiently look up the range for an // existing namespace without iterating through the leaves. The map key is - // the string representation of a namespace.ID and the leafRange indicates + // the string representation of a namespace.ID and the LeafRange indicates // the range of the leaves matching that namespace ID in the tree - namespaceRanges map[string]leafRange + namespaceRanges map[string]LeafRange // minNID is the minimum namespace ID of the leaves minNID namespace.ID // maxNID is the maximum namespace ID of the leaves @@ -151,7 +151,7 @@ func New(h hash.Hash, setters ...Option) *NamespacedMerkleTree { visit: opts.NodeVisitor, leaves: make([][]byte, 0, opts.InitialCapacity), leafHashes: make([][]byte, 0, opts.InitialCapacity), - namespaceRanges: make(map[string]leafRange), + namespaceRanges: make(map[string]LeafRange), minNID: bytes.Repeat([]byte{0xFF}, int(opts.NamespaceIDSize)), maxNID: bytes.Repeat([]byte{0x00}, int(opts.NamespaceIDSize)), } @@ -436,7 +436,7 @@ func (n *NamespacedMerkleTree) foundInRange(nID namespace.ID) (found bool, start // This is a faster version of this code snippet: // https://github.com/celestiaorg/celestiaorg-prototype/blob/2aeca6f55ad389b9d68034a0a7038f80a8d2982e/simpleblock.go#L106-L117 foundRng, found := n.namespaceRanges[string(nID)] - return found, foundRng.start, foundRng.end + return found, foundRng.Start, foundRng.End } // NamespaceSize returns the underlying namespace size. Note that all namespaced @@ -590,14 +590,14 @@ func (n *NamespacedMerkleTree) updateNamespaceRanges() { lastNsStr := string(lastPushed[:n.treeHasher.NamespaceSize()]) lastRange, found := n.namespaceRanges[lastNsStr] if !found { - n.namespaceRanges[lastNsStr] = leafRange{ - start: lastIndex, - end: lastIndex + 1, + n.namespaceRanges[lastNsStr] = LeafRange{ + Start: lastIndex, + End: lastIndex + 1, } } else { - n.namespaceRanges[lastNsStr] = leafRange{ - start: lastRange.start, - end: lastRange.end + 1, + n.namespaceRanges[lastNsStr] = LeafRange{ + Start: lastRange.Start, + End: lastRange.End + 1, } } } @@ -644,11 +644,38 @@ func (n *NamespacedMerkleTree) updateMinMaxID(id namespace.ID) { } } -type leafRange struct { - // start and end denote the indices of a leaf in the tree. start ranges from - // 0 up to the total number of leaves minus 1 end ranges from 1 up to the - // total number of leaves end is non-inclusive - start, end int +// ComputeSubtreeRoot takes a leaf range and returns the corresponding subtree root. +// Also, it requires the start and end range to correctly reference an inner node. +// The provided range, defined by start and end, is end-exclusive. +func (n *NamespacedMerkleTree) ComputeSubtreeRoot(start, end int) ([]byte, error) { + if start < 0 { + return nil, fmt.Errorf("start %d shouldn't be strictly negative", start) + } + if end <= start { + return nil, fmt.Errorf("end %d should be stricly bigger than start %d", end, start) + } + uStart, err := safeIntToUint(start) + if err != nil { + return nil, err + } + uEnd, err := safeIntToUint(end) + if err != nil { + return nil, err + } + // check if the provided range correctly references an inner node. + // calculates the ideal tree from the provided range, and verifies if it is the same as the range + if idealTreeRange := nextSubtreeSize(uint64(uStart), uint64(uEnd)); end-start != idealTreeRange { + return nil, fmt.Errorf("the provided range [%d, %d) does not construct a valid subtree root range", start, end) + } + return n.computeRoot(start, end) +} + +type LeafRange struct { + // Start and End denote the indices of a leaf in the tree. + // Start ranges from 0 up to the total number of leaves minus 1. + // End ranges from 1 up to the total number of leaves. + // End is non-inclusive + Start, End int } // MinNamespace extracts the minimum namespace ID from a given namespace hash, diff --git a/nmt_test.go b/nmt_test.go index 6e0565e7..83076818 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -862,6 +862,20 @@ func exampleNMT(nidSize int, ignoreMaxNamespace bool, leavesNIDs ...byte) *Names return tree } +// exampleNMT2 Replica of exampleNMT except that it uses the namespace IDs in the +// leaves instead of the index. +func exampleNMT2(nidSize int, ignoreMaxNamespace bool, leavesNIDs ...byte) *NamespacedMerkleTree { + tree := New(sha256.New(), NamespaceIDSize(nidSize), IgnoreMaxNamespace(ignoreMaxNamespace)) + for _, nid := range leavesNIDs { + namespace := bytes.Repeat([]byte{nid}, nidSize) + d := append(namespace, []byte(fmt.Sprintf("leaf_%d", nid))...) + if err := tree.Push(d); err != nil { + panic(fmt.Sprintf("unexpected error: %v", err)) + } + } + return tree +} + func swap(slice [][]byte, i int, j int) { temp := slice[i] slice[i] = slice[j] @@ -1175,3 +1189,126 @@ func TestForcedOutOfOrderNamespacedMerkleTree(t *testing.T) { assert.NoError(t, err) } } + +func TestComputeSubtreeRoot(t *testing.T) { + n := exampleNMT2(1, true, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + tests := []struct { + start, end int + tree *NamespacedMerkleTree + expectedRoot []byte + expectError bool + }{ + { + start: 0, + end: 16, + tree: n, + expectedRoot: func() []byte { + root, err := n.Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 0, + end: 8, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [0,8) coincides with the root of this tree + root, err := exampleNMT2(1, true, 0, 1, 2, 3, 4, 5, 6, 7).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 8, + end: 16, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [8,16) coincides with the root of this tree + root, err := exampleNMT2(1, true, 8, 9, 10, 11, 12, 13, 14, 15).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 8, + end: 12, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [8,12) coincides with the root of this tree + root, err := exampleNMT2(1, true, 8, 9, 10, 11).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 4, + end: 8, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [4,8) coincides with the root of this tree + root, err := exampleNMT2(1, true, 4, 5, 6, 7).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 4, + end: 6, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [4,6) coincides with the root of this tree + root, err := exampleNMT2(1, true, 4, 5).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 4, + end: 5, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [4,5) coincides with the root of this tree + root, err := exampleNMT2(1, true, 4).Root() + require.NoError(t, err) + return root + }(), + }, + { // doesn't correctly reference an inner node + start: 2, + end: 6, + tree: n, + expectError: true, + }, + { + start: -1, // invalid start + end: 4, + tree: n, + expectError: true, + }, + { + start: 4, + end: 4, // start == end + tree: n, + expectError: true, + }, + { + start: 5, // start >= end + end: 4, + tree: n, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("treeSize=%d,start=%d,end=%d", tt.tree.Size(), tt.start, tt.end), func(t *testing.T) { + root, err := tt.tree.ComputeSubtreeRoot(tt.start, tt.end) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRoot, root) + } + }) + } +} diff --git a/proof.go b/proof.go index 6e824e5e..bde67237 100644 --- a/proof.go +++ b/proof.go @@ -9,7 +9,7 @@ import ( "math/bits" "github.com/celestiaorg/nmt/namespace" - pb "github.com/celestiaorg/nmt/pb" + "github.com/celestiaorg/nmt/pb" ) var ( @@ -465,6 +465,187 @@ func (proof Proof) VerifyInclusion(h hash.Hash, nid namespace.ID, leavesWithoutN return res } +// VerifySubtreeRootInclusion verifies that a set of subtree roots is included in +// an NMT. +// Warning: This method is Celestia specific! Using it without verifying +// the following assumptions, can return unexpected errors, false positive/negatives: +// - The subtree roots are created according to the ADR-013 +// https://github.com/celestiaorg/celestia-app/blob/main/docs/architecture/adr-013-non-interactive-default-rules-for-zero-padding.md +// - The tree's number of leaves is a power of two +// The subtreeWidth is also defined in ADR-013. +// More information on the algorithm used can be found in the ToLeafRanges() method docs. +func (proof Proof) VerifySubtreeRootInclusion(nth *NmtHasher, subtreeRoots [][]byte, subtreeWidth int, root []byte) (bool, error) { + // check that the proof range is valid + if proof.Start() < 0 || proof.Start() >= proof.End() { + return false, fmt.Errorf("proof range [proof.start=%d, proof.end=%d) is not valid: %w", proof.Start(), proof.End(), ErrInvalidRange) + } + + // check that the root is valid w.r.t the NMT hasher + if err := nth.ValidateNodeFormat(root); err != nil { + return false, fmt.Errorf("root does not match the NMT hasher's hash format: %w", err) + } + // check that all the proof.Notes() are valid w.r.t the NMT hasher + for _, node := range proof.Nodes() { + if err := nth.ValidateNodeFormat(node); err != nil { + return false, fmt.Errorf("proof nodes do not match the NMT hasher's hash format: %w", err) + } + } + // check that all the subtree roots are valid w.r.t the NMT hasher + for _, subtreeRoot := range subtreeRoots { + if err := nth.ValidateNodeFormat(subtreeRoot); err != nil { + return false, fmt.Errorf("inner nodes does not match the NMT hasher's hash format: %w", err) + } + } + + // get the subtree roots leaf ranges + ranges, err := ToLeafRanges(proof.Start(), proof.End(), subtreeWidth) + if err != nil { + return false, err + } + + // check whether the number of ranges matches the number of subtree roots. + // if not, make an early return. + if len(subtreeRoots) != len(ranges) { + return false, fmt.Errorf("number of subtree roots %d is different than the number of the expected leaf ranges %d", len(subtreeRoots), len(ranges)) + } + + var computeRoot func(start, end int) ([]byte, error) + // computeRoot can return error iff the HashNode function fails while calculating the root + computeRoot = func(start, end int) ([]byte, error) { + // if the current range does not overlap with the proof range, pop and + // return a proof node if present, else return nil because subtree + // doesn't exist + if end <= proof.Start() || start >= proof.End() { + return popIfNonEmpty(&proof.nodes), nil + } + + if len(ranges) == 0 { + return nil, fmt.Errorf(fmt.Sprintf("expected to have a subtree root for range [%d, %d)", start, end)) + } + + if ranges[0].Start == start && ranges[0].End == end { + ranges = ranges[1:] + return popIfNonEmpty(&subtreeRoots), nil + } + + // Recursively get left and right subtree + k := getSplitPoint(end - start) + left, err := computeRoot(start, start+k) + if err != nil { + return nil, fmt.Errorf("failed to compute subtree root [%d, %d): %w", start, start+k, err) + } + right, err := computeRoot(start+k, end) + if err != nil { + return nil, fmt.Errorf("failed to compute subtree root [%d, %d): %w", start+k, end, err) + } + + // only right leaf/subtree can be non-existent + if right == nil { + return left, nil + } + hash, err := nth.HashNode(left, right) + if err != nil { + return nil, fmt.Errorf("failed to hash node: %w", err) + } + return hash, nil + } + + // estimate the leaf size of the subtree containing the proof range + proofRangeSubtreeEstimate := getSplitPoint(proof.End()) * 2 + if proofRangeSubtreeEstimate < 1 { + proofRangeSubtreeEstimate = 1 + } + rootHash, err := computeRoot(0, proofRangeSubtreeEstimate) + if err != nil { + return false, fmt.Errorf("failed to compute root [%d, %d): %w", 0, proofRangeSubtreeEstimate, err) + } + for i := 0; i < len(proof.Nodes()); i++ { + rootHash, err = nth.HashNode(rootHash, proof.Nodes()[i]) + if err != nil { + return false, fmt.Errorf("failed to hash node: %w", err) + } + } + + return bytes.Equal(rootHash, root), nil +} + +// ToLeafRanges returns the leaf ranges corresponding to the provided subtree roots. +// The proof range defined by proofStart and proofEnd is end exclusive. +// It uses the subtree root width to calculate the maximum number of leaves a subtree root can +// commit to. +// The subtree root width is defined as per ADR-013: +// https://github.com/celestiaorg/celestia-app/blob/main/docs/architecture/adr-013-non-interactive-default-rules-for-zero-padding.md +// This method assumes: +// - The subtree roots are created according to the ADR-013 non-interactive defaults rules +// - The tree's number of leaves is a power of two +// The algorithm is as follows: +// - Let `d` be `y - x` (the range of the proof). +// - `i` is the index of the next subtree root. +// - While `d != 0`: +// - Let `z` be the largest power of 2 that fits in `d`; here we are finding the range for the next subtree root. +// - The range for the next subtree root is `[x, x + z)`, i.e., `S_i` is the subtree root of leaves at indices `[x, x + z)`. +// - `d = d - z` (move past the first subtree root and its range). +// - `i = i + 1`. +// - Go back to the loop condition. +// +// Note: This method is Celestia specific. +func ToLeafRanges(proofStart, proofEnd, subtreeWidth int) ([]LeafRange, error) { + if proofStart < 0 { + return nil, fmt.Errorf("proof start %d shouldn't be strictly negative", proofStart) + } + if proofEnd <= proofStart { + return nil, fmt.Errorf("proof end %d should be stricly bigger than proof start %d", proofEnd, proofStart) + } + if subtreeWidth <= 0 { + return nil, fmt.Errorf("subtree root width cannot be negative %d", subtreeWidth) + } + currentStart := proofStart + currentLeafRange := proofEnd - proofStart + var ranges []LeafRange + maximumLeafRange := subtreeWidth + for currentLeafRange != 0 { + nextRange, err := nextLeafRange(currentStart, proofEnd, maximumLeafRange) + if err != nil { + return nil, err + } + ranges = append(ranges, nextRange) + currentStart = nextRange.End + currentLeafRange = currentLeafRange - nextRange.End + nextRange.Start + } + return ranges, nil +} + +// nextLeafRange takes a proof start, proof end, and the maximum range a subtree +// root can cover, and returns the corresponding subtree root range. +// Check ToLeafRanges() for more information on the algorithm used. +// The subtreeWidth is calculated using SubTreeWidth() method +// in celestiaorg/go-square/inclusion package. +// The subtreeWidth is a power of two. +// Also, the LeafRange values, i.e., the range size, are all powers of two. +// Note: This method is Celestia specific. +func nextLeafRange(currentStart, currentEnd, subtreeWidth int) (LeafRange, error) { + currentLeafRange := currentEnd - currentStart + minimum := minInt(currentLeafRange, subtreeWidth) + uMinimum, err := safeIntToUint(minimum) + if err != nil { + return LeafRange{}, fmt.Errorf("failed to convert subtree root range to Uint %w", err) + } + currentRange, err := largestPowerOfTwo(uMinimum) + if err != nil { + return LeafRange{}, err + } + return LeafRange{Start: currentStart, End: currentStart + currentRange}, nil +} + +// largestPowerOfTwo calculates the largest power of two +// that is smaller than 'bound' +func largestPowerOfTwo(bound uint) (int, error) { + if bound == 0 { + return 0, fmt.Errorf("bound cannot be equal to 0") + } + return 1 << (bits.Len(bound) - 1), nil +} + // ProtoToProof creates a proof from its proto representation. func ProtoToProof(protoProof pb.Proof) Proof { if protoProof.Start == 0 && protoProof.End == 0 { @@ -512,3 +693,17 @@ func popIfNonEmpty(s *[][]byte) []byte { } return nil } + +func safeIntToUint(val int) (uint, error) { + if val < 0 { + return 0, fmt.Errorf("cannot convert a negative int %d to uint", val) + } + return uint(val), nil +} + +func minInt(val1, val2 int) int { + if val1 > val2 { + return val2 + } + return val1 +} diff --git a/proof_test.go b/proof_test.go index 235a4037..b2ea98fe 100644 --- a/proof_test.go +++ b/proof_test.go @@ -3,6 +3,7 @@ package nmt import ( "bytes" "crypto/sha256" + "fmt" "hash" "testing" @@ -1124,3 +1125,706 @@ func Test_ProtoToProof(t *testing.T) { }) } } + +func TestLargestPowerOfTwo(t *testing.T) { + tests := []struct { + bound uint + expected int + expectError bool + }{ + {bound: 1, expected: 1}, + {bound: 2, expected: 2}, + {bound: 3, expected: 2}, + {bound: 4, expected: 4}, + {bound: 5, expected: 4}, + {bound: 6, expected: 4}, + {bound: 7, expected: 4}, + {bound: 8, expected: 8}, + {bound: 0, expectError: true}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("bound=%d", tt.bound), func(t *testing.T) { + result, err := largestPowerOfTwo(tt.bound) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestToLeafRanges(t *testing.T) { + tests := []struct { + proofStart, proofEnd, subtreeWidth int + expectedRanges []LeafRange + expectError bool + }{ + { + proofStart: 0, + proofEnd: 8, + subtreeWidth: 1, + expectedRanges: []LeafRange{ + {Start: 0, End: 1}, + {Start: 1, End: 2}, + {Start: 2, End: 3}, + {Start: 3, End: 4}, + {Start: 4, End: 5}, + {Start: 5, End: 6}, + {Start: 6, End: 7}, + {Start: 7, End: 8}, + }, + }, + { + proofStart: 0, + proofEnd: 9, + subtreeWidth: 1, + expectedRanges: []LeafRange{ + {Start: 0, End: 1}, + {Start: 1, End: 2}, + {Start: 2, End: 3}, + {Start: 3, End: 4}, + {Start: 4, End: 5}, + {Start: 5, End: 6}, + {Start: 6, End: 7}, + {Start: 7, End: 8}, + {Start: 8, End: 9}, + }, + }, + { + proofStart: 0, + proofEnd: 16, + subtreeWidth: 1, + expectedRanges: []LeafRange{ + {Start: 0, End: 1}, + {Start: 1, End: 2}, + {Start: 2, End: 3}, + {Start: 3, End: 4}, + {Start: 4, End: 5}, + {Start: 5, End: 6}, + {Start: 6, End: 7}, + {Start: 7, End: 8}, + {Start: 8, End: 9}, + {Start: 9, End: 10}, + {Start: 10, End: 11}, + {Start: 11, End: 12}, + {Start: 12, End: 13}, + {Start: 13, End: 14}, + {Start: 14, End: 15}, + {Start: 15, End: 16}, + }, + }, + { + proofStart: 0, + proofEnd: 100, + subtreeWidth: 2, + expectedRanges: func() []LeafRange { + var ranges []LeafRange + for i := 0; i < 100; i = i + 2 { + ranges = append(ranges, LeafRange{i, i + 2}) + } + return ranges + }(), + }, + { + proofStart: 0, + proofEnd: 150, + subtreeWidth: 4, + expectedRanges: func() []LeafRange { + var ranges []LeafRange + for i := 0; i < 148; i = i + 4 { + ranges = append(ranges, LeafRange{i, i + 4}) + } + ranges = append(ranges, LeafRange{ + Start: 148, + End: 150, + }) + return ranges + }(), + }, + { + proofStart: 0, + proofEnd: 400, + subtreeWidth: 8, + expectedRanges: func() []LeafRange { + var ranges []LeafRange + for i := 0; i < 400; i = i + 8 { + ranges = append(ranges, LeafRange{i, i + 8}) + } + return ranges + }(), + }, + { + proofStart: -1, + proofEnd: 0, + subtreeWidth: -1, + expectedRanges: nil, + expectError: true, + }, + { + proofStart: 0, + proofEnd: -1, + subtreeWidth: -1, + expectedRanges: nil, + expectError: true, + }, + { + proofStart: 0, + proofEnd: 0, + subtreeWidth: 2, + expectedRanges: nil, + expectError: true, + }, + { + proofStart: 0, + proofEnd: 0, + subtreeWidth: -1, + expectedRanges: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("proofStart=%d, proofEnd=%d, subtreeWidth=%d", tt.proofStart, tt.proofEnd, tt.subtreeWidth), func(t *testing.T) { + result, err := ToLeafRanges(tt.proofStart, tt.proofEnd, tt.subtreeWidth) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.True(t, compareRanges(result, tt.expectedRanges)) + } + }) + } +} + +func compareRanges(a, b []LeafRange) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestNextLeafRange(t *testing.T) { + tests := []struct { + currentStart, currentEnd int + // the maximum leaf range == subtree width used in these tests do not follow ADR-013 + // they're just used to try different test cases + subtreeRootMaximumLeafRange int + expectedRange LeafRange + expectError bool + }{ + { + currentStart: 0, + currentEnd: 8, + subtreeRootMaximumLeafRange: 4, + expectedRange: LeafRange{Start: 0, End: 4}, + }, + { + currentStart: 4, + currentEnd: 10, + subtreeRootMaximumLeafRange: 8, + expectedRange: LeafRange{Start: 4, End: 8}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 16, + expectedRange: LeafRange{Start: 4, End: 20}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 1, + expectedRange: LeafRange{Start: 4, End: 5}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 2, + expectedRange: LeafRange{Start: 4, End: 6}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 4, + expectedRange: LeafRange{Start: 4, End: 8}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 8, + expectedRange: LeafRange{Start: 4, End: 12}, + }, + { + currentStart: 0, + currentEnd: 1, + subtreeRootMaximumLeafRange: 1, + expectedRange: LeafRange{Start: 0, End: 1}, + }, + { + currentStart: 0, + currentEnd: 16, + subtreeRootMaximumLeafRange: 16, + expectedRange: LeafRange{Start: 0, End: 16}, + }, + { + currentStart: 0, + currentEnd: 0, + subtreeRootMaximumLeafRange: 4, + expectError: true, + }, + { + currentStart: 5, + currentEnd: 2, + subtreeRootMaximumLeafRange: 4, + expectError: true, + }, + { + currentStart: 5, + currentEnd: 2, + subtreeRootMaximumLeafRange: 0, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("currentStart=%d, currentEnd=%d, subtreeRootMaximumLeafRange=%d", tt.currentStart, tt.currentEnd, tt.subtreeRootMaximumLeafRange), func(t *testing.T) { + result, err := nextLeafRange(tt.currentStart, tt.currentEnd, tt.subtreeRootMaximumLeafRange) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRange, result) + } + }) + } +} + +func TestSafeIntToUint(t *testing.T) { + tests := []struct { + input int + expectedUint uint + expectedError error + }{ + { + input: 10, + expectedUint: 10, + expectedError: nil, + }, + { + input: 0, + expectedUint: 0, + expectedError: nil, + }, + { + input: -5, + expectedUint: 0, + expectedError: fmt.Errorf("cannot convert a negative int %d to uint", -5), + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("input=%d", tt.input), func(t *testing.T) { + result, err := safeIntToUint(tt.input) + if (err != nil) != (tt.expectedError != nil) || (err != nil && err.Error() != tt.expectedError.Error()) { + t.Errorf("expected error %v, got %v", tt.expectedError, err) + } + if result != tt.expectedUint { + t.Errorf("expected uint %v, got %v", tt.expectedUint, result) + } + }) + } +} + +func TestMinInt(t *testing.T) { + tests := []struct { + val1, val2 int + expected int + }{ + { + val1: 10, + val2: 20, + expected: 10, + }, + { + val1: -5, + val2: 6, + expected: -5, + }, + { + val1: 5, + val2: -6, + expected: -6, + }, + { + val1: -5, + val2: -6, + expected: -6, + }, + { + val1: 0, + val2: 0, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("val1=%d, val2=%d", tt.val1, tt.val2), func(t *testing.T) { + result := minInt(tt.val1, tt.val2) + if result != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, result) + } + }) + } +} + +func TestVerifySubtreeRootInclusion(t *testing.T) { + tree := exampleNMT(1, true, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + root, err := tree.Root() + require.NoError(t, err) + + nmthasher := tree.treeHasher + hasher := nmthasher.(*NmtHasher) + + tests := []struct { + proof Proof + subtreeRoots [][]byte + // the subtree widths used in these tests do not follow ADR-013 + // they're just used to try different test cases + subtreeWidth int + root []byte + validProof bool + expectError bool + }{ + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(0, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 1) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(0, 1) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 2) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(0, 2) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(2, 4) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(2, 4) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 4) + require.NoError(t, err) + subtreeRoot2, err := tree.ComputeSubtreeRoot(4, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot1, subtreeRoot2} + }(), + subtreeWidth: 4, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 2) + require.NoError(t, err) + subtreeRoot2, err := tree.ComputeSubtreeRoot(2, 4) + require.NoError(t, err) + subtreeRoot3, err := tree.ComputeSubtreeRoot(4, 6) + require.NoError(t, err) + subtreeRoot4, err := tree.ComputeSubtreeRoot(6, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot1, subtreeRoot2, subtreeRoot3, subtreeRoot4} + }(), + subtreeWidth: 2, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 1) + require.NoError(t, err) + subtreeRoot2, err := tree.ComputeSubtreeRoot(1, 2) + require.NoError(t, err) + subtreeRoot3, err := tree.ComputeSubtreeRoot(2, 3) + require.NoError(t, err) + subtreeRoot4, err := tree.ComputeSubtreeRoot(3, 4) + require.NoError(t, err) + subtreeRoot5, err := tree.ComputeSubtreeRoot(4, 5) + require.NoError(t, err) + subtreeRoot6, err := tree.ComputeSubtreeRoot(5, 6) + require.NoError(t, err) + subtreeRoot7, err := tree.ComputeSubtreeRoot(6, 7) + require.NoError(t, err) + subtreeRoot8, err := tree.ComputeSubtreeRoot(7, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot1, subtreeRoot2, subtreeRoot3, subtreeRoot4, subtreeRoot5, subtreeRoot6, subtreeRoot7, subtreeRoot8} + }(), + subtreeWidth: 1, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(4, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(4, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(12, 14) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(12, 14) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(14, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(14, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(14, 15) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(14, 15) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: -3, // invalid subtree root width + root: root, + expectError: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot, subtreeRoot} // invalid number of subtree roots + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: []byte("random root"), // invalid root format + expectError: true, + }, + { + proof: Proof{start: -1}, // invalid start + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: Proof{end: 1, start: 2}, // invalid end + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: Proof{ + start: 0, + end: 4, + nodes: [][]byte{[]byte("invalid proof node")}, // invalid proof node + }, + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: [][]byte{[]byte("invalid subtree root")}, // invalid subtree root + subtreeWidth: 8, + root: root, + expectError: true, + }, + + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 4) + require.NoError(t, err) + return [][]byte{subtreeRoot1} // will error because it requires the subtree root of [4,8) too + }(), + subtreeWidth: 4, + root: root, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("proofStart=%d, proofEnd=%d, subTreeWidth=%d", tt.proof.Start(), tt.proof.End(), tt.subtreeWidth), func(t *testing.T) { + result, err := tt.proof.VerifySubtreeRootInclusion(hasher, tt.subtreeRoots, tt.subtreeWidth, tt.root) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.validProof, result) + } + }) + } +}