From 04abe5ff741e4d00cba4897ee65acadfdae7f1d9 Mon Sep 17 00:00:00 2001 From: ilija Date: Fri, 7 Feb 2025 19:13:31 +0100 Subject: [PATCH] Implement split params reading --- .../chainreader/account_read_binding.go | 16 +- pkg/solana/chainreader/batch.go | 179 +++++++++++++++++- pkg/solana/chainreader/bindings.go | 13 +- pkg/solana/chainreader/chain_reader.go | 4 +- pkg/solana/config/chain_reader.go | 9 + 5 files changed, 197 insertions(+), 24 deletions(-) diff --git a/pkg/solana/chainreader/account_read_binding.go b/pkg/solana/chainreader/account_read_binding.go index 69fdc1328..fdc0d281e 100644 --- a/pkg/solana/chainreader/account_read_binding.go +++ b/pkg/solana/chainreader/account_read_binding.go @@ -3,10 +3,12 @@ package chainreader import ( "context" "fmt" + "slices" "github.com/gagliardetto/solana-go" "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) @@ -16,16 +18,16 @@ type accountReadBinding struct { namespace, genericName string codec types.RemoteCodec key solana.PublicKey - isPda bool // flag to signify whether or not the account read is for a PDA - prefix []byte // only used for PDA public key calculation + prefix []byte + readType config.ReadType } -func newAccountReadBinding(namespace, genericName string, prefix []byte, isPda bool) *accountReadBinding { +func newAccountReadBinding(namespace, genericName string, prefix []byte, readType config.ReadType) *accountReadBinding { return &accountReadBinding{ namespace: namespace, genericName: genericName, prefix: prefix, - isPda: isPda, + readType: readType, } } @@ -41,7 +43,7 @@ func (b *accountReadBinding) SetAddress(key solana.PublicKey) { func (b *accountReadBinding) GetAddress(ctx context.Context, params any) (solana.PublicKey, error) { // Return the bound key if normal account read - if !b.isPda { + if !slices.Contains([]config.ReadType{config.AccountSplitParams, config.AccountPDA}, b.readType) { return b.key, nil } // Calculate the public key if PDA account read @@ -64,6 +66,10 @@ func (b *accountReadBinding) Decode(ctx context.Context, bts []byte, outVal any) return b.codec.Decode(ctx, bts, outVal, codec.WrapItemType(false, b.namespace, b.genericName, codec.ChainConfigTypeAccountDef)) } +func (b *accountReadBinding) ReadType() config.ReadType { + return b.readType +} + // buildSeedsSlice encodes and builds the seedslist to calculate the PDA public key func (b *accountReadBinding) buildSeedsSlice(ctx context.Context, params any) ([][]byte, error) { flattenedSeeds := make([]byte, 0, solana.MaxSeeds*solana.MaxSeedLength) diff --git a/pkg/solana/chainreader/batch.go b/pkg/solana/chainreader/batch.go index 630c2d3f8..37b9a2fab 100644 --- a/pkg/solana/chainreader/batch.go +++ b/pkg/solana/chainreader/batch.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" + "reflect" "strings" "github.com/gagliardetto/solana-go" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" ) type call struct { @@ -66,27 +68,150 @@ func doMultiRead(ctx context.Context, client MultipleAccountGetter, bindings nam return nil } +type resultIndex struct { + contractName, readName string + readType config.ReadType +} + func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindings namespaceBindings, batch []call) ([]batchResultWithErr, error) { + resultIndexes := make(map[int]resultIndex) + var regularBatch []call + var splitParamsBatch []call + // Create the list of public keys to fetch - keys := make([]solana.PublicKey, len(batch)) + regularKeys := make([]solana.PublicKey, len(batch)) for idx, batchCall := range batch { binding, err := bindings.GetReadBinding(batchCall.ContractName, batchCall.ReadName) if err != nil { return nil, fmt.Errorf("%w: read binding not found for contract: %q read: %q: %w", types.ErrInvalidConfig, batchCall.ContractName, batchCall.ReadName, err) } - keys[idx], err = binding.GetAddress(ctx, batchCall.Params) + resultIndexes[idx] = resultIndex{ + contractName: batchCall.ContractName, + readName: batchCall.ReadName, + readType: binding.ReadType()} + + if binding.ReadType() == config.AccountSplitParams { + splitParamsBatch = append(splitParamsBatch, batchCall) + } else { + regularKeys[idx], err = binding.GetAddress(ctx, batchCall.Params) + if err != nil { + return nil, fmt.Errorf("failed to get address for contract: %q read: %q: %w", batchCall.ContractName, batchCall.ReadName, err) + } + regularBatch = append(regularBatch, batchCall) + } + } + + var splitParamBatchCallsResults []batchResultWithErr + for _, callToSplit := range splitParamsBatch { + binding, err := bindings.GetReadBinding(callToSplit.ContractName, callToSplit.ReadName) + if err != nil { + return nil, fmt.Errorf("%w: read binding not found for contract: %q read: %q: %w", types.ErrInvalidConfig, callToSplit.ContractName, callToSplit.ReadName, err) + } + + results, err := doSplitParamsBatchCall(ctx, client, callToSplit, bindings, binding) if err != nil { - return nil, fmt.Errorf("failed to get address for contract: %q read: %q: %w", batchCall.ContractName, batchCall.ReadName, err) + return nil, err } + + splitParamBatchCallsResults = append(splitParamBatchCallsResults, results) } // Fetch the account data - data, err := client.GetMultipleAccountData(ctx, keys...) + data, err := client.GetMultipleAccountData(ctx, regularKeys...) if err != nil { return nil, err } + return mergeBatchResults(decodeBatchResults(ctx, batch, regularKeys, data, bindings), splitParamBatchCallsResults, resultIndexes) +} + +func mergeBatchResults(regularResults, splitParamBatchCallsResults []batchResultWithErr, resultIndexes map[int]resultIndex) ([]batchResultWithErr, error) { + finalResults := make([]batchResultWithErr, len(resultIndexes)) + seen := make(map[int]bool) + + for _, result := range regularResults { + for key, resIndex := range resultIndexes { + if result.contractName == resIndex.contractName && result.readName == resIndex.readName && resIndex.readType != config.AccountSplitParams { + if !seen[key] { + finalResults[key] = result + seen[key] = true + } + } + } + } + + for _, result := range splitParamBatchCallsResults { + for key, resIndex := range resultIndexes { + if result.contractName == resIndex.contractName && result.readName == resIndex.readName && resIndex.readType == config.AccountSplitParams { + if !seen[key] { + finalResults[key] = result + seen[key] = true + } + } + } + } + + var mergeResErr error + for key, val := range seen { + if !val { + mergeResErr = errors.Join(mergeResErr, fmt.Errorf("failed to find result for call: %v", resultIndexes[key])) + } + } + + if mergeResErr != nil { + return nil, mergeResErr + } + + if len(finalResults) != len(resultIndexes) { + return nil, fmt.Errorf("%w failed to mere batch results, final results length does not match batch length", types.ErrInternal) + } + + return finalResults, nil +} + +func doSplitParamsBatchCall(ctx context.Context, client MultipleAccountGetter, callToSplit call, bindings namespaceBindings, binding readBinding) (batchResultWithErr, error) { + sPBatch, err := getSplitParamsBatch(callToSplit) + if err != nil { + return batchResultWithErr{}, err + } + + var sPKeys []solana.PublicKey + for _, spCall := range sPBatch { + key, err := binding.GetAddress(ctx, spCall.Params) + if err != nil { + return batchResultWithErr{}, fmt.Errorf("failed to get address for contract: %q read: %q: %w", callToSplit.ContractName, callToSplit.ReadName, err) + } + sPKeys = append(sPKeys, key) + } + + data, err := client.GetMultipleAccountData(ctx, sPKeys...) + if err != nil { + return batchResultWithErr{}, err + } + + results := decodeBatchResults(ctx, sPBatch, sPKeys, data, bindings) + + if len(results) != len(sPBatch) { + return batchResultWithErr{}, fmt.Errorf("results length does not match split params batch length for contract: %q read: %q", callToSplit.ContractName, callToSplit.ReadName) + } + + var returnVal []any + var returnErr error + for _, res := range results { + returnVal = append(returnVal, res.returnVal) + returnErr = errors.Join(returnErr, res.err) + } + + return batchResultWithErr{ + contractName: callToSplit.ContractName, + readName: callToSplit.ReadName, + returnVal: returnVal, + err: returnErr, + }, nil +} + +func decodeBatchResults(ctx context.Context, batch []call, keys []solana.PublicKey, data [][]byte, bindings namespaceBindings) []batchResultWithErr { results := make([]batchResultWithErr, len(batch)) // decode batch call results @@ -115,8 +240,7 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin decodeReturnVal(ctx, binding, data[idx], results[idx].returnVal), results[idx].err) } - - return results, nil + return results } // decodeReturnVal checks if returnVal is a *values.Value vs. a normal struct pointer, and decodes accordingly. @@ -147,3 +271,46 @@ func decodeReturnVal(ctx context.Context, binding readBinding, raw []byte, retur return nil } + +func getSplitParamsBatch(c call) ([]call, error) { + sParams, isOk := extractSliceElements(c.Params) + if !isOk { + return nil, fmt.Errorf("failed to extract params slice elements for contract: %q split params read: %q", c.ContractName, c.ReadName) + } + + sReturnVals, isOk := extractSliceElements(c.ReturnVal) + if !isOk { + return nil, fmt.Errorf("failed to extract return values slice elements for contract: %q split params read: %q", c.ContractName, c.ReadName) + } + + if len(sParams) != len(sReturnVals) { + return nil, fmt.Errorf("params and return values slice lengths do not match for contract: %q split params read: %q", c.ContractName, c.ReadName) + } + + batch := make([]call, len(sParams)) + for idx := range sParams { + batch[idx] = call{ + ContractName: c.ContractName, + ReadName: c.ReadName, + Params: sParams[idx], + ReturnVal: sReturnVals[idx], + } + } + + return batch, nil +} + +func extractSliceElements(input any) ([]any, bool) { + rv := reflect.ValueOf(input) + if rv.Kind() != reflect.Slice { + return nil, false + } + + length := rv.Len() + elements := make([]any, length) + for i := 0; i < length; i++ { + elements[i] = rv.Index(i).Interface() + } + + return elements, true +} diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index b26cae62e..d175b252c 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -7,16 +7,7 @@ import ( "github.com/gagliardetto/solana-go" "github.com/smartcontractkit/chainlink-common/pkg/types" -) - -type ReadType int - -const ( - Log ReadType = iota - Account - AccountPDA - AccountMulti - AccountSplitParams + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" ) type readBinding interface { @@ -25,7 +16,7 @@ type readBinding interface { SetCodec(types.RemoteCodec) CreateType(bool) (any, error) Decode(context.Context, []byte, any) error - ReadType() ReadType + ReadType() config.ReadType } // key is namespace diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index 3943834e1..5a4c279f5 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -313,10 +313,10 @@ func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName // Create PDA read binding if PDA prefix or seeds configs are populated if len(readDefinition.PDADefinition.Prefix) > 0 || len(readDefinition.PDADefinition.Seeds) > 0 { inputAccountIDLDef = readDefinition.PDADefinition - reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, true) + reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, readDefinition.ReadType) } else { inputAccountIDLDef = codec.NilIdlTypeDefTy - reader = newAccountReadBinding(namespace, genericName, nil, false) + reader = newAccountReadBinding(namespace, genericName, nil, config.Account) } if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil { return err diff --git a/pkg/solana/config/chain_reader.go b/pkg/solana/config/chain_reader.go index 475fe08a0..b83fe3bb5 100644 --- a/pkg/solana/config/chain_reader.go +++ b/pkg/solana/config/chain_reader.go @@ -40,6 +40,9 @@ type ReadType int const ( Account ReadType = iota + AccountPDA + AccountMulti + AccountSplitParams Event ) @@ -49,6 +52,12 @@ func (r ReadType) String() string { return "Account" case Event: return "Event" + case AccountPDA: + return "AccountPDA" + case AccountMulti: + return "AccountMulti" + case AccountSplitParams: + return "AccountSplitParams" default: return fmt.Sprintf("Unknown(%d)", r) }