Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NONEVM-1104] - Implement CR namespace group address sharing and fix naming inconsistencies #1032

Merged
merged 9 commits into from
Feb 7, 2025
150 changes: 122 additions & 28 deletions integration-tests/relayinterface/chain_components_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/gagliardetto/solana-go/rpc"
"github.com/gagliardetto/solana-go/rpc/ws"
"github.com/gagliardetto/solana-go/text"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
Expand All @@ -37,6 +38,12 @@ import (
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
)

const (
AnyContractNameWithSharedAddress1 = AnyContractName + "Shared1"
AnyContractNameWithSharedAddress2 = AnyContractName + "Shared2"
AnyContractNameWithSharedAddress3 = AnyContractName + "Shared3"
)

func TestChainComponents(t *testing.T) {
t.Parallel()
helper := &helper{}
Expand Down Expand Up @@ -95,7 +102,80 @@ func DisableTests(it *SolanaChainComponentsInterfaceTester[*testing.T]) {
}

func RunChainComponentsSolanaTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) {
RunContractReaderSolanaTests(t, it)
testCases := Testcase[T]{
Name: "Test address groups where first namespace shares address with second namespace",
Test: func(t T) {
ctx := tests.Context(t)
cfg := it.contractReaderConfig
cfg.AddressShareGroups = [][]string{{AnyContractNameWithSharedAddress1, AnyContractNameWithSharedAddress2, AnyContractNameWithSharedAddress3}}
cr := it.GetContractReaderWithCustomCfg(t, cfg)

t.Run("Namespace is part of an address share group that doesn't have a registered address and provides no address during Bind", func(t T) {
bound1 := []types.BoundContract{{
Name: AnyContractNameWithSharedAddress1,
}}
require.Error(t, cr.Bind(ctx, bound1))
})

addressToBeShared := it.Helper.CreateAccount(t, AnyValueToReadWithoutAnArgument).String()
t.Run("Namespace is part of an address share group that doesn't have a registered address and provides an address during Bind", func(t T) {
bound1 := []types.BoundContract{{Name: AnyContractNameWithSharedAddress1, Address: addressToBeShared}}

require.NoError(t, cr.Bind(ctx, bound1))

var prim uint64
require.NoError(t, cr.GetLatestValue(ctx, bound1[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
})

t.Run("Namespace is part of an address share group that has a registered address and provides that same address during Bind", func(t T) {
bound2 := []types.BoundContract{{
Name: AnyContractNameWithSharedAddress2,
Address: addressToBeShared}}
require.NoError(t, cr.Bind(ctx, bound2))

var prim uint64
require.NoError(t, cr.GetLatestValue(ctx, bound2[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
assert.Equal(t, addressToBeShared, bound2[0].Address)
})

t.Run("Namespace is part of an address share group that has a registered address and provides a wrong address during Bind", func(t T) {
key, err := solana.NewRandomPrivateKey()
require.NoError(t, err)

bound2 := []types.BoundContract{{
Name: AnyContractNameWithSharedAddress2,
Address: key.PublicKey().String()}}
require.Error(t, cr.Bind(ctx, bound2))
})

t.Run("Namespace is part of an address share group that has a registered address and provides no address during Bind", func(t T) {
bound3 := []types.BoundContract{{Name: AnyContractNameWithSharedAddress3}}
require.NoError(t, cr.Bind(ctx, bound3))

var prim uint64
require.NoError(t, cr.GetLatestValue(ctx, bound3[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
assert.Equal(t, addressToBeShared, bound3[0].Address)

// when run in a loop Bind address won't be set, so check if CR Method works without set address.
prim = 0
require.NoError(t, cr.GetLatestValue(ctx, types.BoundContract{
Address: "",
Name: AnyContractNameWithSharedAddress3,
}.ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a test that it handles this error condition properly?

// If they are Bound with a non-empty address string, an error will be thrown unless the address matches the address of the first Bound shared contract.
	AddressShareGroups [][]string `json:"addressShareGroups,omitempty"`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


t.Run("Namespace is not part of an address share group that has a registered address and provides no address during Bind", func(t T) {
require.Error(t, cr.Bind(ctx, []types.BoundContract{{Name: AnyContractName}}))
})
},
}

RunTests(t, it, []Testcase[T]{testCases})
RunContractReaderTests(t, it)
// Add ChainWriter tests here
}

Expand All @@ -104,20 +184,12 @@ func RunChainComponentsInLoopSolanaTests[T TestingT[T]](t T, it ChainComponentsI
// Add ChainWriter tests here
}

func RunContractReaderSolanaTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) {
func RunContractReaderTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) {
RunContractReaderInterfaceTests(t, it, false, true)

var testCases []Testcase[T]

RunTests(t, it, testCases)
}

func RunContractReaderInLoopTests[T TestingT[T]](t T, it ChainComponentsInterfaceTester[T]) {
RunContractReaderInterfaceTests(t, it, false, true)

var testCases []Testcase[T]

RunTests(t, it, testCases)
}

type SolanaChainComponentsInterfaceTesterHelper[T TestingT[T]] interface {
Expand All @@ -139,26 +211,28 @@ type SolanaChainComponentsInterfaceTester[T TestingT[T]] struct {
func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) {
t.Cleanup(func() {})

it.contractReaderConfig = config.ContractReader{
Namespaces: map[string]config.ChainContractReader{
AnyContractName: {
IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))),
Reads: map[string]config.ReadDefinition{
MethodReturningUint64: {
ChainSpecificName: "DataAccount",
ReadType: config.Account,
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Value"},
},
},
MethodReturningUint64Slice: {
ChainSpecificName: "DataAccount",
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Slice"},
},
},
anyContractReadDef := config.ChainContractReader{
IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))),
Reads: map[string]config.ReadDefinition{
MethodReturningUint64: {
ChainSpecificName: "DataAccount",
ReadType: config.Account,
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Value"},
},
},
MethodReturningUint64Slice: {
ChainSpecificName: "DataAccount",
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Slice"},
},
},
},
}

it.contractReaderConfig = config.ContractReader{
Namespaces: map[string]config.ChainContractReader{
AnyContractName: anyContractReadDef,
AnySecondContractName: {
IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))),
Reads: map[string]config.ReadDefinition{
Expand All @@ -170,6 +244,10 @@ func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) {
},
},
},
// these are for testing shared address groups
AnyContractNameWithSharedAddress1: anyContractReadDef,
AnyContractNameWithSharedAddress2: anyContractReadDef,
AnyContractNameWithSharedAddress3: anyContractReadDef,
},
}
}
Expand Down Expand Up @@ -208,6 +286,22 @@ func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReader(t T) types.
return svc
}

func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReaderWithCustomCfg(t T, cfg config.ContractReader) types.ContractReader {
ctx := it.Helper.Context(t)
if it.cr != nil {
return it.cr
}

svc, err := chainreader.NewContractReaderService(it.Helper.Logger(t), it.Helper.RPCClient(), cfg, nil)

require.NoError(t, err)
require.NoError(t, svc.Start(ctx))

it.cr = svc

return svc
}

func (it *SolanaChainComponentsInterfaceTester[T]) GetContractWriter(t T) types.ContractWriter {
return nil
}
Expand Down
34 changes: 17 additions & 17 deletions pkg/solana/chainreader/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ import (
)

type call struct {
ContractName, ReadName string
Params, ReturnVal any
Namespace, ReadName string
Params, ReturnVal any
}

type batchResultWithErr struct {
address string
contractName, readName string
returnVal any
err error
address string
namespace, readName string
returnVal any
err error
}

var (
Expand All @@ -30,16 +30,16 @@ type MultipleAccountGetter interface {
GetMultipleAccountData(context.Context, ...solana.PublicKey) ([][]byte, error)
}

func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindings namespaceBindings, batch []call) ([]batchResultWithErr, error) {
func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindingsRegistry bindingsRegistry, batch []call) ([]batchResultWithErr, error) {
// Create the list of public keys to fetch
keys := make([]solana.PublicKey, len(batch))
for idx, call := range batch {
binding, err := bindings.GetReadBinding(call.ContractName, call.ReadName)
rBinding, err := bindingsRegistry.GetReadBinding(call.Namespace, call.ReadName)
if err != nil {
return nil, err
}

key, err := binding.GetAddress(ctx, call.Params)
key, err := rBinding.GetAddress(ctx, call.Params)
if err != nil {
return nil, fmt.Errorf("failed to get address for %s account read: %w", call.ReadName, err)
}
Expand All @@ -57,10 +57,10 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
// decode batch call results
for idx, call := range batch {
results[idx] = batchResultWithErr{
address: keys[idx].String(),
contractName: call.ContractName,
readName: call.ReadName,
returnVal: call.ReturnVal,
address: keys[idx].String(),
namespace: call.Namespace,
readName: call.ReadName,
returnVal: call.ReturnVal,
}

if data[idx] == nil || len(data[idx]) == 0 {
Expand All @@ -69,7 +69,7 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
continue
}

binding, err := bindings.GetReadBinding(results[idx].contractName, results[idx].readName)
rBinding, err := bindingsRegistry.GetReadBinding(results[idx].namespace, results[idx].readName)
if err != nil {
results[idx].err = err

Expand All @@ -80,12 +80,12 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
if !isValue {
results[idx].err = errors.Join(
results[idx].err,
binding.Decode(ctx, data[idx], results[idx].returnVal),
rBinding.Decode(ctx, data[idx], results[idx].returnVal),
)
continue
}

contractType, err := binding.CreateType(false)
contractType, err := rBinding.CreateType(false)
if err != nil {
results[idx].err = err

Expand All @@ -94,7 +94,7 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin

results[idx].err = errors.Join(
results[idx].err,
binding.Decode(ctx, data[idx], contractType),
rBinding.Decode(ctx, data[idx], contractType),
)

value, err := values.Wrap(contractType)
Expand Down
Loading
Loading