Skip to content

Commit

Permalink
Implement multi read
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Feb 4, 2025
1 parent 2425295 commit ee7ff3a
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 96 deletions.
92 changes: 52 additions & 40 deletions pkg/solana/chainreader/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
)

type call struct {
ContractName, ReadName string
Params, ReturnVal any
ContractName string
// MultiRead contains all the sequential reads needed to populate ReturnVal.
MultiRead []string
Params, ReturnVal any
}

type batchResultWithErr struct {
Expand All @@ -34,14 +36,18 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
// Create the list of public keys to fetch
keys := make([]solana.PublicKey, len(batch))
for idx, call := range batch {
binding, err := bindings.GetReadBinding(call.ContractName, call.ReadName)
if len(call.MultiRead) == 0 {
return nil, fmt.Errorf("no read specified for: %q", call.ContractName)
}

binding, err := bindings.GetReadBinding(call.ContractName, call.MultiRead[0])
if err != nil {
return nil, err
}

key, err := binding.GetAddress(ctx, call.Params)
if err != nil {
return nil, fmt.Errorf("failed to get address for %s account read: %w", call.ReadName, err)
return nil, fmt.Errorf("failed to get address for %s account read: %w", call.MultiRead, err)
}
keys[idx] = key
}
Expand All @@ -56,55 +62,61 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin

// decode batch call results
for idx, call := range batch {
results[idx] = batchResultWithErr{
address: keys[idx].String(),
contractName: call.ContractName,
readName: call.ReadName,
returnVal: call.ReturnVal,
if len(batch) > 1 && len(call.MultiRead) > 1 {
return nil, fmt.Errorf("batch call does not support multiple reads")
}

if data[idx] == nil || len(data[idx]) == 0 {
results[idx].err = ErrMissingAccountData
for _, read := range call.MultiRead {
results[idx] = batchResultWithErr{
address: keys[idx].String(),
contractName: call.ContractName,
readName: read,
returnVal: call.ReturnVal,
}

continue
}
if data[idx] == nil || len(data[idx]) == 0 {
results[idx].err = ErrMissingAccountData

binding, err := bindings.GetReadBinding(results[idx].contractName, results[idx].readName)
if err != nil {
results[idx].err = err
continue
}

continue
}
binding, err := bindings.GetReadBinding(results[idx].contractName, results[idx].readName)
if err != nil {
results[idx].err = err

continue
}

ptrToValue, isValue := call.ReturnVal.(*values.Value)
if !isValue {
results[idx].err = errors.Join(
results[idx].err,
binding.Decode(ctx, data[idx], results[idx].returnVal),
)
continue
}

contractType, err := binding.CreateType(false)
if err != nil {
results[idx].err = err

continue
}

ptrToValue, isValue := call.ReturnVal.(*values.Value)
if !isValue {
results[idx].err = errors.Join(
results[idx].err,
binding.Decode(ctx, data[idx], results[idx].returnVal),
binding.Decode(ctx, data[idx], contractType),
)
continue
}

contractType, err := binding.CreateType(false)
if err != nil {
results[idx].err = err

continue
}

results[idx].err = errors.Join(
results[idx].err,
binding.Decode(ctx, data[idx], contractType),
)
value, err := values.Wrap(contractType)
if err != nil {
results[idx].err = errors.Join(results[idx].err, err)

value, err := values.Wrap(contractType)
if err != nil {
results[idx].err = errors.Join(results[idx].err, err)
continue
}

continue
*ptrToValue = value
}

*ptrToValue = value
}

return results, nil
Expand Down
80 changes: 49 additions & 31 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ import (
"fmt"
"sync"

"github.com/gagliardetto/solana-go"
"github.com/gagliardetto/solana-go/rpc"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
Expand Down Expand Up @@ -124,7 +121,7 @@ func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, readIdent
batch := []call{
{
ContractName: values.contract,
ReadName: values.genericName,
MultiRead: values.multiRead,
Params: params,
ReturnVal: returnVal,
},
Expand All @@ -149,7 +146,7 @@ func (s *SolanaChainReaderService) GetLatestValue(ctx context.Context, readIdent
// BatchGetLatestValues implements the types.ContractReader interface.
func (s *SolanaChainReaderService) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) {
idxLookup := make(map[types.BoundContract][]int)
batch := []call{}
var batch []call

for bound, req := range request {
idxLookup[bound] = make([]int, len(req))
Expand All @@ -158,7 +155,7 @@ func (s *SolanaChainReaderService) BatchGetLatestValues(ctx context.Context, req
idxLookup[bound][idx] = len(batch)
batch = append(batch, call{
ContractName: bound.Name,
ReadName: readReq.ReadName,
MultiRead: []string{readReq.ReadName},
Params: readReq.Params,
ReturnVal: readReq.ReturnVal,
})
Expand Down Expand Up @@ -227,7 +224,11 @@ func (s *SolanaChainReaderService) CreateContractType(readIdentifier string, for
return nil, fmt.Errorf("%w: no contract for read identifier", types.ErrInvalidConfig)
}

return s.bindings.CreateType(values.contract, values.genericName, forEncoding)
if len(values.multiRead) == 0 {
return nil, fmt.Errorf("%w: no reads defined for read identifier", types.ErrInvalidConfig)
}

return s.bindings.CreateType(values.contract, values.multiRead[0], forEncoding)
}

func (s *SolanaChainReaderService) addCodecDef(forEncoding bool, namespace, genericName string, readType codec.ChainConfigType, idl codec.IDL, idlDefinition interface{}, modCfg commoncodec.ModifiersConfig) error {
Expand Down Expand Up @@ -288,14 +289,23 @@ func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName
return err
}

s.lookup.addReadNameForContract(namespace, genericName)
multiRead := []string{genericName}
if readDefinition.DynamicHardCoder != nil {
reads, err := s.addMultiAccountRead(namespace, readDefinition, idl)
if err != nil {
return err
}
multiRead = append(multiRead, reads...)
}

s.lookup.addReadNameForContract(namespace, genericName, multiRead)

var reader readBinding
var inputAccountIDLDef interface{}
// Create PDA read binding if PDA prefix or seeds configs are populated
if len(readDefinition.PDADefiniton.Prefix) > 0 || len(readDefinition.PDADefiniton.Seeds) > 0 {
inputAccountIDLDef = readDefinition.PDADefiniton
reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefiniton.Prefix, true)
if len(readDefinition.PDADefinition.Prefix) > 0 || len(readDefinition.PDADefinition.Seeds) > 0 {
inputAccountIDLDef = readDefinition.PDADefinition
reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, true)
} else {
inputAccountIDLDef = codec.NilIdlTypeDefTy
reader = newAccountReadBinding(namespace, genericName, "", false)
Expand All @@ -304,10 +314,37 @@ func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName
return err
}
s.bindings.AddReadBinding(namespace, genericName, reader)

return nil
}

func (s *SolanaChainReaderService) addMultiAccountRead(namespace string, readDefinition config.ReadDefinition, idl codec.IDL) ([]string, error) {
var reads []string
if readDefinition.DynamicHardCoder.MultiReader != nil {
for _, mr := range readDefinition.DynamicHardCoder.MultiReader {
idlDef, err := codec.FindDefinitionFromIDL(codec.ChainConfigTypeAccountDef, mr.ChainSpecificName, idl)
if err != nil {
return nil, err
}

if mr.ReadType != config.Account {
return nil, fmt.Errorf("unexpected read type %q for dynamic hard coder read: %q in namespace: %q", mr.ReadType, mr.ChainSpecificName, namespace)
}

accountIDLDef, isOk := idlDef.(codec.IdlTypeDef)
if !isOk {
return nil, fmt.Errorf("unexpected type %T from IDL definition for account read with chainSpecificName: %q, of type: %q", accountIDLDef, mr.ChainSpecificName, mr.ReadType)
}

if err = s.addAccountRead(namespace, mr.ChainSpecificName, idl, accountIDLDef, mr); err != nil {
return nil, fmt.Errorf("failed to add multi-read %q: %w", mr.ChainSpecificName, err)
}

reads = append(reads, mr.ChainSpecificName)
}
}
return reads, nil
}

// injectAddressModifier injects AddressModifier into OutputModifications.
// This is necessary because AddressModifier cannot be serialized and must be applied at runtime.
func injectAddressModifier(inputModifications, outputModifications commoncodec.ModifiersConfig) {
Expand All @@ -325,22 +362,3 @@ func injectAddressModifier(inputModifications, outputModifications commoncodec.M
}
}
}

type accountDataReader struct {
client *rpc.Client
}

func NewAccountDataReader(client *rpc.Client) *accountDataReader {
return &accountDataReader{client: client}
}

func (r *accountDataReader) ReadAll(ctx context.Context, pk solana.PublicKey, opts *rpc.GetAccountInfoOpts) ([]byte, error) {
result, err := r.client.GetAccountInfoWithOpts(ctx, pk, opts)
if err != nil {
return nil, err
}

bts := result.Value.Data.GetBinary()

return bts, nil
}
4 changes: 2 additions & 2 deletions pkg/solana/chainreader/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
testReadDef := readDef
testReadDef.PDADefiniton = testCase.pdaDefinition
testReadDef.PDADefinition = testCase.pdaDefinition
testReadDef.InputModifications = testCase.inputModifier
testCodec, conf := newTestConfAndCodecWithInjectibleReadDef(t, PDAAccount, testReadDef)
encoded, err := testCodec.Encode(ctx, expected, testutils.TestStructWithNestedStruct)
Expand Down Expand Up @@ -428,7 +428,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) {
readDef := config.ReadDefinition{
ChainSpecificName: testutils.TestStructWithNestedStruct,
ReadType: config.Account,
PDADefiniton: codec.PDATypeDef{
PDADefinition: codec.PDATypeDef{
Prefix: prefixString,
Seeds: []codec.PDASeed{
{
Expand Down
55 changes: 33 additions & 22 deletions pkg/solana/chainreader/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,63 @@ import (
)

type readValues struct {
address string
contract string
genericName string
address string
contract string
// First read in multi read has type info that other sequential reads are filling out.
// this works by having hard coder codec modifier define fields that are filled out by subsequent reads.
multiRead []string
}

// lookup provides basic utilities for mapping a complete readIdentifier to
// finite contract read information
type lookup struct {
mu sync.RWMutex
// contractReadNames maps a contract name to all available namePairs (method, log, event, etc.)
contractReadNames map[string][]string
// contractReadNames maps a program name to all available reads (accounts, PDAs, logs).
// Every key (generic read name) can be composed of multiple reads of the same program. Right now all of them have to be of same type (account, PDA or log).
contractReadNames map[string]map[string][]string
// readIdentifiers maps from a complete readIdentifier string to finite read data
// a readIdentifier is a combination of address, contract, and chainSpecificName as a concatenated string
readIdentifiers map[string]readValues
}

func newLookup() *lookup {
return &lookup{
contractReadNames: make(map[string][]string),
contractReadNames: make(map[string]map[string][]string),
readIdentifiers: make(map[string]readValues),
}
}

func (l *lookup) addReadNameForContract(contract string, genericName string) {
func (l *lookup) addReadNameForContract(contract, genericName string, multiRead []string) {
l.mu.Lock()
defer l.mu.Unlock()

readNames, exists := l.contractReadNames[contract]
if !exists {
readNames = []string{}
readNames = make(map[string][]string)
}

l.contractReadNames[contract] = append(readNames, genericName)
readNames[genericName] = multiRead

l.contractReadNames[contract] = readNames
}

func (l *lookup) bindAddressForContract(contract, address string) {
l.mu.Lock()
defer l.mu.Unlock()

for _, genericName := range l.contractReadNames[contract] {
readIdentifier := types.BoundContract{
Address: address,
Name: contract,
}.ReadIdentifier(genericName)
for _, multiRead := range l.contractReadNames[contract] {
readIdentifier := ""
if len(multiRead) > 0 {
readIdentifier = types.BoundContract{
Address: address,
Name: contract,
}.ReadIdentifier(multiRead[0])
}

l.readIdentifiers[readIdentifier] = readValues{
address: address,
contract: contract,
genericName: genericName,
address: address,
contract: contract,
multiRead: multiRead,
}
}
}
Expand All @@ -64,11 +72,14 @@ func (l *lookup) unbindAddressForContract(contract, address string) {
l.mu.Lock()
defer l.mu.Unlock()

for _, genericName := range l.contractReadNames[contract] {
readIdentifier := types.BoundContract{
Address: address,
Name: contract,
}.ReadIdentifier(genericName)
for _, multiRead := range l.contractReadNames[contract] {
readIdentifier := ""
if len(multiRead) > 0 {
readIdentifier = types.BoundContract{
Address: address,
Name: contract,
}.ReadIdentifier(multiRead[0])
}

delete(l.readIdentifiers, readIdentifier)
}
Expand Down
Loading

0 comments on commit ee7ff3a

Please sign in to comment.