diff --git a/contracts/generated/contract_reader_interface/Initialize.go b/contracts/generated/contract_reader_interface/Initialize.go index 03e13f579..5525400f8 100644 --- a/contracts/generated/contract_reader_interface/Initialize.go +++ b/contracts/generated/contract_reader_interface/Initialize.go @@ -19,14 +19,18 @@ type Initialize struct { // // [1] = [WRITE] data // - // [2] = [] systemProgram + // [2] = [WRITE] multiRead1 + // + // [3] = [WRITE] multiRead2 + // + // [4] = [] systemProgram ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` } // NewInitializeInstructionBuilder creates a new `Initialize` instruction builder. func NewInitializeInstructionBuilder() *Initialize { nd := &Initialize{ - AccountMetaSlice: make(ag_solanago.AccountMetaSlice, 3), + AccountMetaSlice: make(ag_solanago.AccountMetaSlice, 5), } return nd } @@ -65,15 +69,37 @@ func (inst *Initialize) GetDataAccount() *ag_solanago.AccountMeta { return inst.AccountMetaSlice[1] } +// SetMultiRead1Account sets the "multiRead1" account. +func (inst *Initialize) SetMultiRead1Account(multiRead1 ag_solanago.PublicKey) *Initialize { + inst.AccountMetaSlice[2] = ag_solanago.Meta(multiRead1).WRITE() + return inst +} + +// GetMultiRead1Account gets the "multiRead1" account. +func (inst *Initialize) GetMultiRead1Account() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[2] +} + +// SetMultiRead2Account sets the "multiRead2" account. +func (inst *Initialize) SetMultiRead2Account(multiRead2 ag_solanago.PublicKey) *Initialize { + inst.AccountMetaSlice[3] = ag_solanago.Meta(multiRead2).WRITE() + return inst +} + +// GetMultiRead2Account gets the "multiRead2" account. +func (inst *Initialize) GetMultiRead2Account() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[3] +} + // SetSystemProgramAccount sets the "systemProgram" account. func (inst *Initialize) SetSystemProgramAccount(systemProgram ag_solanago.PublicKey) *Initialize { - inst.AccountMetaSlice[2] = ag_solanago.Meta(systemProgram) + inst.AccountMetaSlice[4] = ag_solanago.Meta(systemProgram) return inst } // GetSystemProgramAccount gets the "systemProgram" account. func (inst *Initialize) GetSystemProgramAccount() *ag_solanago.AccountMeta { - return inst.AccountMetaSlice[2] + return inst.AccountMetaSlice[4] } func (inst Initialize) Build() *Instruction { @@ -113,6 +139,12 @@ func (inst *Initialize) Validate() error { return errors.New("accounts.Data is not set") } if inst.AccountMetaSlice[2] == nil { + return errors.New("accounts.MultiRead1 is not set") + } + if inst.AccountMetaSlice[3] == nil { + return errors.New("accounts.MultiRead2 is not set") + } + if inst.AccountMetaSlice[4] == nil { return errors.New("accounts.SystemProgram is not set") } } @@ -134,10 +166,12 @@ func (inst *Initialize) EncodeToTree(parent ag_treeout.Branches) { }) // Accounts of the instruction: - instructionBranch.Child("Accounts[len=3]").ParentFunc(func(accountsBranch ag_treeout.Branches) { + instructionBranch.Child("Accounts[len=5]").ParentFunc(func(accountsBranch ag_treeout.Branches) { accountsBranch.Child(ag_format.Meta(" signer", inst.AccountMetaSlice[0])) accountsBranch.Child(ag_format.Meta(" data", inst.AccountMetaSlice[1])) - accountsBranch.Child(ag_format.Meta("systemProgram", inst.AccountMetaSlice[2])) + accountsBranch.Child(ag_format.Meta(" multiRead1", inst.AccountMetaSlice[2])) + accountsBranch.Child(ag_format.Meta(" multiRead2", inst.AccountMetaSlice[3])) + accountsBranch.Child(ag_format.Meta("systemProgram", inst.AccountMetaSlice[4])) }) }) }) @@ -178,11 +212,15 @@ func NewInitializeInstruction( // Accounts: signer ag_solanago.PublicKey, data ag_solanago.PublicKey, + multiRead1 ag_solanago.PublicKey, + multiRead2 ag_solanago.PublicKey, systemProgram ag_solanago.PublicKey) *Initialize { return NewInitializeInstructionBuilder(). SetTestIdx(testIdx). SetValue(value). SetSignerAccount(signer). SetDataAccount(data). + SetMultiRead1Account(multiRead1). + SetMultiRead2Account(multiRead2). SetSystemProgramAccount(systemProgram) } diff --git a/contracts/generated/contract_reader_interface/accounts.go b/contracts/generated/contract_reader_interface/accounts.go index 3a7197749..b314d90f2 100644 --- a/contracts/generated/contract_reader_interface/accounts.go +++ b/contracts/generated/contract_reader_interface/accounts.go @@ -353,3 +353,131 @@ func (obj *TestStruct) UnmarshalWithDecoder(decoder *ag_binary.Decoder) (err err } return nil } + +type MultiRead1 struct { + A uint8 + B int16 + C bool +} + +var MultiRead1Discriminator = [8]byte{15, 46, 242, 154, 22, 213, 170, 20} + +func (obj MultiRead1) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { + // Write account discriminator: + err = encoder.WriteBytes(MultiRead1Discriminator[:], false) + if err != nil { + return err + } + // Serialize `A` param: + err = encoder.Encode(obj.A) + if err != nil { + return err + } + // Serialize `B` param: + err = encoder.Encode(obj.B) + if err != nil { + return err + } + // Serialize `C` param: + err = encoder.Encode(obj.C) + if err != nil { + return err + } + return nil +} + +func (obj *MultiRead1) UnmarshalWithDecoder(decoder *ag_binary.Decoder) (err error) { + // Read and check account discriminator: + { + discriminator, err := decoder.ReadTypeID() + if err != nil { + return err + } + if !discriminator.Equal(MultiRead1Discriminator[:]) { + return fmt.Errorf( + "wrong discriminator: wanted %s, got %s", + "[15 46 242 154 22 213 170 20]", + fmt.Sprint(discriminator[:])) + } + } + // Deserialize `A`: + err = decoder.Decode(&obj.A) + if err != nil { + return err + } + // Deserialize `B`: + err = decoder.Decode(&obj.B) + if err != nil { + return err + } + // Deserialize `C`: + err = decoder.Decode(&obj.C) + if err != nil { + return err + } + return nil +} + +type MultiRead2 struct { + U string + V bool + W [2]uint64 +} + +var MultiRead2Discriminator = [8]byte{17, 116, 102, 101, 239, 43, 252, 86} + +func (obj MultiRead2) MarshalWithEncoder(encoder *ag_binary.Encoder) (err error) { + // Write account discriminator: + err = encoder.WriteBytes(MultiRead2Discriminator[:], false) + if err != nil { + return err + } + // Serialize `U` param: + err = encoder.Encode(obj.U) + if err != nil { + return err + } + // Serialize `V` param: + err = encoder.Encode(obj.V) + if err != nil { + return err + } + // Serialize `W` param: + err = encoder.Encode(obj.W) + if err != nil { + return err + } + return nil +} + +func (obj *MultiRead2) UnmarshalWithDecoder(decoder *ag_binary.Decoder) (err error) { + // Read and check account discriminator: + { + discriminator, err := decoder.ReadTypeID() + if err != nil { + return err + } + if !discriminator.Equal(MultiRead2Discriminator[:]) { + return fmt.Errorf( + "wrong discriminator: wanted %s, got %s", + "[17 116 102 101 239 43 252 86]", + fmt.Sprint(discriminator[:])) + } + } + // Deserialize `U`: + err = decoder.Decode(&obj.U) + if err != nil { + return err + } + // Deserialize `V`: + err = decoder.Decode(&obj.V) + if err != nil { + return err + } + // Deserialize `W`: + err = decoder.Decode(&obj.W) + if err != nil { + return err + } + return nil +} diff --git a/contracts/programs/contract-reader-interface/src/lib.rs b/contracts/programs/contract-reader-interface/src/lib.rs index 8d7bc3dc0..a3708578f 100644 --- a/contracts/programs/contract-reader-interface/src/lib.rs +++ b/contracts/programs/contract-reader-interface/src/lib.rs @@ -14,6 +14,16 @@ pub mod contract_reader_interface { account.idx = test_idx; account.bump = ctx.bumps.data; + let multi_read1 = &mut ctx.accounts.multi_read1; + multi_read1.a = 1; + multi_read1.b = 2; + multi_read1.c = true; + + let multi_read2 = &mut ctx.accounts.multi_read2; + multi_read2.u = "Hello".to_string(); + multi_read2.v = true; + multi_read2.w = [123, 456]; + Ok(()) } @@ -65,6 +75,22 @@ pub struct Initialize<'info> { bump)] pub data: Account<'info, DataAccount>, + #[account( + init_if_needed, + payer = signer, + space = size_of::() + 8, + seeds = [b"multi_read1"], + bump)] + pub multi_read1: Account<'info, MultiRead1>, + + #[account( + init_if_needed, + payer = signer, + space = size_of::() + 8, + seeds = [b"multi_read2"], + bump)] + pub multi_read2: Account<'info, MultiRead2>, + pub system_program: Program<'info, System>, } @@ -197,3 +223,17 @@ pub struct InnerStaticTestStruct { pub i: i64, pub a: Pubkey, } + +#[account] +pub struct MultiRead1 { + pub a: u8, + pub b: i16, + pub c: bool, +} + +#[account] +pub struct MultiRead2 { + pub u: String, + pub v: bool, + pub w: [u64; 2], +} diff --git a/integration-tests/relayinterface/chain_components_test.go b/integration-tests/relayinterface/chain_components_test.go index bb363d285..f0159d02a 100644 --- a/integration-tests/relayinterface/chain_components_test.go +++ b/integration-tests/relayinterface/chain_components_test.go @@ -193,7 +193,35 @@ func RunContractReaderTests[T WrappedTestingT[T]](t T, it *SolanaChainComponents } func RunContractReaderInLoopTests[T WrappedTestingT[T]](t T, it ChainComponentsInterfaceTester[T]) { - RunContractReaderInterfaceTests(t, it, false, true) + //RunContractReaderInterfaceTests(t, it, false, true) + testCases := []Testcase[T]{ + { + Name: ContractReaderGetLatestValueWithPrimitiveReturn, + Test: func(t T) { + cr := it.GetContractReader(t) + bindings := it.GetBindings(t) + ctx := tests.Context(t) + + bound := BindingsByName(bindings, AnyContractName)[0] + + require.NoError(t, cr.Bind(ctx, bindings)) + + type MultiReadResult struct { + A uint8 + B int16 + U string + V bool + } + + mRR := MultiReadResult{} + require.NoError(t, cr.GetLatestValue(ctx, bound.ReadIdentifier(MultiRead), primitives.Unconfirmed, nil, &mRR)) + + expectedMRR := MultiReadResult{A: 1, B: 2, U: "Hello", V: true} + require.Equal(t, expectedMRR, mRR) + }, + }, + } + RunTests(t, it, testCases) } type SolanaChainComponentsInterfaceTesterHelper[T WrappedTestingT[T]] interface { @@ -493,6 +521,8 @@ func (h *helper) runInitialize( SubmitTransactionToCW(t, &it, cw, MethodSettingStruct, storeStructArgs, types.BoundContract{Name: contractName, Address: programID.String()}, types.Finalized) } +const MultiRead = "MultiRead" + func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T) config.ContractReader { idx := it.getTestIdx(t.Name()) pdaDataPrefix := []byte("data") @@ -503,7 +533,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T uint64ReadDef := config.ReadDefinition{ ChainSpecificName: "DataAccount", ReadType: config.Account, - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: pdaDataPrefix, }, OutputModifications: commoncodec.ModifiersConfig{ @@ -521,11 +551,28 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T AnyContractName: { IDL: mustUnmarshalIDL(t, string(it.Helper.GetPrimaryIDL(t))), Reads: map[string]config.ReadDefinition{ + MultiRead: { + ChainSpecificName: "MultiRead1", + PDADefinition: codec.PDATypeDef{ + Prefix: []byte("multi_read1"), + }, + OutputModifications: commoncodec.ModifiersConfig{ + &commoncodec.HardCodeModifierConfig{ + OffChainValues: map[string]any{"U": "", "V": false}, + }, + }, + MultiReader: &config.MultiReader{Reads: []config.ReadDefinition{ + { + ChainSpecificName: "MultiRead2", + PDADefinition: codec.PDATypeDef{Prefix: []byte("multi_read2")}, + }, + }}, + }, MethodReturningUint64: uint64ReadDef, MethodReturningUint64Slice: { ChainSpecificName: "DataAccount", ReadType: config.Account, - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: pdaDataPrefix, }, OutputModifications: commoncodec.ModifiersConfig{ @@ -535,7 +582,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T MethodSettingUint64: { ChainSpecificName: "DataAccount", ReadType: config.Account, - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: pdaDataPrefix, }, OutputModifications: commoncodec.ModifiersConfig{ @@ -545,7 +592,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T MethodReturningSeenStruct: { ChainSpecificName: "TestStruct", ReadType: config.Account, - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: pdaStructDataPrefix, }, OutputModifications: commoncodec.ModifiersConfig{ @@ -567,7 +614,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T }, MethodTakingLatestParamsReturningTestStruct: { ChainSpecificName: "TestStruct", - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: pdaStructDataPrefix, }, OutputModifications: commoncodec.ModifiersConfig{ @@ -594,7 +641,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T Reads: map[string]config.ReadDefinition{ MethodReturningUint64: { ChainSpecificName: "Data", - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: pdaDataPrefix, }, OutputModifications: commoncodec.ModifiersConfig{ @@ -645,6 +692,30 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractWriterConfig(t T IsWritable: true, IsSigner: false, }, + chainwriter.PDALookups{ + Name: "MultiRead1", + PublicKey: chainwriter.AccountConstant{ + Name: "ProgramID", + Address: primaryProgramPubKey, + }, + Seeds: []chainwriter.Seed{ + {Static: []byte("multi_read1")}, + }, + IsWritable: true, + IsSigner: false, + }, + chainwriter.PDALookups{ + Name: "MultiRead2", + PublicKey: chainwriter.AccountConstant{ + Name: "ProgramID", + Address: primaryProgramPubKey, + }, + Seeds: []chainwriter.Seed{ + {Static: []byte("multi_read2")}, + }, + IsWritable: true, + IsSigner: false, + }, chainwriter.AccountConstant{ Name: "SystemProgram", Address: solana.SystemProgramID.String(), diff --git a/pkg/solana/chainreader/batch.go b/pkg/solana/chainreader/batch.go index af50f3a38..a3aa7c297 100644 --- a/pkg/solana/chainreader/batch.go +++ b/pkg/solana/chainreader/batch.go @@ -4,9 +4,11 @@ import ( "context" "errors" "fmt" + "strings" "github.com/gagliardetto/solana-go" + "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/values" ) @@ -30,20 +32,53 @@ type MultipleAccountGetter interface { GetMultipleAccountData(context.Context, ...solana.PublicKey) ([][]byte, error) } +// doMultiRead aggregate results from multiple PDAs from the same contract into one result. +func doMultiRead(ctx context.Context, client MultipleAccountGetter, bindings bindingsRegistry, rv readValues, returnValue any) error { + batch := make([]call, len(rv.multiRead)) + for idx, readName := range rv.multiRead { + batch[idx] = call{ + Namespace: rv.contract, + ReadName: readName, + ReturnVal: returnValue, + } + } + + results, err := doMethodBatchCall(ctx, client, bindings, batch) + if err != nil { + return err + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("failed to do a multiRead: %q on contract: %q with address: %q with: %d total calls:\n", rv.multiRead[0], rv.contract, rv.address, len(rv.multiRead))) + + var errCount int + for i, r := range results { + if r.err != nil { + errCount++ + sb.WriteString(fmt.Sprintf("- call: #%d with readName: %q and address: %q failed with err: %s\n", i+1, r.readName, r.address, r.err)) + } + } + + if errCount != 0 { + return errors.New(sb.String()) + } + + return nil +} + func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindingsRegistry bindingsRegistry, batch []call) ([]batchResultWithErr, error) { // Create the list of public keys to fetch keys := make([]solana.PublicKey, len(batch)) - for idx, call := range batch { - rBinding, err := bindingsRegistry.GetReadBinding(call.Namespace, call.ReadName) + for idx, batchCall := range batch { + rBinding, err := bindingsRegistry.GetReadBinding(batchCall.Namespace, batchCall.ReadName) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: read binding not found for contract: %q read: %q: %w", types.ErrInvalidConfig, batchCall.Namespace, batchCall.ReadName, err) } - key, err := rBinding.GetAddress(ctx, call.Params) + keys[idx], err = rBinding.GetAddress(ctx, batchCall.Params) if err != nil { - return nil, fmt.Errorf("failed to get address for %s account read: %w", call.ReadName, err) + return nil, fmt.Errorf("failed to get address for contract: %q read: %q: %w", batchCall.Namespace, batchCall.ReadName, err) } - keys[idx] = key } // Fetch the account data @@ -55,12 +90,12 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin results := make([]batchResultWithErr, len(batch)) // decode batch call results - for idx, call := range batch { + for idx, batchCall := range batch { results[idx] = batchResultWithErr{ address: keys[idx].String(), - namespace: call.Namespace, - readName: call.ReadName, - returnVal: call.ReturnVal, + namespace: batchCall.Namespace, + readName: batchCall.ReadName, + returnVal: batchCall.ReturnVal, } if data[idx] == nil || len(data[idx]) == 0 { @@ -76,36 +111,39 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin continue } - ptrToValue, isValue := call.ReturnVal.(*values.Value) - if !isValue { - results[idx].err = errors.Join( - results[idx].err, - rBinding.Decode(ctx, data[idx], results[idx].returnVal), - ) - continue - } - - contractType, err := rBinding.CreateType(false) - if err != nil { - results[idx].err = err + results[idx].err = errors.Join( + decodeReturnVal(ctx, rBinding, data[idx], results[idx].returnVal), + results[idx].err) + } - continue - } + return results, nil +} - results[idx].err = errors.Join( - results[idx].err, - rBinding.Decode(ctx, data[idx], contractType), - ) +// decodeReturnVal checks if returnVal is a *values.Value vs. a normal struct pointer, and decodes accordingly. +func decodeReturnVal(ctx context.Context, binding readBinding, raw []byte, returnVal any) error { + // If we are not dealing with a `*values.Value`, just decode directly. + ptrToValue, isValue := returnVal.(*values.Value) + if !isValue { + return binding.Decode(ctx, raw, returnVal) + } - value, err := values.Wrap(contractType) - if err != nil { - results[idx].err = errors.Join(results[idx].err, err) + // Otherwise, we need to create an intermediate type, decode into it, + // wrap it, and set it back into *values.Value + contractType, err := binding.CreateType(false) + if err != nil { + return err + } - continue - } + if err = binding.Decode(ctx, raw, contractType); err != nil { + return err + } - *ptrToValue = value + value, err := values.Wrap(contractType) + if err != nil { + return err } - return results, nil + *ptrToValue = value + + return nil } diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index 4eb17ef44..a7608e338 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -8,7 +8,6 @@ import ( "sync" "github.com/gagliardetto/solana-go" - "github.com/gagliardetto/solana-go/rpc" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -154,13 +153,21 @@ func (s *ContractReaderService) GetLatestValue(ctx context.Context, readIdentifi values, ok := s.lookup.getContractForReadIdentifiers(readIdentifier) if !ok { - return fmt.Errorf("%w: no contract for read identifier %s", types.ErrInvalidType, readIdentifier) + return fmt.Errorf("%w: no contract for read identifier: %q", types.ErrInvalidType, readIdentifier) + } + + if len(values.multiRead) == 0 { + return fmt.Errorf("%w: no reads defined for readIdentifier: %q", types.ErrInvalidConfig, readIdentifier) + } + + if len(values.multiRead) > 1 { + return doMultiRead(ctx, s.client, s.bdRegistry, values, returnVal) } batch := []call{ { Namespace: values.contract, - ReadName: values.genericName, + ReadName: values.multiRead[0], Params: params, ReturnVal: returnVal, }, @@ -185,7 +192,7 @@ func (s *ContractReaderService) GetLatestValue(ctx context.Context, readIdentifi // BatchGetLatestValues implements the types.ContractReader interface. func (s *ContractReaderService) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { idxLookup := make(map[types.BoundContract][]int) - batch := []call{} + var batch []call for bound, req := range request { idxLookup[bound] = make([]int, len(req)) @@ -300,7 +307,11 @@ func (s *ContractReaderService) CreateContractType(readIdentifier string, forEnc return nil, fmt.Errorf("%w: no contract for read identifier", types.ErrInvalidConfig) } - return s.bdRegistry.CreateType(values.contract, values.genericName, forEncoding) + if len(values.multiRead) == 0 { + return nil, fmt.Errorf("%w: no reads defined for read identifier", types.ErrInvalidConfig) + } + + return s.bdRegistry.CreateType(values.contract, values.multiRead[0], forEncoding) } func (s *ContractReaderService) addCodecDef(forEncoding bool, namespace, genericName string, idl codec.IDL, idlDefinition interface{}, modCfg commoncodec.ModifiersConfig) error { @@ -375,7 +386,16 @@ func (s *ContractReaderService) addAccountRead(namespace string, genericName str return err } - s.lookup.addReadNameForContract(namespace, genericName) + multiRead := []string{genericName} + if readDefinition.MultiReader != nil { + reads, err := s.addMultiAccountRead(namespace, readDefinition, idl) + if err != nil { + return err + } + multiRead = append(multiRead, reads...) + } + + s.lookup.addReadNameForContract(namespace, genericName, multiRead) var ( reader readBinding @@ -383,9 +403,9 @@ func (s *ContractReaderService) addAccountRead(namespace string, genericName str ) // Create PDA read binding if PDA prefix or seeds configs are populated - if readDefinition.PDADefiniton.Prefix != nil || len(readDefinition.PDADefiniton.Seeds) > 0 { - inputAccountIDLDef = readDefinition.PDADefiniton - reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefiniton.Prefix, true) + if readDefinition.PDADefinition.Prefix != nil || len(readDefinition.PDADefinition.Seeds) > 0 { + inputAccountIDLDef = readDefinition.PDADefinition + reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, true) } else { inputAccountIDLDef = codec.NilIdlTypeDefTy reader = newAccountReadBinding(namespace, genericName, nil, false) @@ -395,10 +415,35 @@ func (s *ContractReaderService) addAccountRead(namespace string, genericName str } s.bdRegistry.AddReadBinding(namespace, genericName, reader) - return nil } +func (s *ContractReaderService) addMultiAccountRead(namespace string, readDefinition config.ReadDefinition, idl codec.IDL) ([]string, error) { + var reads []string + for _, mr := range readDefinition.MultiReader.Reads { + idlDef, err := codec.FindDefinitionFromIDL(codec.ChainConfigTypeAccountDef, mr.ChainSpecificName, idl) + if err != nil { + return nil, err + } + + if mr.ReadType != config.Account { + return nil, fmt.Errorf("unexpected read type %q for dynamic hard coder read: %q in namespace: %q", mr.ReadType, mr.ChainSpecificName, namespace) + } + + accountIDLDef, isOk := idlDef.(codec.IdlTypeDef) + if !isOk { + return nil, fmt.Errorf("unexpected type %T from IDL definition for account read with chainSpecificName: %q, of type: %q", accountIDLDef, mr.ChainSpecificName, mr.ReadType) + } + + if err = s.addAccountRead(namespace, mr.ChainSpecificName, idl, accountIDLDef, mr); err != nil { + return nil, fmt.Errorf("failed to add multi-read %q: %w", mr.ChainSpecificName, err) + } + + reads = append(reads, mr.ChainSpecificName) + } + return reads, nil +} + func (s *ContractReaderService) addEventRead( namespace, genericName string, contractAddress solana.PublicKey, @@ -429,25 +474,6 @@ func (s *ContractReaderService) addEventRead( return nil } -type accountDataReader struct { - client *rpc.Client -} - -func NewAccountDataReader(client *rpc.Client) *accountDataReader { - return &accountDataReader{client: client} -} - -func (r *accountDataReader) ReadAll(ctx context.Context, pk solana.PublicKey, opts *rpc.GetAccountInfoOpts) ([]byte, error) { - result, err := r.client.GetAccountInfoWithOpts(ctx, pk, opts) - if err != nil { - return nil, err - } - - bts := result.Value.Data.GetBinary() - - return bts, nil -} - func toLPFilter( f *config.PollingFilter, address solana.PublicKey, diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index 987b3d256..016136461 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -387,7 +387,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { testReadDef := readDef - testReadDef.PDADefiniton = testCase.pdaDefinition + testReadDef.PDADefinition = testCase.pdaDefinition testReadDef.InputModifications = testCase.inputModifier testCodec, conf := newTestConfAndCodecWithInjectibleReadDef(t, PDAAccount, testReadDef) encoded, err := testCodec.Encode(ctx, expected, testutils.TestStructWithNestedStruct) @@ -428,7 +428,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { readDef := config.ReadDefinition{ ChainSpecificName: testutils.TestStructWithNestedStruct, ReadType: config.Account, - PDADefiniton: codec.PDATypeDef{ + PDADefinition: codec.PDATypeDef{ Prefix: prefixBytes, Seeds: []codec.PDASeed{ { diff --git a/pkg/solana/chainreader/lookup.go b/pkg/solana/chainreader/lookup.go index ccd7b44db..e51c2a878 100644 --- a/pkg/solana/chainreader/lookup.go +++ b/pkg/solana/chainreader/lookup.go @@ -7,17 +7,20 @@ import ( ) type readValues struct { - address string - contract string - genericName string + address string + contract string + // First read in multi read has type info that other sequential reads are filling out. + // this works by having hard coder codec modifier define fields that are filled out by subsequent reads. + multiRead []string } // lookup provides basic utilities for mapping a complete readIdentifier to // finite contract read information type lookup struct { mu sync.RWMutex - // contractReadNames maps a contract name to all available namePairs (method, log, event, etc.) - contractReadNames map[string][]string + // contractReadNames maps a program name to all available reads (accounts, PDAs, logs). + // Every key (generic read name) can be composed of multiple reads of the same program. Right now all of them have to be of same type (account, PDA or log). + contractReadNames map[string]map[string][]string // readIdentifiers maps from a complete readIdentifier string to finite read data // a readIdentifier is a combination of address, contract, and chainSpecificName as a concatenated string readIdentifiers map[string]readValues @@ -25,37 +28,42 @@ type lookup struct { func newLookup() *lookup { return &lookup{ - contractReadNames: make(map[string][]string), + contractReadNames: make(map[string]map[string][]string), readIdentifiers: make(map[string]readValues), } } -func (l *lookup) addReadNameForContract(contract string, genericName string) { +func (l *lookup) addReadNameForContract(contract, genericName string, multiRead []string) { l.mu.Lock() defer l.mu.Unlock() readNames, exists := l.contractReadNames[contract] if !exists { - readNames = []string{} + readNames = make(map[string][]string) } - l.contractReadNames[contract] = append(readNames, genericName) + readNames[genericName] = multiRead + + l.contractReadNames[contract] = readNames } func (l *lookup) bindAddressForContract(contract, address string) { l.mu.Lock() defer l.mu.Unlock() - for _, genericName := range l.contractReadNames[contract] { - readIdentifier := types.BoundContract{ - Address: address, - Name: contract, - }.ReadIdentifier(genericName) + for _, multiRead := range l.contractReadNames[contract] { + readIdentifier := "" + if len(multiRead) > 0 { + readIdentifier = types.BoundContract{ + Address: address, + Name: contract, + }.ReadIdentifier(multiRead[0]) + } l.readIdentifiers[readIdentifier] = readValues{ - address: address, - contract: contract, - genericName: genericName, + address: address, + contract: contract, + multiRead: multiRead, } } } @@ -64,11 +72,14 @@ func (l *lookup) unbindAddressForContract(contract, address string) { l.mu.Lock() defer l.mu.Unlock() - for _, genericName := range l.contractReadNames[contract] { - readIdentifier := types.BoundContract{ - Address: address, - Name: contract, - }.ReadIdentifier(genericName) + for _, multiRead := range l.contractReadNames[contract] { + readIdentifier := "" + if len(multiRead) > 0 { + readIdentifier = types.BoundContract{ + Address: address, + Name: contract, + }.ReadIdentifier(multiRead[0]) + } delete(l.readIdentifiers, readIdentifier) } diff --git a/pkg/solana/config/chain_reader.go b/pkg/solana/config/chain_reader.go index a1b6839a5..4fbe7c420 100644 --- a/pkg/solana/config/chain_reader.go +++ b/pkg/solana/config/chain_reader.go @@ -31,16 +31,23 @@ type ChainContractReader struct { // TODO ContractPollingFilter same as EVM? } +type MultiReader struct { + // Reads is a list of reads that is sequentially read to fill out a complete response for the parent read. + // Parent ReadDefinition has to define codec modifiers which adds fields that are to be filled out by the reads in Reads. + Reads []ReadDefinition `json:"reads,omitempty"` +} + type ReadDefinition struct { ChainSpecificName string `json:"chainSpecificName"` ReadType ReadType `json:"readType,omitempty"` InputModifications commoncodec.ModifiersConfig `json:"inputModifications,omitempty"` OutputModifications commoncodec.ModifiersConfig `json:"outputModifications,omitempty"` - PDADefiniton codec.PDATypeDef `json:"pdaDefinition,omitempty"` // Only used for PDA account reads - IndexedField0 *IndexedField `json:"indexedField0"` - IndexedField1 *IndexedField `json:"indexedField1"` - IndexedField2 *IndexedField `json:"indexedField2"` - IndexedField3 *IndexedField `json:"indexedField3"` + PDADefinition codec.PDATypeDef `json:"pdaDefinition,omitempty"` // Only used for PDA account reads + MultiReader *MultiReader + IndexedField0 *IndexedField `json:"indexedField0"` + IndexedField1 *IndexedField `json:"indexedField1"` + IndexedField2 *IndexedField `json:"indexedField2"` + IndexedField3 *IndexedField `json:"indexedField3"` // This will create a log poller filter for this event. *PollingFilter `json:"pollingFilter,omitempty"` }