diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go index c04021a68..c1f38f898 100644 --- a/pulsar/internal/connection.go +++ b/pulsar/internal/connection.go @@ -87,6 +87,7 @@ type Connection interface { ID() string GetMaxMessageSize() int32 Close() + WaitForClose() <-chan interface{} IsProxied() bool } @@ -1049,6 +1050,10 @@ func (c *connection) CheckIdle(maxIdleTime time.Duration) bool { return time.Since(c.lastActive) > maxIdleTime } +func (c *connection) WaitForClose() <-chan interface{} { + return c.closeCh +} + // Close closes the connection by // closing underlying socket connection and closeCh. // This also triggers callbacks to the ConnectionClosed listeners. diff --git a/pulsar/internal/connection_pool.go b/pulsar/internal/connection_pool.go index 3d718b75d..5f858b118 100644 --- a/pulsar/internal/connection_pool.go +++ b/pulsar/internal/connection_pool.go @@ -34,6 +34,9 @@ type ConnectionPool interface { // GetConnection get a connection from ConnectionPool. GetConnection(logicalAddr *url.URL, physicalAddr *url.URL) (Connection, error) + // GetConnections get all connections in the pool. + GetConnections() map[string]Connection + // Close all the connections in the pool Close() } @@ -124,6 +127,16 @@ func (p *connectionPool) GetConnection(logicalAddr *url.URL, physicalAddr *url.U return conn, err } +func (p *connectionPool) GetConnections() map[string]Connection { + p.Lock() + conns := make(map[string]Connection) + for k, c := range p.connections { + conns[k] = c + } + p.Unlock() + return conns +} + func (p *connectionPool) Close() { p.Lock() close(p.closeCh) diff --git a/pulsar/transaction_coordinator_client.go b/pulsar/transaction_coordinator_client.go index c4b7a6a20..1449d698e 100644 --- a/pulsar/transaction_coordinator_client.go +++ b/pulsar/transaction_coordinator_client.go @@ -20,23 +20,311 @@ package pulsar import ( "context" "strconv" + "strings" "sync/atomic" "time" "github.com/apache/pulsar-client-go/pulsar/internal" pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto" "github.com/apache/pulsar-client-go/pulsar/log" + "github.com/pkg/errors" + uAtomic "go.uber.org/atomic" "google.golang.org/protobuf/proto" ) type transactionCoordinatorClient struct { client *client - cons []internal.Connection + handlers []*transactionHandler epoch uint64 semaphore internal.Semaphore log log.Logger } +type transactionHandler struct { + tc *transactionCoordinatorClient + state uAtomic.Int32 + conn uAtomic.Value + partition uint64 + closeCh chan any + requestCh chan any + connectClosedCh chan *connectionClosed + log log.Logger +} + +type txnHandlerState int + +const ( + txnHandlerReady = iota + txnHandlerClosed +) + +func (t *transactionHandler) getState() txnHandlerState { + return txnHandlerState(t.state.Load()) +} + +func (tc *transactionCoordinatorClient) newTransactionHandler(partition uint64) (*transactionHandler, error) { + handler := &transactionHandler{ + tc: tc, + partition: partition, + closeCh: make(chan any), + requestCh: make(chan any), + connectClosedCh: make(chan *connectionClosed), + log: tc.log.SubLogger(log.Fields{"txn handler partition": partition}), + } + err := handler.grabConn() + if err != nil { + return nil, err + } + go handler.runEventsLoop() + return handler, nil +} + +func (t *transactionHandler) grabConn() error { + lr, err := t.tc.client.lookupService.Lookup(getTCAssignTopicName(t.partition)) + if err != nil { + t.log.WithError(err).Warn("Failed to lookup the transaction_impl " + + "coordinator assign topic [" + strconv.FormatUint(t.partition, 10) + "]") + return err + } + + requestID := t.tc.client.rpcClient.NewRequestID() + cmdTCConnect := pb.CommandTcClientConnectRequest{ + RequestId: proto.Uint64(requestID), + TcId: proto.Uint64(t.partition), + } + + res, err := t.tc.client.rpcClient.Request(lr.LogicalAddr, lr.PhysicalAddr, requestID, + pb.BaseCommand_TC_CLIENT_CONNECT_REQUEST, &cmdTCConnect) + + if err != nil { + t.log.WithError(err).Error("Failed to connect transaction_impl coordinator " + + strconv.FormatUint(t.partition, 10)) + return err + } + + go func() { + select { + case <-t.closeCh: + return + case <-res.Cnx.WaitForClose(): + t.connectClosedCh <- &connectionClosed{} + } + }() + t.conn.Store(res.Cnx) + t.log.Infof("Transaction handler with transaction coordinator id %d connected", t.partition) + return nil +} + +func (t *transactionHandler) getConn() internal.Connection { + return t.conn.Load().(internal.Connection) +} + +func (t *transactionHandler) runEventsLoop() { + for { + select { + case <-t.closeCh: + return + case req := <-t.requestCh: + switch r := req.(type) { + case *newTxnOp: + t.newTransaction(r) + case *addPublishPartitionOp: + t.addPublishPartitionToTxn(r) + case *addSubscriptionOp: + t.addSubscriptionToTxn(r) + case *endTxnOp: + t.endTxn(r) + } + case <-t.connectClosedCh: + t.reconnectToBroker() + } + } +} + +func (t *transactionHandler) reconnectToBroker() { + var delayReconnectTime time.Duration + var defaultBackoff = internal.DefaultBackoff{} + + for { + if t.getState() == txnHandlerClosed { + // The handler is already closing + t.log.Info("transaction handler is closed, exit reconnect") + return + } + + delayReconnectTime = defaultBackoff.Next() + + t.log.WithFields(log.Fields{ + "delayReconnectTime": delayReconnectTime, + }).Info("Transaction handler will reconnect to the transaction coordinator") + time.Sleep(delayReconnectTime) + + // double check + if t.getState() == txnHandlerClosed { + // Txn handler is already closing + t.log.Info("transaction handler is closed, exit reconnect") + return + } + + err := t.grabConn() + if err == nil { + // Successfully reconnected + t.log.Info("Reconnected transaction handler to broker") + return + } + t.log.WithError(err).Error("Failed to create transaction handler at reconnect") + errMsg := err.Error() + if strings.Contains(errMsg, errMsgTopicNotFound) { + // when topic is deleted, we should give up reconnection. + t.log.Warn("Topic Not Found") + break + } + } +} + +func (t *transactionHandler) checkRetriableError(err error, op any) bool { + if err != nil && errors.Is(err, internal.ErrConnectionClosed) { + // We are in the EventLoop here, so we need to insert the request back to the requestCh asynchronously. + go func() { + t.requestCh <- op + }() + return true + } + return false +} + +type newTxnOp struct { + // Request + timeout time.Duration + + // Response + errCh chan error + txnID *TxnID +} + +func (t *transactionHandler) newTransaction(op *newTxnOp) { + requestID := t.tc.client.rpcClient.NewRequestID() + cmdNewTxn := &pb.CommandNewTxn{ + RequestId: proto.Uint64(requestID), + TcId: &t.partition, + TxnTtlSeconds: proto.Uint64(uint64(op.timeout.Milliseconds())), + } + res, err := t.tc.client.rpcClient.RequestOnCnx(t.getConn(), requestID, pb.BaseCommand_NEW_TXN, cmdNewTxn) + if t.checkRetriableError(err, op) { + return + } + defer close(op.errCh) + if err != nil { + op.errCh <- err + } else if res.Response.NewTxnResponse.Error != nil { + op.errCh <- getErrorFromServerError(res.Response.NewTxnResponse.Error) + } else { + op.txnID = &TxnID{*res.Response.NewTxnResponse.TxnidMostBits, + *res.Response.NewTxnResponse.TxnidLeastBits} + } +} + +type addPublishPartitionOp struct { + // Request + id *TxnID + partitions []string + + // Response + errCh chan error +} + +func (t *transactionHandler) addPublishPartitionToTxn(op *addPublishPartitionOp) { + requestID := t.tc.client.rpcClient.NewRequestID() + cmdAddPartitions := &pb.CommandAddPartitionToTxn{ + RequestId: proto.Uint64(requestID), + TxnidMostBits: proto.Uint64(op.id.MostSigBits), + TxnidLeastBits: proto.Uint64(op.id.LeastSigBits), + Partitions: op.partitions, + } + res, err := t.tc.client.rpcClient.RequestOnCnx(t.getConn(), requestID, + pb.BaseCommand_ADD_PARTITION_TO_TXN, cmdAddPartitions) + if t.checkRetriableError(err, op) { + return + } + defer close(op.errCh) + if err != nil { + op.errCh <- err + } else if res.Response.AddPartitionToTxnResponse.Error != nil { + op.errCh <- getErrorFromServerError(res.Response.AddPartitionToTxnResponse.Error) + } +} + +type addSubscriptionOp struct { + // Request + id *TxnID + topic string + subscription string + + // Response + errCh chan error +} + +func (t *transactionHandler) addSubscriptionToTxn(op *addSubscriptionOp) { + requestID := t.tc.client.rpcClient.NewRequestID() + sub := &pb.Subscription{ + Topic: &op.topic, + Subscription: &op.subscription, + } + cmdAddSubscription := &pb.CommandAddSubscriptionToTxn{ + RequestId: proto.Uint64(requestID), + TxnidMostBits: proto.Uint64(op.id.MostSigBits), + TxnidLeastBits: proto.Uint64(op.id.LeastSigBits), + Subscription: []*pb.Subscription{sub}, + } + res, err := t.tc.client.rpcClient.RequestOnCnx(t.getConn(), requestID, + pb.BaseCommand_ADD_SUBSCRIPTION_TO_TXN, cmdAddSubscription) + if t.checkRetriableError(err, op) { + return + } + defer close(op.errCh) + if err != nil { + op.errCh <- err + } else if res.Response.AddSubscriptionToTxnResponse.Error != nil { + op.errCh <- getErrorFromServerError(res.Response.AddSubscriptionToTxnResponse.Error) + } +} + +type endTxnOp struct { + // Request + id *TxnID + action pb.TxnAction + + // Response + errCh chan error +} + +func (t *transactionHandler) endTxn(op *endTxnOp) { + requestID := t.tc.client.rpcClient.NewRequestID() + cmdEndTxn := &pb.CommandEndTxn{ + RequestId: proto.Uint64(requestID), + TxnAction: &op.action, + TxnidMostBits: proto.Uint64(op.id.MostSigBits), + TxnidLeastBits: proto.Uint64(op.id.LeastSigBits), + } + res, err := t.tc.client.rpcClient.RequestOnCnx(t.getConn(), requestID, pb.BaseCommand_END_TXN, cmdEndTxn) + if t.checkRetriableError(err, op) { + return + } + defer close(op.errCh) + if err != nil { + op.errCh <- err + } else if res.Response.EndTxnResponse.Error != nil { + op.errCh <- getErrorFromServerError(res.Response.EndTxnResponse.Error) + } +} + +func (t *transactionHandler) close() { + if !t.state.CAS(txnHandlerReady, txnHandlerClosed) { + return + } + close(t.closeCh) +} + // TransactionCoordinatorAssign is the transaction_impl coordinator topic which is used to look up the broker // where the TC located. const TransactionCoordinatorAssign = "persistent://pulsar/system/transaction_coordinator_assign" @@ -61,50 +349,25 @@ func (tc *transactionCoordinatorClient) start() error { if err != nil { return err } - tc.cons = make([]internal.Connection, r.Partitions) - + tc.handlers = make([]*transactionHandler, r.Partitions) //Get connections with all transaction_impl coordinators which is synchronized if r.Partitions <= 0 { return ErrTransactionCoordinatorNotEnabled } for i := 0; i < r.Partitions; i++ { - err := tc.grabConn(uint64(i)) + handler, err := tc.newTransactionHandler(uint64(i)) if err != nil { + tc.log.WithError(err).Errorf("Failed to create transaction handler %d", i) return err } + tc.handlers[uint64(i)] = handler } return nil } -func (tc *transactionCoordinatorClient) grabConn(partition uint64) error { - lr, err := tc.client.lookupService.Lookup(getTCAssignTopicName(partition)) - if err != nil { - tc.log.WithError(err).Warn("Failed to lookup the transaction_impl " + - "coordinator assign topic [" + strconv.FormatUint(partition, 10) + "]") - return err - } - - requestID := tc.client.rpcClient.NewRequestID() - cmdTCConnect := pb.CommandTcClientConnectRequest{ - RequestId: proto.Uint64(requestID), - TcId: proto.Uint64(partition), - } - - res, err := tc.client.rpcClient.Request(lr.LogicalAddr, lr.PhysicalAddr, requestID, - pb.BaseCommand_TC_CLIENT_CONNECT_REQUEST, &cmdTCConnect) - - if err != nil { - tc.log.WithError(err).Error("Failed to connect transaction_impl coordinator " + - strconv.FormatUint(partition, 10)) - return err - } - tc.cons[partition] = res.Cnx - return nil -} - func (tc *transactionCoordinatorClient) close() { - for _, con := range tc.cons { - con.Close() + for _, h := range tc.handlers { + h.close() } } @@ -113,24 +376,13 @@ func (tc *transactionCoordinatorClient) newTransaction(timeout time.Duration) (* if err := tc.canSendRequest(); err != nil { return nil, err } - requestID := tc.client.rpcClient.NewRequestID() - nextTcID := tc.nextTCNumber() - cmdNewTxn := &pb.CommandNewTxn{ - RequestId: proto.Uint64(requestID), - TcId: proto.Uint64(nextTcID), - TxnTtlSeconds: proto.Uint64(uint64(timeout.Milliseconds())), - } - - res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[nextTcID], requestID, pb.BaseCommand_NEW_TXN, cmdNewTxn) - tc.semaphore.Release() - if err != nil { - return nil, err - } else if res.Response.NewTxnResponse.Error != nil { - return nil, getErrorFromServerError(res.Response.NewTxnResponse.Error) + defer tc.semaphore.Release() + op := &newTxnOp{ + timeout: timeout, + errCh: make(chan error), } - - return &TxnID{*res.Response.NewTxnResponse.TxnidMostBits, - *res.Response.NewTxnResponse.TxnidLeastBits}, nil + tc.handlers[tc.nextTCNumber()].requestCh <- op + return op.txnID, <-op.errCh } // addPublishPartitionToTxn register the partitions which published messages with the transactionImpl. @@ -139,22 +391,14 @@ func (tc *transactionCoordinatorClient) addPublishPartitionToTxn(id *TxnID, part if err := tc.canSendRequest(); err != nil { return err } - requestID := tc.client.rpcClient.NewRequestID() - cmdAddPartitions := &pb.CommandAddPartitionToTxn{ - RequestId: proto.Uint64(requestID), - TxnidMostBits: proto.Uint64(id.MostSigBits), - TxnidLeastBits: proto.Uint64(id.LeastSigBits), - Partitions: partitions, - } - res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.MostSigBits], requestID, - pb.BaseCommand_ADD_PARTITION_TO_TXN, cmdAddPartitions) - tc.semaphore.Release() - if err != nil { - return err - } else if res.Response.AddPartitionToTxnResponse.Error != nil { - return getErrorFromServerError(res.Response.AddPartitionToTxnResponse.Error) + defer tc.semaphore.Release() + op := &addPublishPartitionOp{ + id: id, + partitions: partitions, + errCh: make(chan error), } - return nil + tc.handlers[id.MostSigBits].requestCh <- op + return <-op.errCh } // addSubscriptionToTxn register the subscription which acked messages with the transactionImpl. @@ -163,26 +407,15 @@ func (tc *transactionCoordinatorClient) addSubscriptionToTxn(id *TxnID, topic st if err := tc.canSendRequest(); err != nil { return err } - requestID := tc.client.rpcClient.NewRequestID() - sub := &pb.Subscription{ - Topic: &topic, - Subscription: &subscription, - } - cmdAddSubscription := &pb.CommandAddSubscriptionToTxn{ - RequestId: proto.Uint64(requestID), - TxnidMostBits: proto.Uint64(id.MostSigBits), - TxnidLeastBits: proto.Uint64(id.LeastSigBits), - Subscription: []*pb.Subscription{sub}, + defer tc.semaphore.Release() + op := &addSubscriptionOp{ + id: id, + topic: topic, + subscription: subscription, + errCh: make(chan error), } - res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.MostSigBits], requestID, - pb.BaseCommand_ADD_SUBSCRIPTION_TO_TXN, cmdAddSubscription) - tc.semaphore.Release() - if err != nil { - return err - } else if res.Response.AddSubscriptionToTxnResponse.Error != nil { - return getErrorFromServerError(res.Response.AddSubscriptionToTxnResponse.Error) - } - return nil + tc.handlers[id.MostSigBits].requestCh <- op + return <-op.errCh } // endTxn commit or abort the transactionImpl. @@ -190,21 +423,14 @@ func (tc *transactionCoordinatorClient) endTxn(id *TxnID, action pb.TxnAction) e if err := tc.canSendRequest(); err != nil { return err } - requestID := tc.client.rpcClient.NewRequestID() - cmdEndTxn := &pb.CommandEndTxn{ - RequestId: proto.Uint64(requestID), - TxnAction: &action, - TxnidMostBits: proto.Uint64(id.MostSigBits), - TxnidLeastBits: proto.Uint64(id.LeastSigBits), + defer tc.semaphore.Release() + op := &endTxnOp{ + id: id, + action: action, + errCh: make(chan error), } - res, err := tc.client.rpcClient.RequestOnCnx(tc.cons[id.MostSigBits], requestID, pb.BaseCommand_END_TXN, cmdEndTxn) - tc.semaphore.Release() - if err != nil { - return err - } else if res.Response.EndTxnResponse.Error != nil { - return getErrorFromServerError(res.Response.EndTxnResponse.Error) - } - return nil + tc.handlers[id.MostSigBits].requestCh <- op + return <-op.errCh } func getTCAssignTopicName(partition uint64) string { @@ -219,5 +445,5 @@ func (tc *transactionCoordinatorClient) canSendRequest() error { } func (tc *transactionCoordinatorClient) nextTCNumber() uint64 { - return atomic.AddUint64(&tc.epoch, 1) % uint64(len(tc.cons)) + return atomic.AddUint64(&tc.epoch, 1) % uint64(len(tc.handlers)) } diff --git a/pulsar/transaction_test.go b/pulsar/transaction_test.go index 74e8dd0c9..eb88c7067 100644 --- a/pulsar/transaction_test.go +++ b/pulsar/transaction_test.go @@ -241,10 +241,11 @@ func TestConsumeAndProduceWithTxn(t *testing.T) { SubscriptionName: sub, }) assert.NoError(t, err) - producer, _ := client.CreateProducer(ProducerOptions{ + producer, err := client.CreateProducer(ProducerOptions{ Topic: topic, SendTimeout: 0, }) + assert.NoError(t, err) // Step 3: Open a transaction, send 10 messages with the transaction and 10 messages without the transaction. // Expectation: We can receive the 10 messages sent without a transaction and // cannot receive the 10 messages sent with the transaction. @@ -448,7 +449,7 @@ func consumerShouldNotReceiveMessage(t *testing.T, consumer Consumer) { } } -func TestAckChunkMessage(t *testing.T) { +func TestTransactionAckChunkMessage(t *testing.T) { topic := newTopicName() sub := "my-sub" @@ -539,3 +540,56 @@ func TestAckChunkMessage(t *testing.T) { require.Nil(t, err) consumerShouldNotReceiveMessage(t, consumer) } + +func TestTxnConnReconnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + topic := newTopicName() + _, cli := createTcClient(t) + + txn, err := cli.NewTransaction(5 * time.Minute) + assert.NoError(t, err) + + connections := cli.cnxPool.GetConnections() + for _, conn := range connections { + conn.Close() + } + + err = txn.Commit(ctx) + assert.NoError(t, err) + + txn, err = cli.NewTransaction(5 * time.Minute) + assert.NoError(t, err) // Assert that the transaction can be opened after the connections are reconnected + + // Start a goroutine to periodically close connections + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Second): + connections := cli.cnxPool.GetConnections() + for _, conn := range connections { + conn.Close() + } + } + } + }() + + producer, err := cli.CreateProducer(ProducerOptions{ + Topic: topic, + SendTimeout: 0, + }) + assert.NoError(t, err) + for i := 0; i < 10; i++ { + _, err := producer.Send(context.Background(), &ProducerMessage{ + Transaction: txn, + Payload: make([]byte, 1024), + }) + require.Nil(t, err) + time.Sleep(500 * time.Millisecond) + } + err = txn.Commit(context.Background()) + assert.NoError(t, err) +}