From a9c0e7c194fabd5950505ca1d98b6f11d986cff2 Mon Sep 17 00:00:00 2001 From: ilija Date: Sun, 9 Feb 2025 22:33:38 +0100 Subject: [PATCH] lint --- .../relayinterface/chain_components_test.go | 70 ++--- .../chainreader/account_read_binding.go | 16 +- pkg/solana/chainreader/batch.go | 1 + pkg/solana/chainreader/bindings.go | 2 - pkg/solana/chainreader/bindings_test.go | 5 - pkg/solana/chainreader/chain_reader.go | 249 ++++++++++-------- pkg/solana/chainreader/chain_reader_test.go | 8 +- pkg/solana/chainreader/event_read_binding.go | 6 +- pkg/solana/config/chain_reader.go | 6 - 9 files changed, 191 insertions(+), 172 deletions(-) diff --git a/integration-tests/relayinterface/chain_components_test.go b/integration-tests/relayinterface/chain_components_test.go index b4502b2a6..ab13caa34 100644 --- a/integration-tests/relayinterface/chain_components_test.go +++ b/integration-tests/relayinterface/chain_components_test.go @@ -58,13 +58,13 @@ func TestChainComponents(t *testing.T) { helper := &helper{} helper.Init(t) - //t.Run("RunChainComponentsSolanaTests", func(t *testing.T) { - // t.Parallel() - // it := &SolanaChainComponentsInterfaceTester[*testing.T]{Helper: helper, testContext: make(map[string]uint64), testContextMu: &sync.RWMutex{}, testIdx: &atomic.Uint64{}} - // DisableTests(it) - // it.Setup(t) - // RunChainComponentsSolanaTests(t, it) - //}) + t.Run("RunChainComponentsSolanaTests", func(t *testing.T) { + t.Parallel() + it := &SolanaChainComponentsInterfaceTester[*testing.T]{Helper: helper, testContext: make(map[string]uint64), testContextMu: &sync.RWMutex{}, testIdx: &atomic.Uint64{}} + DisableTests(it) + it.Setup(t) + RunChainComponentsSolanaTests(t, it) + }) t.Run("RunChainComponentsInLoopSolanaTests", func(t *testing.T) { t.Parallel() @@ -209,31 +209,31 @@ type TimestampedUnixBig struct { func RunContractReaderInLoopTests[T WrappedTestingT[T]](t T, it ChainComponentsInterfaceTester[T]) { //RunContractReaderInterfaceTests(t, it, false, true) testCases := []Testcase[T]{ - //{ - // Name: ContractReaderGetLatestValueUsingMultiReader, - // Test: func(t T) { - // cr := it.GetContractReader(t) - // bindings := it.GetBindings(t) - // ctx := tests.Context(t) - // - // bound := BindingsByName(bindings, AnyContractName)[0] - // - // require.NoError(t, cr.Bind(ctx, bindings)) - // - // type MultiReadResult struct { - // A uint8 - // B int16 - // U string - // V bool - // } - // - // mRR := MultiReadResult{} - // require.NoError(t, cr.GetLatestValue(ctx, bound.ReadIdentifier(MultiRead), primitives.Unconfirmed, nil, &mRR)) - // - // expectedMRR := MultiReadResult{A: 1, B: 2, U: "Hello", V: true} - // require.Equal(t, expectedMRR, mRR) - // }, - //}, + { + Name: ContractReaderGetLatestValueUsingMultiReader, + Test: func(t T) { + cr := it.GetContractReader(t) + bindings := it.GetBindings(t) + ctx := tests.Context(t) + + bound := BindingsByName(bindings, AnyContractName)[0] + + require.NoError(t, cr.Bind(ctx, bindings)) + + type MultiReadResult struct { + A uint8 + B int16 + U string + V bool + } + + mRR := MultiReadResult{} + require.NoError(t, cr.GetLatestValue(ctx, bound.ReadIdentifier(MultiRead), primitives.Unconfirmed, nil, &mRR)) + + expectedMRR := MultiReadResult{A: 1, B: 2, U: "Hello", V: true} + require.Equal(t, expectedMRR, mRR) + }, + }, { Name: ContractReaderGetLatestValueUsingSplitParamsReader, Test: func(t T) { @@ -632,7 +632,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T }, &commoncodec.PropertyExtractorConfig{FieldName: "Response"}, }, - ReadType: config.AccountSplitParams, + ReadType: config.Account, }, MultiRead: { ChainSpecificName: "MultiRead1", @@ -644,10 +644,10 @@ func (it *SolanaChainComponentsInterfaceTester[T]) buildContractReaderConfig(t T { ChainSpecificName: "MultiRead2", PDADefinition: codec.PDATypeDef{Prefix: []byte("multi_read2")}, - ReadType: config.AccountPDA, + ReadType: config.Account, }, }}, - ReadType: config.AccountPDA, + ReadType: config.Account, }, MethodReturningUint64: uint64ReadDef, MethodReturningUint64Slice: { diff --git a/pkg/solana/chainreader/account_read_binding.go b/pkg/solana/chainreader/account_read_binding.go index e91aafae4..2104af87f 100644 --- a/pkg/solana/chainreader/account_read_binding.go +++ b/pkg/solana/chainreader/account_read_binding.go @@ -4,14 +4,12 @@ import ( "context" "errors" "fmt" - "slices" "github.com/gagliardetto/solana-go" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) @@ -21,16 +19,16 @@ type accountReadBinding struct { namespace, genericName string codec types.RemoteCodec key solana.PublicKey - prefix []byte - readType config.ReadType + isPda bool // flag to signify whether or not the account read is for a PDA + prefix []byte // only used for PDA public key calculation } -func newAccountReadBinding(namespace, genericName string, prefix []byte, readType config.ReadType) *accountReadBinding { +func newAccountReadBinding(namespace, genericName string, prefix []byte, isPda bool) *accountReadBinding { return &accountReadBinding{ namespace: namespace, genericName: genericName, prefix: prefix, - readType: readType, + isPda: isPda, } } @@ -48,7 +46,7 @@ func (b *accountReadBinding) SetAddress(key solana.PublicKey) { func (b *accountReadBinding) GetAddress(ctx context.Context, params any) (solana.PublicKey, error) { // Return the bound key if normal account read - if !slices.Contains([]config.ReadType{config.AccountSplitParams, config.AccountPDA}, b.readType) { + if !b.isPda { return b.key, nil } // Calculate the public key if PDA account read @@ -71,10 +69,6 @@ func (b *accountReadBinding) Decode(ctx context.Context, bts []byte, outVal any) return b.codec.Decode(ctx, bts, outVal, codec.WrapItemType(false, b.namespace, b.genericName)) } -func (b *accountReadBinding) ReadType() config.ReadType { - return b.readType -} - // buildSeedsSlice encodes and builds the seedslist to calculate the PDA public key func (b *accountReadBinding) buildSeedsSlice(ctx context.Context, params any) ([][]byte, error) { flattenedSeeds := make([]byte, 0, solana.MaxSeeds*solana.MaxSeedLength) diff --git a/pkg/solana/chainreader/batch.go b/pkg/solana/chainreader/batch.go index 029e42c60..a3aa7c297 100644 --- a/pkg/solana/chainreader/batch.go +++ b/pkg/solana/chainreader/batch.go @@ -115,6 +115,7 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin decodeReturnVal(ctx, rBinding, data[idx], results[idx].returnVal), results[idx].err) } + return results, nil } diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index f3af4384b..b24a4191f 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -10,7 +10,6 @@ import ( commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" ) type readBinding interface { @@ -21,7 +20,6 @@ type readBinding interface { CreateType(bool) (any, error) Decode(context.Context, []byte, any) error QueryKey(context.Context, query.KeyFilter, query.LimitAndSort, any) ([]types.Sequence, error) - ReadType() config.ReadType } type bindingsRegistry struct { diff --git a/pkg/solana/chainreader/bindings_test.go b/pkg/solana/chainreader/bindings_test.go index c92d9a655..70e2d9c8c 100644 --- a/pkg/solana/chainreader/bindings_test.go +++ b/pkg/solana/chainreader/bindings_test.go @@ -12,7 +12,6 @@ import ( commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" ) func TestBindings_CreateType(t *testing.T) { @@ -49,10 +48,6 @@ type mockBinding struct { mock.Mock } -func (_m *mockBinding) ReadType() config.ReadType { - return config.ReadType(0) -} - func (_m *mockBinding) SetCodec(_ types.RemoteCodec) {} func (_m *mockBinding) SetAddress(_ solana.PublicKey) {} diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index 48e6cdc04..597921c3b 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -20,6 +20,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types/query" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller" @@ -176,14 +177,8 @@ func (s *ContractReaderService) GetLatestValue(ctx context.Context, readIdentifi }, } - // TODO all of the err messages if values.multiRead[0] == "GetTokenPrices" { - err2 := s.handleGetTokenPricesGetLatestValue(ctx, params, values, returnVal) - if err2 != nil { - return err2 - } - - return nil + return s.handleGetTokenPricesGetLatestValue(ctx, params, values, returnVal) } results, err := doMethodBatchCall(ctx, s.client, s.bdRegistry, batch) @@ -202,98 +197,6 @@ func (s *ContractReaderService) GetLatestValue(ctx context.Context, readIdentifi return nil } -func (s *ContractReaderService) handleGetTokenPricesGetLatestValue(ctx context.Context, params any, values readValues, returnVal any) error { - val := reflect.ValueOf(params) - if val.Kind() == reflect.Ptr { - // Dereference so we can access fields - val = val.Elem() - } - - // Make sure we actually have a struct - if val.Kind() != reflect.Struct { - return fmt.Errorf("expected struct, got: %s", val.Kind()) - } - - // Attempt to get the "Tokens" field - field := val.FieldByName("Tokens") - if !field.IsValid() { - return fmt.Errorf("no field named 'Tokens' found") - } - - // Convert that field’s interface value into the correct type - tokens, ok := field.Interface().(*[][32]uint8) - if !ok { - return fmt.Errorf("'Tokens' field is not *[][32]uint8") - } - programAddress, err := solana.PublicKeyFromBase58(values.address) - if err != nil { - return fmt.Errorf("%w: failed to parse program address: %q for contract %q", types.ErrInvalidConfig, values.address, values.contract) - } - - var pdaAddresses []solana.PublicKey - for _, token := range *tokens { - tokenAddr := solana.PublicKeyFromBytes(token[:]) - if tokenAddr.IsOnCurve() && !tokenAddr.IsZero() { - return fmt.Errorf("bad token address: %v for contract: %q read: %q", tokenAddr, values.contract, values.multiRead[0]) - } - - pdaAddress, _, err := solana.FindProgramAddress([][]byte{[]byte("fee_billing_token_config"), tokenAddr.Bytes()}, programAddress) - if err != nil { - return fmt.Errorf("%w: failed find program address for PDA: %w", types.ErrInvalidConfig, err) - } - - pdaAddresses = append(pdaAddresses, pdaAddress) - } - - data, err := s.client.GetMultipleAccountData(ctx, pdaAddresses...) - if err != nil { - return err - } - - type TimestampedUnixBig struct { - Value *big.Int `json:"value"` - Timestamp uint32 `json:"timestamp"` - } - - returnSliceVal := reflect.ValueOf(returnVal) - if returnSliceVal.Kind() == reflect.Ptr { - returnSliceVal = returnSliceVal.Elem() - if returnSliceVal.Kind() == reflect.Ptr { - returnSliceVal = returnSliceVal.Elem() - } - if returnSliceVal.Kind() != reflect.Slice { - return fmt.Errorf("expected slice, got: %s", returnSliceVal.Kind()) - } - } - - elemType := returnSliceVal.Type().Elem() - newElem := reflect.New(elemType).Elem() - - for _, d := range data { - wrapper := fee_quoter.BillingTokenConfigWrapper{} - err = wrapper.UnmarshalWithDecoder(bin.NewBorshDecoder(d)) - if err != nil { - return err - } - v := big.NewInt(0) - v.SetBytes(wrapper.Config.UsdPerToken.Value[:]) - - valueField := newElem.FieldByName("Value") - if !valueField.IsValid() { - return errors.New("struct does not have a field named 'Value'") - } - valueField.Set(reflect.ValueOf(v)) - - timestampField := newElem.FieldByName("Timestamp") - if !timestampField.IsValid() { - return errors.New("struct does not have a field named 'Timestamp'") - } - timestampField.Set(reflect.ValueOf(uint32(wrapper.Config.UsdPerToken.Timestamp))) - returnSliceVal.Set(reflect.Append(returnSliceVal, newElem)) - } - return nil -} - // BatchGetLatestValues implements the types.ContractReader interface. func (s *ContractReaderService) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { idxLookup := make(map[types.BoundContract][]int) @@ -444,7 +347,7 @@ func (s *ContractReaderService) initNamespace(namespaces map[string]config.Chain utils.InjectAddressModifier(read.InputModifications, read.OutputModifications) switch read.ReadType { - case config.Account, config.AccountPDA, config.AccountSplitParams: + case config.Account: idlDef, err := codec.FindDefinitionFromIDL(codec.ChainConfigTypeAccountDef, read.ChainSpecificName, nameSpaceDef.IDL) if err != nil { return err @@ -510,10 +413,10 @@ func (s *ContractReaderService) addAccountRead(namespace string, genericName str // Create PDA read binding if PDA prefix or seeds configs are populated if readDefinition.PDADefinition.Prefix != nil || len(readDefinition.PDADefinition.Seeds) > 0 { inputAccountIDLDef = readDefinition.PDADefinition - reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, readDefinition.ReadType) + reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefinition.Prefix, true) } else { inputAccountIDLDef = codec.NilIdlTypeDefTy - reader = newAccountReadBinding(namespace, genericName, nil, config.Account) + reader = newAccountReadBinding(namespace, genericName, nil, false) } if err := s.addCodecDef(true, namespace, genericName, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil { return err @@ -531,8 +434,8 @@ func (s *ContractReaderService) addMultiAccountRead(namespace string, readDefini return nil, err } - if mr.ReadType != config.AccountPDA { - return nil, fmt.Errorf("unexpected read type %q for multi read: %q in namespace: %q", mr.ReadType, mr.ChainSpecificName, namespace) + if mr.ReadType != config.Account { + return nil, fmt.Errorf("unexpected read type %q for dynamic hard coder read: %q in namespace: %q", mr.ReadType, mr.ChainSpecificName, namespace) } accountIDLDef, isOk := idlDef.(codec.IdlTypeDef) @@ -600,3 +503,141 @@ func applyIndexedFieldTuple(lookup map[string]uint64, subKeys [4][]string, conf subKeys[idx] = strings.Split(conf.OnChainPath, ".") } } + +func (s *ContractReaderService) handleGetTokenPricesGetLatestValue( + ctx context.Context, + params any, + values readValues, + returnVal any, +) error { + pdaAddresses, err := s.getPDAsForGetTokenPrices(params, values) + if err != nil { + return err + } + + data, err := s.client.GetMultipleAccountData(ctx, pdaAddresses...) + if err != nil { + return fmt.Errorf( + "for contract %q read %q: failed to get multiple account data: %w", + values.contract, values.multiRead[0], err, + ) + } + + // -------------- Fill out the returnVal slice with data -------------- + // can't typecast returnVal so we have to use reflection here + + // Ensure `returnVal` is a pointer to a slice we can populate. + returnSliceVal := reflect.ValueOf(returnVal) + if returnSliceVal.Kind() == reflect.Ptr { + returnSliceVal = returnSliceVal.Elem() + if returnSliceVal.Kind() == reflect.Ptr { + returnSliceVal = returnSliceVal.Elem() + } + } + if returnSliceVal.Kind() != reflect.Slice { + return fmt.Errorf( + "for contract %q read %q: expected `returnVal` to be a slice, got %s", + values.contract, values.multiRead[0], returnSliceVal.Kind(), + ) + } + + elemType := returnSliceVal.Type().Elem() + for _, d := range data { + var wrapper fee_quoter.BillingTokenConfigWrapper + if err = wrapper.UnmarshalWithDecoder(bin.NewBorshDecoder(d)); err != nil { + return fmt.Errorf( + "for contract %q read %q: failed to unmarshal account data: %w", + values.contract, values.multiRead[0], err, + ) + } + + newElem := reflect.New(elemType).Elem() + + valueField := newElem.FieldByName("Value") + if !valueField.IsValid() { + return fmt.Errorf( + "for contract %q read %q: struct type missing `Value` field", + values.contract, values.multiRead[0], + ) + } + valueField.Set(reflect.ValueOf(big.NewInt(0).SetBytes(wrapper.Config.UsdPerToken.Value[:]))) + + timestampField := newElem.FieldByName("Timestamp") + if !timestampField.IsValid() { + return fmt.Errorf( + "for contract %q read %q: struct type missing `Timestamp` field", + values.contract, values.multiRead[0], + ) + } + + // nolint:gosec + // G115: integer overflow conversion int64 -> uint32 + timestampField.Set(reflect.ValueOf(uint32(wrapper.Config.UsdPerToken.Timestamp))) + + returnSliceVal.Set(reflect.Append(returnSliceVal, newElem)) + } + + return nil +} + +func (s *ContractReaderService) getPDAsForGetTokenPrices(params any, values readValues) ([]solana.PublicKey, error) { + val := reflect.ValueOf(params) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + if val.Kind() != reflect.Struct { + return nil, fmt.Errorf( + "for contract %q read %q: expected `params` to be a struct, got %s", + values.contract, values.multiRead[0], val.Kind(), + ) + } + + field := val.FieldByName("Tokens") + if !field.IsValid() { + return nil, fmt.Errorf( + "for contract %q read %q: no field named 'Tokens' found in params", + values.contract, values.multiRead[0], + ) + } + + tokens, ok := field.Interface().(*[][32]uint8) + if !ok { + return nil, fmt.Errorf( + "for contract %q read %q: 'Tokens' field is not of type *[][32]uint8", + values.contract, values.multiRead[0], + ) + } + + programAddress, err := solana.PublicKeyFromBase58(values.address) + if err != nil { + return nil, fmt.Errorf( + "for contract %q read %q: %w (could not parse program address %q)", + values.contract, values.multiRead[0], types.ErrInvalidConfig, values.address, + ) + } + + // Build the PDA addresses for all tokens. + var pdaAddresses []solana.PublicKey + for _, token := range *tokens { + tokenAddr := solana.PublicKeyFromBytes(token[:]) + if tokenAddr.IsOnCurve() && !tokenAddr.IsZero() { + return nil, fmt.Errorf( + "for contract %q read %q: invalid token address %v (on-curve or zero)", + values.contract, values.multiRead[0], tokenAddr, + ) + } + + pdaAddress, _, err := solana.FindProgramAddress( + [][]byte{[]byte("fee_billing_token_config"), tokenAddr.Bytes()}, + programAddress, + ) + if err != nil { + return nil, fmt.Errorf( + "for contract %q read %q: %w (failed to find PDA for token %v)", + values.contract, values.multiRead[0], types.ErrInvalidConfig, tokenAddr, + ) + } + pdaAddresses = append(pdaAddresses, pdaAddress) + } + return pdaAddresses, nil +} diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index 8f82cf622..b8abc1813 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -352,11 +352,11 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { Seeds: []codec.PDASeed{ { Name: "PubKey", - Type: codec.IdlTypePublicKey, + Type: codec.IdlType{AsString: codec.IdlTypePublicKey}, }, { Name: "Uint64Seed", - Type: codec.IdlTypeU64, + Type: codec.IdlType{AsString: codec.IdlTypeU64}, }, }, }, @@ -373,7 +373,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { Seeds: []codec.PDASeed{ { Name: "PubKey", - Type: codec.IdlTypePublicKey, + Type: codec.IdlType{AsString: codec.IdlTypePublicKey}, }, }, }, @@ -433,7 +433,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { Seeds: []codec.PDASeed{ { Name: "PubKey", - Type: codec.IdlTypePublicKey, + Type: codec.IdlType{AsString: codec.IdlTypePublicKey}, }, }, }, diff --git a/pkg/solana/chainreader/event_read_binding.go b/pkg/solana/chainreader/event_read_binding.go index 0ee9f768f..8f8011463 100644 --- a/pkg/solana/chainreader/event_read_binding.go +++ b/pkg/solana/chainreader/event_read_binding.go @@ -12,8 +12,8 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller" ) @@ -133,10 +133,6 @@ func (b *eventReadBinding) remapPrimitive(expression query.Expression) (query.Ex return comp, nil } -func (b *eventReadBinding) ReadType() config.ReadType { - return config.Event -} - func (b *eventReadBinding) encodeComparator(comparator *primitives.Comparator) (query.Expression, error) { subKeyIndex, ok := b.indexedSubKeys[comparator.Name] if !ok { diff --git a/pkg/solana/config/chain_reader.go b/pkg/solana/config/chain_reader.go index fc289aaac..4fbe7c420 100644 --- a/pkg/solana/config/chain_reader.go +++ b/pkg/solana/config/chain_reader.go @@ -56,8 +56,6 @@ type ReadType int const ( Account ReadType = iota - AccountPDA - AccountSplitParams Event ) @@ -67,10 +65,6 @@ func (r ReadType) String() string { return "Account" case Event: return "Event" - case AccountPDA: - return "AccountPDA" - case AccountSplitParams: - return "AccountSplitParams" default: return fmt.Sprintf("Unknown(%d)", r) }