Skip to content

Commit

Permalink
use pointer on binding to achieve side-effect of updating address
Browse files Browse the repository at this point in the history
  • Loading branch information
EasterTheBunny committed Feb 14, 2025
1 parent 73dca39 commit 41a4874
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 21 deletions.
41 changes: 30 additions & 11 deletions pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,24 @@ type readBinding interface {

type addressShareGroup struct {
address solana.PublicKey
mux sync.Mutex
mux sync.RWMutex
group []string
}

func (g *addressShareGroup) getAddress() solana.PublicKey {
g.mux.RLock()
defer g.mux.RUnlock()

return g.address
}

func (g *addressShareGroup) setAddress(addr solana.PublicKey) {
g.mux.Lock()
defer g.mux.Unlock()

g.address = addr
}

type bindingsRegistry struct {
mu sync.RWMutex
// key is namespace
Expand Down Expand Up @@ -136,7 +150,14 @@ func (r *bindingsRegistry) GetReaders(namespace string) ([]readBinding, error) {
return rBindings.GetReaders()
}

func (r *bindingsRegistry) Bind(ctx context.Context, reg filterRegistrar, binding types.BoundContract) error {
// Bind has a side-effect of updating the bound address to a group shared address.
//
// DO NOT CHANGE binding from pointer type.
func (r *bindingsRegistry) Bind(ctx context.Context, reg filterRegistrar, binding *types.BoundContract) error {
if binding == nil {
return fmt.Errorf("%w: bound contract is nil", types.ErrInvalidType)
}

r.mu.Lock()
defer r.mu.Unlock()

Expand Down Expand Up @@ -217,28 +238,26 @@ func (r *bindingsRegistry) getShareGroup(nameSpace string) (*addressShareGroup,
return shareGroup, sharesAddress
}

func (r *bindingsRegistry) handleAddressSharing(boundContract types.BoundContract) error {
func (r *bindingsRegistry) handleAddressSharing(boundContract *types.BoundContract) error {
shareGroup, isInAGroup := r.getShareGroup(boundContract.Name)
if !isInAGroup {
return nil
}

shareGroup.mux.Lock()
defer shareGroup.mux.Unlock()

// set shared address to the binding address
if shareGroup.address.IsZero() {
if shareGroup.getAddress().IsZero() {
key, err := solana.PublicKeyFromBase58(boundContract.Address)
if err != nil {
return err
}

r.addressShareGroups[boundContract.Name].address, shareGroup.address = key, key
} else if boundContract.Address != shareGroup.address.String() && boundContract.Address != "" {
return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.address, shareGroup.group, boundContract.Address)
shareGroup.setAddress(key)
} else if boundContract.Address != shareGroup.getAddress().String() && boundContract.Address != "" {
return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.getAddress(), shareGroup.group, boundContract.Address)
}

boundContract.Address = shareGroup.address.String()
// side-effect of updating the bound contract address to group-shared address
boundContract.Address = shareGroup.getAddress().String()

return nil
}
Expand Down
27 changes: 17 additions & 10 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,31 +321,34 @@ func (s *ContractReaderService) QueryKey(ctx context.Context, contract types.Bou
return sequenceOfValues, nil
}

// Bind implements the types.ContractReader interface and allows new contract namespaceBindings to be added
// to the service.
// Bind implements the types.ContractReader interface and allows new contract namespaceBindings to be added to the
// service.
//
// Bind has a side-effect of updating a binding with a shared address if the bound contract has been configured to be
// part of a share group.
func (s *ContractReaderService) Bind(ctx context.Context, bindings []types.BoundContract) error {
for i := range bindings {
if err := s.bdRegistry.Bind(ctx, s.reader, bindings[i]); err != nil {
for idx := range bindings {
if err := s.bdRegistry.Bind(ctx, s.reader, &bindings[idx]); err != nil {
return err
}

s.lookup.bindAddressForContract(bindings[i].Name, bindings[i].Address)
s.lookup.bindAddressForContract(bindings[idx].Name, bindings[idx].Address)

// also bind with an empty address so that we can look up the contract without providing address when calling CR methods
if sg, isInAShareGroup := s.bdRegistry.getShareGroup(bindings[i].Name); isInAShareGroup {
s.lookup.bindAddressForContract(bindings[i].Name, "")
if sg, isInAShareGroup := s.bdRegistry.getShareGroup(bindings[idx].Name); isInAShareGroup {
s.lookup.bindAddressForContract(bindings[idx].Name, "")

for _, namespace := range sg.group {
if err := s.addAddressResponseHardCoderModifier(namespace, bindings[i].Address); err != nil {
if err := s.addAddressResponseHardCoderModifier(namespace, bindings[idx].Address); err != nil {
return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", namespace, err)
}
}

return nil
}

if err := s.addAddressResponseHardCoderModifier(bindings[i].Name, bindings[i].Address); err != nil {
return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", bindings[i].Name, err)
if err := s.addAddressResponseHardCoderModifier(bindings[idx].Name, bindings[idx].Address); err != nil {
return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", bindings[idx].Name, err)
}
}

Expand Down Expand Up @@ -830,6 +833,10 @@ func setPollingFilterOverrides(common *config.PollingFilter, overrides ...*confi
allOverrides := append([]*config.PollingFilter{common}, overrides...)

for _, override := range allOverrides {
if override == nil {
continue
}

valOfO := reflect.Indirect(reflect.ValueOf(override))
for idx := range valOfF.Type().NumField() {
name := valOfO.Type().Field(idx).Name
Expand Down
1 change: 1 addition & 0 deletions pkg/solana/chainreader/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func TestSolanaChainReaderService_Start(t *testing.T) {
return service.ErrNotStarted
}())
er.On("Start", mock.Anything).Maybe().Return(tt.StartError)
er.On("HasFilter", mock.Anything, mock.Anything).Return(false).Maybe()
er.On("RegisterFilter", mock.Anything, mock.Anything).Maybe().Return(tt.RegisterFilterError)

require.NoError(t, svc.Bind(ctx, []types.BoundContract{{Address: pk.String(), Name: "myChainReader"}}))
Expand Down

0 comments on commit 41a4874

Please sign in to comment.