diff --git a/types/row_proof.go b/types/row_proof.go index 2498f32f7d..8c50d90e35 100644 --- a/types/row_proof.go +++ b/types/row_proof.go @@ -34,6 +34,25 @@ func (rp RowProof) Validate(root []byte) error { if len(rp.Proofs) != len(rp.RowRoots) { return fmt.Errorf("the number of proofs %d must equal the number of row roots %d", len(rp.Proofs), len(rp.RowRoots)) } + if len(rp.Proofs) == 0 { + return fmt.Errorf("empty proofs") + } + firstProofIndex := rp.Proofs[0].Index + for i := 1; i < len(rp.Proofs); i++ { + if rp.Proofs[0].Total != rp.Proofs[i].Total { + return errors.New("proofs should have the same total") + } + if rp.Proofs[i].Index != firstProofIndex+1 { + return errors.New("proof indexes are not sequential") + } + firstProofIndex++ + } + if int64(rp.StartRow) != rp.Proofs[0].Index { + return fmt.Errorf("invalid start row") + } + if int64(rp.EndRow) != rp.Proofs[len(rp.Proofs)-1].Index { + return fmt.Errorf("invalid end row") + } if !rp.VerifyProof(root) { return errors.New("row proof failed to verify") } diff --git a/types/row_proof_test.go b/types/row_proof_test.go index 5026c2f19c..3649fa5380 100644 --- a/types/row_proof_test.go +++ b/types/row_proof_test.go @@ -59,6 +59,46 @@ func TestRowProofValidate(t *testing.T) { root: root, wantErr: true, }, + { + name: "proof with different total", + rp: func() RowProof { + proof := validRowProof() + proof.Proofs[0].Total += 1 + return proof + }(), + root: root, + wantErr: true, + }, + { + name: "proof with inconsequential indexes", + rp: func() RowProof { + proof := validRowProof() + proof.Proofs[0].Index -= 1 + return proof + }(), + root: root, + wantErr: true, + }, + { + name: "invalid start row", + rp: func() RowProof { + proof := validRowProof() + proof.StartRow += 1 + return proof + }(), + root: root, + wantErr: true, + }, + { + name: "invalid end row", + rp: func() RowProof { + proof := validRowProof() + proof.EndRow += 1 + return proof + }(), + root: root, + wantErr: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) {