Skip to content

Commit

Permalink
tappsbt: add test vectors for PSBT encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
guggero committed Jun 21, 2023
1 parent ba87e04 commit d54ede9
Show file tree
Hide file tree
Showing 5 changed files with 1,070 additions and 53 deletions.
7 changes: 6 additions & 1 deletion tappsbt/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tappsbt

import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2/schnorr"

"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/taproot-assets/address"
Expand Down Expand Up @@ -51,14 +52,18 @@ func FromAddresses(receiverAddrs []*address.Tap,
// index, but start at the first one indicated by the caller.
for idx := range receiverAddrs {
addr := receiverAddrs[idx]

schnorrInternalKey, _ := schnorr.ParsePubKey(
schnorr.SerializePubKey(&addr.InternalKey),
)
pkt.Outputs = append(pkt.Outputs, &VOutput{
Amount: addr.Amount,
Interactive: false,
AnchorOutputIndex: firstOutputIndex + uint32(idx),
ScriptKey: asset.NewScriptKey(
&addr.ScriptKey,
),
AnchorOutputInternalKey: &addr.InternalKey,
AnchorOutputInternalKey: schnorrInternalKey,
AnchorOutputTapscriptSibling: addr.TapscriptSibling,
})
}
Expand Down
210 changes: 160 additions & 50 deletions tappsbt/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,29 @@ package tappsbt

import (
"bytes"
"encoding/base64"
"encoding/hex"
"os"
"path/filepath"
"reflect"
"strings"
"testing"

"github.com/lightninglabs/taproot-assets/address"
"github.com/lightninglabs/taproot-assets/asset"
"github.com/lightninglabs/taproot-assets/internal/test"
"github.com/stretchr/testify/require"
)

var (
generatedTestVectorName = "psbt_encoding_generated.json"

allTestVectorFiles = []string{
generatedTestVectorName,
"psbt_encoding_error_cases.json",
}
)

// assertEqualPackets asserts that two packets are equal and prints a nice diff
// if they are not.
func assertEqualPackets(t *testing.T, expected, actual *VPacket) {
Expand Down Expand Up @@ -45,58 +57,83 @@ func assertEqualPackets(t *testing.T, expected, actual *VPacket) {
}
}

// TestNewFromRawBytes tests the decoding of a virtual packet from raw bytes.
func TestNewFromRawBytes(t *testing.T) {
t.Parallel()

pkg := RandPacket(t)
packet, err := pkg.EncodeAsPsbt()
require.NoError(t, err)

var buf bytes.Buffer
err = packet.Serialize(&buf)
require.NoError(t, err)

decoded, err := NewFromRawBytes(&buf, false)
require.NoError(t, err)

assertEqualPackets(t, pkg, decoded)
}

// TestNewFromPsbt tests the decoding of a virtual packet from a PSBT packet.
func TestNewFromPsbt(t *testing.T) {
t.Parallel()

pkg := RandPacket(t)
packet, err := pkg.EncodeAsPsbt()
require.NoError(t, err)

decoded, err := NewFromPsbt(packet)
require.NoError(t, err)

assertEqualPackets(t, pkg, decoded)
}

// TestMinimalContent tests the decoding of a virtual packet with the minimal
// amount of information set.
func TestMinimalContent(t *testing.T) {
t.Parallel()

addr, _, _ := address.RandAddr(t, testParams)

pkg, err := FromAddresses([]*address.Tap{addr.Tap}, 1)
require.NoError(t, err)
pkg.Outputs = append(pkg.Outputs, &VOutput{
ScriptKey: asset.RandScriptKey(t),
})
var buf bytes.Buffer
err = pkg.Serialize(&buf)
require.NoError(t, err)
// TestEncodingDecoding tests the decoding of a virtual packet from raw bytes.
func TestEncodingDecoding(t *testing.T) {
// Reset the random source to ensure that we get the same result for
// each run of this test.
test.ResetRand()

testVectors := &TestVectors{}
assertEncodingDecoding := func(comment string, pkg *VPacket) {
// Encode the packet as a PSBT packet then as base64.
packet, err := pkg.EncodeAsPsbt()
require.NoError(t, err)

var buf bytes.Buffer
err = packet.Serialize(&buf)
require.NoError(t, err)

testVectors.ValidTestCases = append(
testVectors.ValidTestCases, &ValidTestCase{
Packet: NewTestFromVPacket(t, pkg),
Expected: base64.StdEncoding.EncodeToString(
buf.Bytes(),
),
Comment: comment,
},
)

// Make sure we can read the packet back from the raw bytes.
decoded, err := NewFromRawBytes(&buf, false)
require.NoError(t, err)

assertEqualPackets(t, pkg, decoded)

// Also make sure we can decode the packet from the base PSBT.
decoded, err = NewFromPsbt(packet)
require.NoError(t, err)

assertEqualPackets(t, pkg, decoded)
}

decoded, err := NewFromRawBytes(&buf, false)
require.NoError(t, err)
testCases := []struct {
name string
pkg func(t *testing.T) *VPacket
}{{
name: "minimal packet",
pkg: func(t *testing.T) *VPacket {
addr, _, _ := address.RandAddr(t, testParams)

pkg, err := FromAddresses([]*address.Tap{addr.Tap}, 1)
require.NoError(t, err)
pkg.Outputs = append(pkg.Outputs, &VOutput{
ScriptKey: asset.RandScriptKey(t),
})

return pkg
},
}, {
name: "random packet",
pkg: func(t *testing.T) *VPacket {
return RandPacket(t)
},
}}

for _, testCase := range testCases {
testCase := testCase

success := t.Run(testCase.name, func(t *testing.T) {
pkg := testCase.pkg(t)
assertEncodingDecoding(testCase.name, pkg)
})
if !success {
return
}
}

assertEqualPackets(t, pkg, decoded)
// Write test vectors to file. This is a no-op if the "gen_test_vectors"
// build tag is not set.
test.WriteTestVectors(t, generatedTestVectorName, testVectors)
}

// TestDecodeBase64 tests the decoding of a virtual packet from a base64 string.
Expand Down Expand Up @@ -136,3 +173,76 @@ func TestDecodeHex(t *testing.T) {

require.Len(t, packet.Outputs, 2)
}

// TestBIPTestVectors tests that the BIP test vectors are passing.
func TestBIPTestVectors(t *testing.T) {
t.Parallel()

for idx := range allTestVectorFiles {
var (
fileName = allTestVectorFiles[idx]
testVectors = &TestVectors{}
)
test.ParseTestVectors(t, fileName, &testVectors)
t.Run(fileName, func(tt *testing.T) {
tt.Parallel()

runBIPTestVector(tt, testVectors)
})
}
}

// runBIPTestVector runs the tests in a single BIP test vector file.
func runBIPTestVector(t *testing.T, testVectors *TestVectors) {
for _, validCase := range testVectors.ValidTestCases {
validCase := validCase

t.Run(validCase.Comment, func(tt *testing.T) {
tt.Parallel()

p := validCase.Packet.ToVPacket(t)

packetString, err := p.B64Encode()
require.NoError(tt, err)

areEqual := validCase.Expected == packetString

// Create nice diff if things don't match.
if !areEqual {
expectedPacket, err := NewFromRawBytes(
strings.NewReader(validCase.Expected),
true,
)
require.NoError(tt, err)

require.Equal(tt, p, expectedPacket)

// Make sure we still fail the test.
require.Equal(
tt, validCase.Expected, packetString,
)
}

// We also want to make sure that the address is decoded
// correctly from the encoded TLV stream.
decoded, err := NewFromRawBytes(
strings.NewReader(validCase.Expected), true,
)
require.NoError(tt, err)

require.Equal(tt, p, decoded)
})
}

for _, invalidCase := range testVectors.ErrorTestCases {
invalidCase := invalidCase

t.Run(invalidCase.Comment, func(tt *testing.T) {
tt.Parallel()

require.PanicsWithValue(tt, invalidCase.Error, func() {
invalidCase.Packet.ToVPacket(tt)
})
})
}
}
Loading

0 comments on commit d54ede9

Please sign in to comment.