From f0a6db8d65fecafaf72d10b2861e81f65ef75df9 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Fri, 14 Feb 2025 11:19:59 -0600 Subject: [PATCH] use pointer on binding to achieve side-effect of updating address --- pkg/solana/chainreader/bindings.go | 41 +++++++++++++++------ pkg/solana/chainreader/chain_reader.go | 27 +++++++++----- pkg/solana/chainreader/chain_reader_test.go | 1 + 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index 48d84b349..a62ce79c0 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -42,10 +42,24 @@ type readBinding interface { type addressShareGroup struct { address solana.PublicKey - mux sync.Mutex + mux sync.RWMutex group []string } +func (g *addressShareGroup) getAddress() solana.PublicKey { + g.mux.RLock() + defer g.mux.RUnlock() + + return g.address +} + +func (g *addressShareGroup) setAddress(addr solana.PublicKey) { + g.mux.Lock() + defer g.mux.Unlock() + + g.address = addr +} + type bindingsRegistry struct { mu sync.RWMutex // key is namespace @@ -136,7 +150,14 @@ func (r *bindingsRegistry) GetReaders(namespace string) ([]readBinding, error) { return rBindings.GetReaders() } -func (r *bindingsRegistry) Bind(ctx context.Context, reg filterRegistrar, binding types.BoundContract) error { +// Bind has a side-effect of updating the bound address to a group shared address. +// +// DO NOT CHANGE binding from pointer type. +func (r *bindingsRegistry) Bind(ctx context.Context, reg filterRegistrar, binding *types.BoundContract) error { + if binding == nil { + return fmt.Errorf("%w: bound contract is nil", types.ErrInvalidType) + } + r.mu.Lock() defer r.mu.Unlock() @@ -217,28 +238,26 @@ func (r *bindingsRegistry) getShareGroup(nameSpace string) (*addressShareGroup, return shareGroup, sharesAddress } -func (r *bindingsRegistry) handleAddressSharing(boundContract types.BoundContract) error { +func (r *bindingsRegistry) handleAddressSharing(boundContract *types.BoundContract) error { shareGroup, isInAGroup := r.getShareGroup(boundContract.Name) if !isInAGroup { return nil } - shareGroup.mux.Lock() - defer shareGroup.mux.Unlock() - // set shared address to the binding address - if shareGroup.address.IsZero() { + if shareGroup.getAddress().IsZero() { key, err := solana.PublicKeyFromBase58(boundContract.Address) if err != nil { return err } - r.addressShareGroups[boundContract.Name].address, shareGroup.address = key, key - } else if boundContract.Address != shareGroup.address.String() && boundContract.Address != "" { - return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.address, shareGroup.group, boundContract.Address) + shareGroup.setAddress(key) + } else if boundContract.Address != shareGroup.getAddress().String() && boundContract.Address != "" { + return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.getAddress(), shareGroup.group, boundContract.Address) } - boundContract.Address = shareGroup.address.String() + // side-effect of updating the bound contract address to group-shared address + boundContract.Address = shareGroup.getAddress().String() return nil } diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index b51858b69..bff9715fa 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -321,22 +321,25 @@ func (s *ContractReaderService) QueryKey(ctx context.Context, contract types.Bou return sequenceOfValues, nil } -// Bind implements the types.ContractReader interface and allows new contract namespaceBindings to be added -// to the service. +// Bind implements the types.ContractReader interface and allows new contract namespaceBindings to be added to the +// service. +// +// Bind has a side-effect of updating a binding with a shared address if the bound contract has been configured to be +// part of a share group. func (s *ContractReaderService) Bind(ctx context.Context, bindings []types.BoundContract) error { - for i := range bindings { - if err := s.bdRegistry.Bind(ctx, s.reader, bindings[i]); err != nil { + for idx := range bindings { + if err := s.bdRegistry.Bind(ctx, s.reader, &bindings[idx]); err != nil { return err } - s.lookup.bindAddressForContract(bindings[i].Name, bindings[i].Address) + s.lookup.bindAddressForContract(bindings[idx].Name, bindings[idx].Address) // also bind with an empty address so that we can look up the contract without providing address when calling CR methods - if sg, isInAShareGroup := s.bdRegistry.getShareGroup(bindings[i].Name); isInAShareGroup { - s.lookup.bindAddressForContract(bindings[i].Name, "") + if sg, isInAShareGroup := s.bdRegistry.getShareGroup(bindings[idx].Name); isInAShareGroup { + s.lookup.bindAddressForContract(bindings[idx].Name, "") for _, namespace := range sg.group { - if err := s.addAddressResponseHardCoderModifier(namespace, bindings[i].Address); err != nil { + if err := s.addAddressResponseHardCoderModifier(namespace, bindings[idx].Address); err != nil { return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", namespace, err) } } @@ -344,8 +347,8 @@ func (s *ContractReaderService) Bind(ctx context.Context, bindings []types.Bound return nil } - if err := s.addAddressResponseHardCoderModifier(bindings[i].Name, bindings[i].Address); err != nil { - return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", bindings[i].Name, err) + if err := s.addAddressResponseHardCoderModifier(bindings[idx].Name, bindings[idx].Address); err != nil { + return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", bindings[idx].Name, err) } } @@ -830,6 +833,10 @@ func setPollingFilterOverrides(common *config.PollingFilter, overrides ...*confi allOverrides := append([]*config.PollingFilter{common}, overrides...) for _, override := range allOverrides { + if override == nil { + continue + } + valOfO := reflect.Indirect(reflect.ValueOf(override)) for idx := range valOfF.Type().NumField() { name := valOfO.Type().Field(idx).Name diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index 104b0cd81..eb905b6bf 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -178,6 +178,7 @@ func TestSolanaChainReaderService_Start(t *testing.T) { return service.ErrNotStarted }()) er.On("Start", mock.Anything).Maybe().Return(tt.StartError) + er.On("HasFilter", mock.Anything, mock.Anything).Return(false).Maybe() er.On("RegisterFilter", mock.Anything, mock.Anything).Maybe().Return(tt.RegisterFilterError) require.NoError(t, svc.Bind(ctx, []types.BoundContract{{Address: pk.String(), Name: "myChainReader"}}))