Skip to content

Commit

Permalink
Implement split params reading
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Feb 7, 2025
1 parent 98e102c commit 04abe5f
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 24 deletions.
16 changes: 11 additions & 5 deletions pkg/solana/chainreader/account_read_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package chainreader
import (
"context"
"fmt"
"slices"

"github.com/gagliardetto/solana-go"

"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
)
Expand All @@ -16,16 +18,16 @@ type accountReadBinding struct {
namespace, genericName string
codec types.RemoteCodec
key solana.PublicKey
isPda bool // flag to signify whether or not the account read is for a PDA
prefix []byte // only used for PDA public key calculation
prefix []byte
readType config.ReadType
}

func newAccountReadBinding(namespace, genericName string, prefix []byte, isPda bool) *accountReadBinding {
func newAccountReadBinding(namespace, genericName string, prefix []byte, readType config.ReadType) *accountReadBinding {
return &accountReadBinding{
namespace: namespace,
genericName: genericName,
prefix: prefix,
isPda: isPda,
readType: readType,
}
}

Expand All @@ -41,7 +43,7 @@ func (b *accountReadBinding) SetAddress(key solana.PublicKey) {

func (b *accountReadBinding) GetAddress(ctx context.Context, params any) (solana.PublicKey, error) {
// Return the bound key if normal account read
if !b.isPda {
if !slices.Contains([]config.ReadType{config.AccountSplitParams, config.AccountPDA}, b.readType) {
return b.key, nil
}
// Calculate the public key if PDA account read
Expand All @@ -64,6 +66,10 @@ func (b *accountReadBinding) Decode(ctx context.Context, bts []byte, outVal any)
return b.codec.Decode(ctx, bts, outVal, codec.WrapItemType(false, b.namespace, b.genericName, codec.ChainConfigTypeAccountDef))
}

func (b *accountReadBinding) ReadType() config.ReadType {
return b.readType
}

// buildSeedsSlice encodes and builds the seedslist to calculate the PDA public key
func (b *accountReadBinding) buildSeedsSlice(ctx context.Context, params any) ([][]byte, error) {
flattenedSeeds := make([]byte, 0, solana.MaxSeeds*solana.MaxSeedLength)
Expand Down
179 changes: 173 additions & 6 deletions pkg/solana/chainreader/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"reflect"
"strings"

"github.com/gagliardetto/solana-go"

"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/smartcontractkit/chainlink-common/pkg/values"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
)

type call struct {
Expand Down Expand Up @@ -66,27 +68,150 @@ func doMultiRead(ctx context.Context, client MultipleAccountGetter, bindings nam
return nil
}

type resultIndex struct {
contractName, readName string
readType config.ReadType
}

func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindings namespaceBindings, batch []call) ([]batchResultWithErr, error) {
resultIndexes := make(map[int]resultIndex)
var regularBatch []call
var splitParamsBatch []call

// Create the list of public keys to fetch
keys := make([]solana.PublicKey, len(batch))
regularKeys := make([]solana.PublicKey, len(batch))
for idx, batchCall := range batch {
binding, err := bindings.GetReadBinding(batchCall.ContractName, batchCall.ReadName)
if err != nil {
return nil, fmt.Errorf("%w: read binding not found for contract: %q read: %q: %w", types.ErrInvalidConfig, batchCall.ContractName, batchCall.ReadName, err)
}

keys[idx], err = binding.GetAddress(ctx, batchCall.Params)
resultIndexes[idx] = resultIndex{
contractName: batchCall.ContractName,
readName: batchCall.ReadName,
readType: binding.ReadType()}

if binding.ReadType() == config.AccountSplitParams {
splitParamsBatch = append(splitParamsBatch, batchCall)
} else {
regularKeys[idx], err = binding.GetAddress(ctx, batchCall.Params)
if err != nil {
return nil, fmt.Errorf("failed to get address for contract: %q read: %q: %w", batchCall.ContractName, batchCall.ReadName, err)
}
regularBatch = append(regularBatch, batchCall)
}
}

var splitParamBatchCallsResults []batchResultWithErr
for _, callToSplit := range splitParamsBatch {
binding, err := bindings.GetReadBinding(callToSplit.ContractName, callToSplit.ReadName)
if err != nil {
return nil, fmt.Errorf("%w: read binding not found for contract: %q read: %q: %w", types.ErrInvalidConfig, callToSplit.ContractName, callToSplit.ReadName, err)
}

results, err := doSplitParamsBatchCall(ctx, client, callToSplit, bindings, binding)
if err != nil {
return nil, fmt.Errorf("failed to get address for contract: %q read: %q: %w", batchCall.ContractName, batchCall.ReadName, err)
return nil, err
}

splitParamBatchCallsResults = append(splitParamBatchCallsResults, results)
}

// Fetch the account data
data, err := client.GetMultipleAccountData(ctx, keys...)
data, err := client.GetMultipleAccountData(ctx, regularKeys...)
if err != nil {
return nil, err
}

return mergeBatchResults(decodeBatchResults(ctx, batch, regularKeys, data, bindings), splitParamBatchCallsResults, resultIndexes)
}

func mergeBatchResults(regularResults, splitParamBatchCallsResults []batchResultWithErr, resultIndexes map[int]resultIndex) ([]batchResultWithErr, error) {
finalResults := make([]batchResultWithErr, len(resultIndexes))
seen := make(map[int]bool)

for _, result := range regularResults {
for key, resIndex := range resultIndexes {
if result.contractName == resIndex.contractName && result.readName == resIndex.readName && resIndex.readType != config.AccountSplitParams {
if !seen[key] {
finalResults[key] = result
seen[key] = true
}
}
}
}

for _, result := range splitParamBatchCallsResults {
for key, resIndex := range resultIndexes {
if result.contractName == resIndex.contractName && result.readName == resIndex.readName && resIndex.readType == config.AccountSplitParams {
if !seen[key] {
finalResults[key] = result
seen[key] = true
}
}
}
}

var mergeResErr error
for key, val := range seen {
if !val {
mergeResErr = errors.Join(mergeResErr, fmt.Errorf("failed to find result for call: %v", resultIndexes[key]))
}
}

if mergeResErr != nil {
return nil, mergeResErr
}

if len(finalResults) != len(resultIndexes) {
return nil, fmt.Errorf("%w failed to mere batch results, final results length does not match batch length", types.ErrInternal)
}

return finalResults, nil
}

func doSplitParamsBatchCall(ctx context.Context, client MultipleAccountGetter, callToSplit call, bindings namespaceBindings, binding readBinding) (batchResultWithErr, error) {
sPBatch, err := getSplitParamsBatch(callToSplit)
if err != nil {
return batchResultWithErr{}, err
}

var sPKeys []solana.PublicKey
for _, spCall := range sPBatch {
key, err := binding.GetAddress(ctx, spCall.Params)
if err != nil {
return batchResultWithErr{}, fmt.Errorf("failed to get address for contract: %q read: %q: %w", callToSplit.ContractName, callToSplit.ReadName, err)
}
sPKeys = append(sPKeys, key)
}

data, err := client.GetMultipleAccountData(ctx, sPKeys...)
if err != nil {
return batchResultWithErr{}, err
}

results := decodeBatchResults(ctx, sPBatch, sPKeys, data, bindings)

if len(results) != len(sPBatch) {
return batchResultWithErr{}, fmt.Errorf("results length does not match split params batch length for contract: %q read: %q", callToSplit.ContractName, callToSplit.ReadName)
}

var returnVal []any
var returnErr error
for _, res := range results {
returnVal = append(returnVal, res.returnVal)
returnErr = errors.Join(returnErr, res.err)
}

return batchResultWithErr{
contractName: callToSplit.ContractName,
readName: callToSplit.ReadName,
returnVal: returnVal,
err: returnErr,
}, nil
}

func decodeBatchResults(ctx context.Context, batch []call, keys []solana.PublicKey, data [][]byte, bindings namespaceBindings) []batchResultWithErr {
results := make([]batchResultWithErr, len(batch))

// decode batch call results
Expand Down Expand Up @@ -115,8 +240,7 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
decodeReturnVal(ctx, binding, data[idx], results[idx].returnVal),
results[idx].err)
}

return results, nil
return results
}

// decodeReturnVal checks if returnVal is a *values.Value vs. a normal struct pointer, and decodes accordingly.
Expand Down Expand Up @@ -147,3 +271,46 @@ func decodeReturnVal(ctx context.Context, binding readBinding, raw []byte, retur

return nil
}

func getSplitParamsBatch(c call) ([]call, error) {
sParams, isOk := extractSliceElements(c.Params)
if !isOk {
return nil, fmt.Errorf("failed to extract params slice elements for contract: %q split params read: %q", c.ContractName, c.ReadName)
}

sReturnVals, isOk := extractSliceElements(c.ReturnVal)
if !isOk {
return nil, fmt.Errorf("failed to extract return values slice elements for contract: %q split params read: %q", c.ContractName, c.ReadName)
}

if len(sParams) != len(sReturnVals) {
return nil, fmt.Errorf("params and return values slice lengths do not match for contract: %q split params read: %q", c.ContractName, c.ReadName)
}

batch := make([]call, len(sParams))
for idx := range sParams {
batch[idx] = call{
ContractName: c.ContractName,
ReadName: c.ReadName,
Params: sParams[idx],
ReturnVal: sReturnVals[idx],
}
}

return batch, nil
}

func extractSliceElements(input any) ([]any, bool) {
rv := reflect.ValueOf(input)
if rv.Kind() != reflect.Slice {
return nil, false
}

length := rv.Len()
elements := make([]any, length)
for i := 0; i < length; i++ {
elements[i] = rv.Index(i).Interface()
}

return elements, true
}
13 changes: 2 additions & 11 deletions pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@ import (
"github.com/gagliardetto/solana-go"

"github.com/smartcontractkit/chainlink-common/pkg/types"
)

type ReadType int

const (
Log ReadType = iota
Account
AccountPDA
AccountMulti
AccountSplitParams
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
)

type readBinding interface {
Expand All @@ -25,7 +16,7 @@ type readBinding interface {
SetCodec(types.RemoteCodec)
CreateType(bool) (any, error)
Decode(context.Context, []byte, any) error
ReadType() ReadType
ReadType() config.ReadType
}

// key is namespace
Expand Down
4 changes: 2 additions & 2 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName
// Create PDA read binding if PDA prefix or seeds configs are populated
if len(readDefinition.PDADefinition.Prefix) > 0 || len(readDefinition.PDADefinition.Seeds) > 0 {
inputAccountIDLDef = readDefinition.PDADefinition
reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, true)
reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, readDefinition.ReadType)
} else {
inputAccountIDLDef = codec.NilIdlTypeDefTy
reader = newAccountReadBinding(namespace, genericName, nil, false)
reader = newAccountReadBinding(namespace, genericName, nil, config.Account)
}
if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil {
return err
Expand Down
9 changes: 9 additions & 0 deletions pkg/solana/config/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ type ReadType int

const (
Account ReadType = iota
AccountPDA
AccountMulti
AccountSplitParams
Event
)

Expand All @@ -49,6 +52,12 @@ func (r ReadType) String() string {
return "Account"
case Event:
return "Event"
case AccountPDA:
return "AccountPDA"
case AccountMulti:
return "AccountMulti"
case AccountSplitParams:
return "AccountSplitParams"
default:
return fmt.Sprintf("Unknown(%d)", r)
}
Expand Down

0 comments on commit 04abe5f

Please sign in to comment.