Skip to content

Commit

Permalink
Refactored ChainWriter to use MultiNode client
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Feb 11, 2025
1 parent b72f722 commit 027e6ce
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 82 deletions.
10 changes: 6 additions & 4 deletions integration-tests/relayinterface/chain_components_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ type SolanaChainComponentsInterfaceTesterHelper[T WrappedTestingT[T]] interface
GetSecondaryIDL(t T) []byte
CreateAccount(t T, it SolanaChainComponentsInterfaceTester[T], contractName string, value uint64, testStruct TestStruct) solana.PublicKey
TXM() *txm.TxManager
SolanaClient() *client.Client
MultiClient() *client.MultiClient
}

type WrappedTestingT[T any] interface {
Expand Down Expand Up @@ -379,7 +379,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReaderWithCustomCf

func (it *SolanaChainComponentsInterfaceTester[T]) GetContractWriter(t T) types.ContractWriter {
chainWriterConfig := it.buildContractWriterConfig(t)
cw, err := chainwriter.NewSolanaChainWriterService(it.Helper.Logger(t), it.Helper.SolanaClient(), *it.Helper.TXM(), nil, chainWriterConfig)
cw, err := chainwriter.NewSolanaChainWriterService(it.Helper.Logger(t), *it.Helper.MultiClient(), *it.Helper.TXM(), nil, chainWriterConfig)
require.NoError(t, err)

servicetest.Run(t, cw)
Expand Down Expand Up @@ -485,8 +485,10 @@ func (h *helper) TXM() *txm.TxManager {
return &h.txm
}

func (h *helper) SolanaClient() *client.Client {
return h.sc
func (h *helper) MultiClient() *client.MultiClient {
return client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) {
return h.sc, nil
})
}

func (h *helper) Context(t *testing.T) context.Context {
Expand Down
51 changes: 30 additions & 21 deletions integration-tests/relayinterface/lookups_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package relayinterface

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -49,7 +50,7 @@ func TestAccountContant(t *testing.T) {
IsSigner: true,
IsWritable: true,
}
result, err := constantConfig.Resolve(tests.Context(t), nil, nil, nil)
result, err := constantConfig.Resolve(tests.Context(t), nil, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand Down Expand Up @@ -77,7 +78,7 @@ func TestAccountLookups(t *testing.T) {
IsSigner: chainwriter.MetaBool{Value: true},
IsWritable: chainwriter.MetaBool{Value: true},
}
result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil)
result, err := lookupConfig.Resolve(ctx, testArgs, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand Down Expand Up @@ -111,7 +112,7 @@ func TestAccountLookups(t *testing.T) {
IsSigner: chainwriter.MetaBool{Value: true},
IsWritable: chainwriter.MetaBool{Value: true},
}
result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil)
result, err := lookupConfig.Resolve(ctx, testArgs, nil, client.MultiClient{})
require.NoError(t, err)
for i, meta := range result {
require.Equal(t, expectedMeta[i], meta)
Expand All @@ -132,7 +133,7 @@ func TestAccountLookups(t *testing.T) {
IsSigner: chainwriter.MetaBool{Value: true},
IsWritable: chainwriter.MetaBool{Value: true},
}
_, err := lookupConfig.Resolve(ctx, testArgs, nil, nil)
_, err := lookupConfig.Resolve(ctx, testArgs, nil, client.MultiClient{})
require.Error(t, err)
})

Expand Down Expand Up @@ -162,7 +163,7 @@ func TestAccountLookups(t *testing.T) {
},
}

result, err := lookupConfig.Resolve(ctx, args, nil, nil)
result, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{})
require.NoError(t, err)

for i, meta := range result {
Expand Down Expand Up @@ -200,7 +201,7 @@ func TestAccountLookups(t *testing.T) {
Bitmaps: []uint64{5, 3},
}

_, err := lookupConfig.Resolve(ctx, args, nil, nil)
_, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{})
require.Contains(t, err.Error(), "bitmap value is not a single value")
})

Expand All @@ -227,7 +228,7 @@ func TestAccountLookups(t *testing.T) {
},
}

_, err := lookupConfig.Resolve(ctx, args, nil, nil)
_, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{})
require.Contains(t, err.Error(), "error reading bitmap from location")
})

Expand All @@ -254,7 +255,7 @@ func TestAccountLookups(t *testing.T) {
},
}

_, err := lookupConfig.Resolve(ctx, args, nil, nil)
_, err := lookupConfig.Resolve(ctx, args, nil, client.MultiClient{})
require.Contains(t, err.Error(), "invalid value format at path")
})
}
Expand Down Expand Up @@ -287,7 +288,7 @@ func TestPDALookups(t *testing.T) {
IsWritable: true,
}

result, err := pdaLookup.Resolve(ctx, nil, nil, nil)
result, err := pdaLookup.Resolve(ctx, nil, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand Down Expand Up @@ -322,7 +323,7 @@ func TestPDALookups(t *testing.T) {
"another_seed": seed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand All @@ -342,7 +343,7 @@ func TestPDALookups(t *testing.T) {
"test_seed": []byte("data"),
}

_, err := pdaLookup.Resolve(ctx, args, nil, nil)
_, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{})
require.Error(t, err)
require.Contains(t, err.Error(), "key not found")
})
Expand Down Expand Up @@ -378,7 +379,7 @@ func TestPDALookups(t *testing.T) {
"another_seed": seed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand Down Expand Up @@ -416,7 +417,7 @@ func TestPDALookups(t *testing.T) {
"array_seed": arraySeed,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand Down Expand Up @@ -456,7 +457,7 @@ func TestPDALookups(t *testing.T) {
"seed2": arraySeed2,
}

result, err := pdaLookup.Resolve(ctx, args, nil, nil)
result, err := pdaLookup.Resolve(ctx, args, nil, client.MultiClient{})
require.NoError(t, err)
require.Equal(t, expectedMeta, result)
})
Expand All @@ -477,13 +478,17 @@ func TestLookupTables(t *testing.T) {
solanaClient, err := client.NewClient(url, cfg, 5*time.Second, nil)
require.NoError(t, err)

multiClient := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) {
return solanaClient, nil
})

loader := solanautils.NewStaticLoader[client.ReaderWriter](solanaClient)
mkey := keyMocks.NewSimpleKeystore(t)
lggr := logger.Test(t)

txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr)

cw, err := chainwriter.NewSolanaChainWriterService(nil, solanaClient, txm, nil, chainwriter.ChainWriterConfig{})
cw, err := chainwriter.NewSolanaChainWriterService(nil, multiClient, txm, nil, chainwriter.ChainWriterConfig{})
require.NoError(t, err)

t.Run("StaticLookup table resolves properly", func(t *testing.T) {
Expand Down Expand Up @@ -657,6 +662,10 @@ func TestCreateATAs(t *testing.T) {
solanaClient, err := client.NewClient(url, cfg, 5*time.Second, nil)
require.NoError(t, err)

multiClient := *client.NewMultiClient(func(context.Context) (client.ReaderWriter, error) {
return solanaClient, nil
})

t.Run("returns no instructions when no ATA location is found", func(t *testing.T) {
lookups := []chainwriter.ATALookup{
{
Expand All @@ -679,7 +688,7 @@ func TestCreateATAs(t *testing.T) {
},
}

ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer)
require.NoError(t, err)
require.Empty(t, ataInstructions)
})
Expand All @@ -704,7 +713,7 @@ func TestCreateATAs(t *testing.T) {
"Addresses": {chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)},
}

_, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
_, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer)
require.Contains(t, err.Error(), "expected exactly one wallet address, got 2")
})

Expand All @@ -728,7 +737,7 @@ func TestCreateATAs(t *testing.T) {
"Addresses": {chainwriter.GetRandomPubKey(t), chainwriter.GetRandomPubKey(t)},
}

_, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
_, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer)
require.Contains(t, err.Error(), "expected equal number of token programs and mints, got 1 tokenPrograms and 2 mints")
})

Expand Down Expand Up @@ -760,7 +769,7 @@ func TestCreateATAs(t *testing.T) {
},
}

ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer)
require.NoError(t, err)

tx := solanautils.CreateTx(ctx, t, rpcClient, ataInstructions, sender, rpc.CommitmentFinalized)
Expand Down Expand Up @@ -797,14 +806,14 @@ func TestCreateATAs(t *testing.T) {
},
}

ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
ataInstructions, err := chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer)
require.NoError(t, err)

solanautils.SendAndConfirm(ctx, t, rpcClient, ataInstructions, sender, rpc.CommitmentFinalized)
require.True(t, checkIfATAExists(t, rpcClient, ataAddress))

// now, if we try to create the same ATA again, it should return no instructions
ataInstructions, err = chainwriter.CreateATAs(ctx, args, lookups, nil, solanaClient, testContractIDL, feePayer)
ataInstructions, err = chainwriter.CreateATAs(ctx, args, lookups, nil, multiClient, testContractIDL, feePayer)
require.NoError(t, err)
require.Empty(t, ataInstructions)
})
Expand Down
5 changes: 5 additions & 0 deletions pkg/solana/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Chain interface {
FeeEstimator() fees.Estimator
// Reader returns a new Reader from the available list of nodes (if there are multiple, it will randomly select one)
Reader() (client.Reader, error)
MultiClient() *client.MultiClient
}

// DefaultRequestTimeout is the default Solana client timeout.
Expand Down Expand Up @@ -437,6 +438,10 @@ func (c *chain) Reader() (client.Reader, error) {
return c.getClient(ctx)
}

func (c *chain) MultiClient() *client.MultiClient {
return c.multiClient
}

func (c *chain) ChainID() string {
return c.id
}
Expand Down
Loading

0 comments on commit 027e6ce

Please sign in to comment.