diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index ada654756..48d84b349 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -137,8 +137,12 @@ func (r *bindingsRegistry) GetReaders(namespace string) ([]readBinding, error) { } func (r *bindingsRegistry) Bind(ctx context.Context, reg filterRegistrar, binding types.BoundContract) error { - r.mu.RLock() - defer r.mu.RUnlock() + r.mu.Lock() + defer r.mu.Unlock() + + if err := r.handleAddressSharing(binding); err != nil { + return err + } namespace, nbsExist := r.namespaceBindings[binding.Name] if !nbsExist { @@ -150,35 +154,25 @@ func (r *bindingsRegistry) Bind(ctx context.Context, reg filterRegistrar, bindin return err } - if err := errors.Join( + return errors.Join( namespace.Bind(ctx, reg, address), namespace.BindReaders(ctx, address), - ); err != nil { - return err - } - - return nil + ) } -func (r *bindingsRegistry) Unbind(ctx context.Context, reg filterRegistrar, bindings []types.BoundContract) error { +func (r *bindingsRegistry) Unbind(ctx context.Context, reg filterRegistrar, binding types.BoundContract) error { r.mu.RLock() defer r.mu.RUnlock() - for _, binding := range bindings { - namespace, nbsExist := r.namespaceBindings[binding.Name] - if !nbsExist { - return fmt.Errorf("%w: no namespace named %s", types.ErrInvalidConfig, binding.Name) - } - - if err := errors.Join( - namespace.Unbind(ctx, reg), - namespace.UnbindReaders(ctx), - ); err != nil { - return err - } + namespace, nbsExist := r.namespaceBindings[binding.Name] + if !nbsExist { + return fmt.Errorf("%w: no namespace named %s", types.ErrInvalidConfig, binding.Name) } - return nil + return errors.Join( + namespace.Unbind(ctx, reg), + namespace.UnbindReaders(ctx), + ) } func (r *bindingsRegistry) CreateType(namespace, readName string, forEncoding bool) (any, error) { @@ -191,6 +185,9 @@ func (r *bindingsRegistry) CreateType(namespace, readName string, forEncoding bo } func (r *bindingsRegistry) initAddressSharing(addressShareGroups [][]string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.addressShareGroups = make(map[string]*addressShareGroup) for _, group := range addressShareGroups { @@ -220,6 +217,32 @@ func (r *bindingsRegistry) getShareGroup(nameSpace string) (*addressShareGroup, return shareGroup, sharesAddress } +func (r *bindingsRegistry) handleAddressSharing(boundContract types.BoundContract) error { + shareGroup, isInAGroup := r.getShareGroup(boundContract.Name) + if !isInAGroup { + return nil + } + + shareGroup.mux.Lock() + defer shareGroup.mux.Unlock() + + // set shared address to the binding address + if shareGroup.address.IsZero() { + key, err := solana.PublicKeyFromBase58(boundContract.Address) + if err != nil { + return err + } + + r.addressShareGroups[boundContract.Name].address, shareGroup.address = key, key + } else if boundContract.Address != shareGroup.address.String() && boundContract.Address != "" { + return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.address, shareGroup.group, boundContract.Address) + } + + boundContract.Address = shareGroup.address.String() + + return nil +} + type namespaceBinding struct { // static data name string @@ -234,6 +257,7 @@ func newNamespaceBinding(namespace string) *namespaceBinding { return &namespaceBinding{ name: namespace, readers: make(map[string]readBinding), + bound: make(map[solana.PublicKey]bool), } } @@ -256,9 +280,6 @@ func (b *namespaceBinding) SetModifiers(modifier commoncodec.Modifier) { } func (b *namespaceBinding) Bind(ctx context.Context, reg filterRegistrar, address solana.PublicKey) error { - b.mu.RLock() - defer b.mu.RUnlock() - if b.bindingExists(address) { return nil } @@ -269,6 +290,9 @@ func (b *namespaceBinding) Bind(ctx context.Context, reg filterRegistrar, addres } func (b *namespaceBinding) BindReaders(ctx context.Context, address solana.PublicKey) error { + b.mu.RLock() + defer b.mu.RUnlock() + var err error for _, rb := range b.readers { @@ -388,27 +412,3 @@ func (b *namespaceBinding) unsetBinding() { b.bound = make(map[solana.PublicKey]bool) } - -func (b *bindingsRegistry) handleAddressSharing(boundContract *types.BoundContract) error { - shareGroup, isInAGroup := b.getShareGroup(boundContract.Name) - if !isInAGroup { - return nil - } - - shareGroup.mux.Lock() - defer shareGroup.mux.Unlock() - - // set shared address to the binding address - if shareGroup.address.IsZero() { - key, err := solana.PublicKeyFromBase58(boundContract.Address) - if err != nil { - return err - } - b.addressShareGroups[boundContract.Name].address, shareGroup.address = key, key - } else if boundContract.Address != shareGroup.address.String() && boundContract.Address != "" { - return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.address, shareGroup.group, boundContract.Address) - } - - boundContract.Address = shareGroup.address.String() - return nil -} diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index bc0431f2a..b51858b69 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -348,14 +348,25 @@ func (s *ContractReaderService) Bind(ctx context.Context, bindings []types.Bound return fmt.Errorf("failed to add address response hard coder modifier for contract: %q, : %w", bindings[i].Name, err) } } + return nil } // Unbind implements the types.ContractReader interface and allows existing contract namespaceBindings to be removed // from the service. func (s *ContractReaderService) Unbind(ctx context.Context, bindings []types.BoundContract) error { - // TODO: unbind is incomplete - return s.bdRegistry.Unbind(ctx, s.reader, bindings) + for i := range bindings { + if err := s.bdRegistry.Unbind(ctx, s.reader, bindings[i]); err != nil { + return err + } + + s.lookup.unbindAddressForContract(bindings[i].Name, bindings[i].Address) + + // also unbind an empty address if a share group exists + s.lookup.unbindAddressForContract(bindings[i].Name, "") + } + + return nil } // CreateContractType implements the ContractTypeProvider interface and allows the chain reader @@ -638,24 +649,6 @@ func toLPFilter( } } -// injectAddressModifier injects AddressModifier into OutputModifications. -// This is necessary because AddressModifier cannot be serialized and must be applied at runtime. -func injectAddressModifier(inputModifications, outputModifications commoncodec.ModifiersConfig) { - for i, modConfig := range inputModifications { - if addrModifierConfig, ok := modConfig.(*commoncodec.AddressBytesToStringModifierConfig); ok { - addrModifierConfig.Modifier = codec.SolanaAddressModifier{} - outputModifications[i] = addrModifierConfig - } - } - - for i, modConfig := range outputModifications { - if addrModifierConfig, ok := modConfig.(*commoncodec.AddressBytesToStringModifierConfig); ok { - addrModifierConfig.Modifier = codec.SolanaAddressModifier{} - outputModifications[i] = addrModifierConfig - } - } -} - type accountDataReader struct { client *rpc.Client } diff --git a/pkg/solana/chainreader/event_read_binding.go b/pkg/solana/chainreader/event_read_binding.go index 1661c7ef0..22ed05cd0 100644 --- a/pkg/solana/chainreader/event_read_binding.go +++ b/pkg/solana/chainreader/event_read_binding.go @@ -101,11 +101,7 @@ func (b *eventReadBinding) Unbind(ctx context.Context) error { b.unsetBinding() - if err := b.Unregister(ctx); err != nil { - return err - } - - return nil + return b.Unregister(ctx) } func (b *eventReadBinding) Register(ctx context.Context) error {