From db39859249bf148cc80469e611d5a1b395080be8 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 11 Feb 2025 11:52:10 -0500 Subject: [PATCH] Refactored ChainWriter to use MultiNode client --- .../relayinterface/chain_components_test.go | 10 +-- .../relayinterface/lookups_test.go | 51 ++++++++------- pkg/solana/chain.go | 5 ++ pkg/solana/chainwriter/chain_writer.go | 10 +-- pkg/solana/chainwriter/chain_writer_test.go | 63 ++++++++++++------- pkg/solana/chainwriter/lookups.go | 12 ++-- pkg/solana/relay.go | 7 +-- 7 files changed, 95 insertions(+), 63 deletions(-) diff --git a/integration-tests/relayinterface/chain_components_test.go b/integration-tests/relayinterface/chain_components_test.go index 8bd2ffd18..2f1b0ae2a 100644 --- a/integration-tests/relayinterface/chain_components_test.go +++ b/integration-tests/relayinterface/chain_components_test.go @@ -281,7 +281,7 @@ type SolanaChainComponentsInterfaceTesterHelper[T WrappedTestingT[T]] interface GetSecondaryIDL(t T) []byte CreateAccount(t T, it SolanaChainComponentsInterfaceTester[T], contractName string, value uint64, testStruct TestStruct) solana.PublicKey TXM() *txm.TxManager - SolanaClient() *client.Client + MultiClient() *client.MultiClient } type WrappedTestingT[T any] interface { @@ -353,7 +353,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReaderWithCustomCf func (it *SolanaChainComponentsInterfaceTester[T]) GetContractWriter(t T) types.ContractWriter { chainWriterConfig := it.buildContractWriterConfig(t) - cw, err := chainwriter.NewSolanaChainWriterService(it.Helper.Logger(t), it.Helper.SolanaClient(), *it.Helper.TXM(), nil, chainWriterConfig) + cw, err := chainwriter.NewSolanaChainWriterService(it.Helper.Logger(t), *it.Helper.MultiClient(), *it.Helper.TXM(), nil, chainWriterConfig) require.NoError(t, err) servicetest.Run(t, cw) @@ -459,8 +459,10 @@ func (h *helper) TXM() *txm.TxManager { return &h.txm } -func (h *helper) SolanaClient() *client.Client { - return h.sc +func (h *helper) MultiClient() *client.MultiClient { + return client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return h.sc, nil + }) } func (h *helper) Context(t *testing.T) context.Context { diff --git a/integration-tests/relayinterface/lookups_test.go b/integration-tests/relayinterface/lookups_test.go index d35387eac..27ef3ea2e 100644 --- a/integration-tests/relayinterface/lookups_test.go +++ b/integration-tests/relayinterface/lookups_test.go @@ -1,6 +1,7 @@ package relayinterface import ( + "context" "testing" "time" @@ -49,7 +50,7 @@ func TestAccountContant(t *testing.T) { IsSigner: true, IsWritable: true, } - result, err := constantConfig.Resolve(tests.Context(t), nil, nil, nil) + result, err := constantConfig.Resolve(tests.Context(t), nil, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -77,7 +78,7 @@ func TestAccountLookups(t *testing.T) { IsSigner: chainwriter.MetaBool{Value: true}, IsWritable: chainwriter.MetaBool{Value: true}, } - result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) + result, err := lookupConfig.Resolve(ctx, testArgs, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -111,7 +112,7 @@ func TestAccountLookups(t *testing.T) { IsSigner: chainwriter.MetaBool{Value: true}, IsWritable: chainwriter.MetaBool{Value: true}, } - result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) + result, err := lookupConfig.Resolve(ctx, testArgs, nil, client.MultiClient{}) require.NoError(t, err) for i, meta := range result { require.Equal(t, expectedMeta[i], meta) @@ -132,7 +133,7 @@ func TestAccountLookups(t *testing.T) { IsSigner: chainwriter.MetaBool{Value: true}, IsWritable: chainwriter.MetaBool{Value: true}, } - _, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) + _, err := lookupConfig.Resolve(ctx, testArgs, nil, client.MultiClient{}) require.Error(t, err) }) @@ -162,7 +163,7 @@ func TestAccountLookups(t *testing.T) { }, } - result, err := lookupConfig.Resolve(ctx, args, nil, nil) + result, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{}) require.NoError(t, err) for i, meta := range result { @@ -200,7 +201,7 @@ func TestAccountLookups(t *testing.T) { Bitmaps: []uint64{5, 3}, } - _, err := lookupConfig.Resolve(ctx, args, nil, nil) + _, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{}) require.Contains(t, err.Error(), "bitmap value is not a single value") }) @@ -227,7 +228,7 @@ func TestAccountLookups(t *testing.T) { }, } - _, err := lookupConfig.Resolve(ctx, args, nil, nil) + _, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{}) require.Contains(t, err.Error(), "error reading bitmap from location") }) @@ -254,7 +255,7 @@ func TestAccountLookups(t *testing.T) { }, } - _, err := lookupConfig.Resolve(ctx, args, nil, nil) + _, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{}) require.Contains(t, err.Error(), "invalid value format at path") }) } @@ -287,7 +288,7 @@ func TestPDALookups(t *testing.T) { IsWritable: true, } - result, err := pdaLookup.Resolve(ctx, nil, nil, nil) + result, err := pdaLookup.Resolve(ctx, nil, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -322,7 +323,7 @@ func TestPDALookups(t *testing.T) { "another_seed": seed2, } - result, err := pdaLookup.Resolve(ctx, args, nil, nil) + result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -342,7 +343,7 @@ func TestPDALookups(t *testing.T) { "test_seed": []byte("data"), } - _, err := pdaLookup.Resolve(ctx, args, nil, nil) + _, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{}) require.Error(t, err) require.Contains(t, err.Error(), "key not found") }) @@ -378,7 +379,7 @@ func TestPDALookups(t *testing.T) { "another_seed": seed2, } - result, err := pdaLookup.Resolve(ctx, args, nil, nil) + result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -416,7 +417,7 @@ func TestPDALookups(t *testing.T) { "array_seed": arraySeed, } - result, err := pdaLookup.Resolve(ctx, args, nil, nil) + result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -456,7 +457,7 @@ func TestPDALookups(t *testing.T) { "seed2": arraySeed2, } - result, err := pdaLookup.Resolve(ctx, args, nil, nil) + result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{}) require.NoError(t, err) require.Equal(t, expectedMeta, result) }) @@ -477,13 +478,17 @@ func TestLookupTables(t *testing.T) { solanaClient, err := client.NewClient(url, cfg, 5*time.Second, nil) require.NoError(t, err) + multiClient := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return solanaClient, nil + }) + loader := solanautils.NewStaticLoader[client.ReaderWriter](solanaClient) mkey := keyMocks.NewSimpleKeystore(t) lggr := logger.Test(t) txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) - cw, err := chainwriter.NewSolanaChainWriterService(nil, solanaClient, txm, nil, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(nil, multiClient, txm, nil, chainwriter.ChainWriterConfig{}) require.NoError(t, err) t.Run("StaticLookup table resolves properly", func(t *testing.T) { @@ -657,6 +662,10 @@ func TestCreateATAs(t *testing.T) { solanaClient, err := client.NewClient(url, cfg, 5*time.Second, nil) require.NoError(t, err) + multiClient := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return solanaClient, nil + }) + t.Run("returns no instructions when no ATA location is found", func(t *testing.T) { lookups := []chainwriter.ATALookup{ { @@ -679,7 +688,7 @@ func TestCreateATAs(t *testing.T) { }, } - ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer) + ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer) require.NoError(t, err) require.Empty(t, ataInstructions) }) @@ -704,7 +713,7 @@ func TestCreateATAs(t *testing.T) { "Addresses": {chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)}, } - _, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer) + _, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer) require.Contains(t, err.Error(), "expected exactly one wallet address, got 2") }) @@ -728,7 +737,7 @@ func TestCreateATAs(t *testing.T) { "Addresses": {chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)}, } - _, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer) + _, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer) require.Contains(t, err.Error(), "expected equal number of token programs and mints, got 1 tokenPrograms and 2 mints") }) @@ -760,7 +769,7 @@ func TestCreateATAs(t *testing.T) { }, } - ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer) + ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer) require.NoError(t, err) tx := solanautils.CreateTx(ctx, t, rpcClient, ataInstructions, sender, rpc.CommitmentFinalized) @@ -797,14 +806,14 @@ func TestCreateATAs(t *testing.T) { }, } - ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer) + ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer) require.NoError(t, err) solanautils.SendAndConfirm(ctx, t, rpcClient, ataInstructions, sender, rpc.CommitmentFinalized) require.True(t, checkIfATAExists(t, rpcClient, ataAddress)) // now, if we try to create the same ATA again, it should return no instructions - ataInstructions, err = chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer) + ataInstructions, err = chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer) require.NoError(t, err) require.Empty(t, ataInstructions) }) diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 0202697b1..d7fd30297 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -54,6 +54,7 @@ type Chain interface { FeeEstimator() fees.Estimator // Reader returns a new Reader from the available list of nodes (if there are multiple, it will randomly select one) Reader() (client.Reader, error) + MultiClient() *client.MultiClient } // DefaultRequestTimeout is the default Solana client timeout. @@ -437,6 +438,10 @@ func (c *chain) Reader() (client.Reader, error) { return c.getClient(ctx) } +func (c *chain) MultiClient() *client.MultiClient { + return c.multiClient +} + func (c *chain) ChainID() string { return c.id } diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 32cbd8e8f..88dfa7849 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -30,7 +30,7 @@ const ServiceName = "SolanaChainWriter" type SolanaChainWriterService struct { lggr logger.Logger - reader client.Reader + reader client.MultiClient txm txm.TxManager ge fees.Estimator config ChainWriterConfig @@ -68,7 +68,7 @@ type MethodConfig struct { ArgsTransform string } -func NewSolanaChainWriterService(logger logger.Logger, reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { +func NewSolanaChainWriterService(logger logger.Logger, reader client.MultiClient, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { w := SolanaChainWriterService{ lggr: logger, reader: reader, @@ -151,7 +151,7 @@ for Solana transactions. It handles constant addresses, dynamic lookups, program ### Error Handling: - Errors are wrapped with the `debugID` for easier tracing. */ -func GetAddresses(ctx context.Context, args any, accounts []Lookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error) { +func GetAddresses(ctx context.Context, args any, accounts []Lookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.MultiClient) ([]*solana.AccountMeta, error) { var addresses []*solana.AccountMeta for _, accountConfig := range accounts { meta, err := accountConfig.Resolve(ctx, args, derivedTableMap, reader) @@ -227,7 +227,7 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses( // CreateATAs first checks if a specified location exists, then checks if the accounts derived from the // ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not. -func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) { +func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.MultiClient, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) { createATAInstructions := []solana.Instruction{} for _, lookup := range lookups { // Check if location exists @@ -562,7 +562,7 @@ func (s *SolanaChainWriterService) loadTable(ctx context.Context, args any, rlt return resultMap, nil } -func getLookupTableAddresses(ctx context.Context, reader client.Reader, tableAddress solana.PublicKey) (solana.PublicKeySlice, error) { +func getLookupTableAddresses(ctx context.Context, reader client.MultiClient, tableAddress solana.PublicKey) (solana.PublicKeySlice, error) { // Fetch the account info for the static table accountInfo, err := reader.GetAccountInfoWithOpts(ctx, tableAddress, &rpc.GetAccountInfoOpts{ Encoding: "base64", diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go index fb3edc472..61f70beac 100644 --- a/pkg/solana/chainwriter/chain_writer_test.go +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -2,6 +2,7 @@ package chainwriter_test import ( "bytes" + "context" _ "embed" "encoding/json" "errors" @@ -27,6 +28,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/monitoring/testutils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" feemocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees/mocks" txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" @@ -47,13 +49,16 @@ func TestChainWriter_GetAddresses(t *testing.T) { // mock client rw := clientmocks.NewReaderWriter(t) + mc := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return rw, nil + }) // mock estimator ge := feemocks.NewEstimator(t) // mock txm txm := txmMocks.NewTxManager(t) // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, chainwriter.ChainWriterConfig{}) require.NoError(t, err) // expected account meta for constant account @@ -165,7 +170,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { require.NoError(t, err) // Resolve account metas - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, mc) require.NoError(t, err) // account metas should be returned in the same order as the provided account lookup configs @@ -209,7 +214,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { require.NoError(t, err) // Resolve account metas - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, mc) require.NoError(t, err) require.Len(t, accounts, 2) @@ -233,7 +238,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { require.NoError(t, err) // Resolve account metas - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, mc) require.NoError(t, err) require.Len(t, accounts, 3) @@ -258,7 +263,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { args := Arguments{} - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, mc) require.NoError(t, err) require.Empty(t, accounts) }) @@ -275,7 +280,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { } args := Arguments{} - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, mc) require.Error(t, err) require.Nil(t, accounts) }) @@ -293,7 +298,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { } args := Arguments{} - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, mc) require.NoError(t, err) require.Empty(t, accounts) }) @@ -311,7 +316,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { } args := Arguments{} - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, mc) require.Error(t, err) require.Nil(t, accounts) }) @@ -364,7 +369,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { args := Arguments{} - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, mc) require.NoError(t, err) require.Empty(t, accounts) }) @@ -378,7 +383,7 @@ func TestChainWriter_GetAddresses(t *testing.T) { } args := Arguments{} - _, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, rw) + _, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, nil, mc) require.Error(t, err) }) }) @@ -389,13 +394,16 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { // mock client rw := clientmocks.NewReaderWriter(t) + mc := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return rw, nil + }) // mock estimator ge := feemocks.NewEstimator(t) // mock txm txm := txmMocks.NewTxManager(t) // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, chainwriter.ChainWriterConfig{}) require.NoError(t, err) programID := chainwriter.GetRandomPubKey(t) @@ -482,7 +490,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { require.NoError(t, err) // Resolve account metas - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, mc) require.NoError(t, err) // Filter the lookup table addresses based on which accounts are actually used @@ -504,7 +512,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { require.NoError(t, err) // Resolve account metas - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, mc) require.NoError(t, err) // Filter the lookup table addresses based on which accounts are actually used @@ -527,7 +535,7 @@ func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { require.NoError(t, err) // Resolve account metas - accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, mc) require.NoError(t, err) // Filter the lookup table addresses based on which accounts are actually used @@ -542,6 +550,9 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { ctx := tests.Context(t) // mock client rw := clientmocks.NewReaderWriter(t) + mc := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return rw, nil + }) // mock estimator ge := feemocks.NewEstimator(t) // mock txm @@ -654,7 +665,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { } // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, cwConfig) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, cwConfig) require.NoError(t, err) t.Run("fails with invalid ABI", func(t *testing.T) { @@ -671,7 +682,7 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { }, } - _, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, invalidCWConfig) + _, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, invalidCWConfig) require.Error(t, err) }) @@ -836,6 +847,9 @@ func TestChainWriter_CCIPOfframp(t *testing.T) { ctx := tests.Context(t) // mock client rw := clientmocks.NewReaderWriter(t) + mc := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return rw, nil + }) // mock estimator ge := feemocks.NewEstimator(t) @@ -843,7 +857,7 @@ func TestChainWriter_CCIPOfframp(t *testing.T) { // mock txm txm := txmMocks.NewTxManager(t) // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, ccipCWConfig) require.NoError(t, err) recentBlockHash := solana.Hash{} @@ -907,7 +921,7 @@ func TestChainWriter_CCIPOfframp(t *testing.T) { // mock txm txm := txmMocks.NewTxManager(t) // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, ccipCWConfig) require.NoError(t, err) recentBlockHash := solana.Hash{} @@ -957,13 +971,17 @@ func TestChainWriter_GetTransactionStatus(t *testing.T) { ctx := tests.Context(t) rw := clientmocks.NewReaderWriter(t) + mc := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return rw, nil + }) + ge := feemocks.NewEstimator(t) // mock txm txm := txmMocks.NewTxManager(t) // initialize chain writer - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, chainwriter.ChainWriterConfig{}) require.NoError(t, err) t.Run("returns unknown with error if ID not found", func(t *testing.T) { @@ -1020,13 +1038,16 @@ func TestChainWriter_GetFeeComponents(t *testing.T) { ctx := tests.Context(t) rw := clientmocks.NewReaderWriter(t) + mc := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) { + return rw, nil + }) ge := feemocks.NewEstimator(t) ge.On("BaseComputeUnitPrice").Return(uint64(100)) // mock txm txm := txmMocks.NewTxManager(t) - cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, chainwriter.ChainWriterConfig{}) + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, ge, chainwriter.ChainWriterConfig{}) require.NoError(t, err) t.Run("returns valid compute unit price", func(t *testing.T) { @@ -1037,7 +1058,7 @@ func TestChainWriter_GetFeeComponents(t *testing.T) { }) t.Run("fails if gas estimator not set", func(t *testing.T) { - cwNoEstimator, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, nil, chainwriter.ChainWriterConfig{}) + cwNoEstimator, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), mc, txm, nil, chainwriter.ChainWriterConfig{}) require.NoError(t, err) _, err = cwNoEstimator.GetFeeComponents(ctx) require.Error(t, err) diff --git a/pkg/solana/chainwriter/lookups.go b/pkg/solana/chainwriter/lookups.go index 635e7edd0..81482d1c3 100644 --- a/pkg/solana/chainwriter/lookups.go +++ b/pkg/solana/chainwriter/lookups.go @@ -15,7 +15,7 @@ import ( // Lookup is an interface that defines a method to resolve an address (or multiple addresses) from a given definition. type Lookup interface { - Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error) + Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.MultiClient) ([]*solana.AccountMeta, error) IsOptional() bool } @@ -108,7 +108,7 @@ type ATALookup struct { MintAddress Lookup } -func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { +func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.MultiClient) ([]*solana.AccountMeta, error) { address, err := solana.PublicKeyFromBase58(ac.Address) if err != nil { return nil, fmt.Errorf("error getting account from constant: %w", err) @@ -122,7 +122,7 @@ func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[str }, nil } -func (al AccountLookup) Resolve(_ context.Context, args any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { +func (al AccountLookup) Resolve(_ context.Context, args any, _ map[string]map[string][]*solana.AccountMeta, _ client.MultiClient) ([]*solana.AccountMeta, error) { derivedValues, err := GetValuesAtLocation(args, al.Location) if err != nil { return nil, fmt.Errorf("error getting account from lookup: %w", err) @@ -181,7 +181,7 @@ func resolveBitMap(mb MetaBool, args any, length int) ([]bool, error) { return result, nil } -func (alt AccountsFromLookupTable) Resolve(_ context.Context, _ any, derivedTableMap map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { +func (alt AccountsFromLookupTable) Resolve(_ context.Context, _ any, derivedTableMap map[string]map[string][]*solana.AccountMeta, _ client.MultiClient) ([]*solana.AccountMeta, error) { // Fetch the inner map for the specified lookup table name innerMap, ok := derivedTableMap[alt.LookupTableName] if !ok { @@ -211,7 +211,7 @@ func (alt AccountsFromLookupTable) Resolve(_ context.Context, _ any, derivedTabl return result, nil } -func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error) { +func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.MultiClient) ([]*solana.AccountMeta, error) { publicKeys, err := GetAddresses(ctx, args, []Lookup{pda.PublicKey}, derivedTableMap, reader) if err != nil { return nil, fmt.Errorf("error getting public key for PDALookups: %w", err) @@ -291,7 +291,7 @@ func getSeedBytesCombinations( lookup PDALookups, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, - reader client.Reader, + reader client.MultiClient, ) ([][][]byte, error) { allCombinations := [][][]byte{ {}, diff --git a/pkg/solana/relay.go b/pkg/solana/relay.go index 07a0830d4..f6135c583 100644 --- a/pkg/solana/relay.go +++ b/pkg/solana/relay.go @@ -143,12 +143,7 @@ func (r *Relayer) NewContractWriter(_ context.Context, config []byte) (relaytype return nil, fmt.Errorf("failed to unmarshall chain writer config err: %s", err) } - solanaReader, err := r.chain.Reader() - if err != nil { - return nil, fmt.Errorf("failed to init solana reader err: %s", err) - } - - return chainwriter.NewSolanaChainWriterService(r.lggr, solanaReader, r.chain.TxManager(), r.chain.FeeEstimator(), cfg) + return chainwriter.NewSolanaChainWriterService(r.lggr, *r.chain.MultiClient(), r.chain.TxManager(), r.chain.FeeEstimator(), cfg) } func (r *Relayer) NewContractReader(_ context.Context, chainReaderConfig []byte) (relaytypes.ContractReader, error) {