Skip to content

Commit

Permalink
Prevent from sending a message to an already closed channel. (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
YoEight authored Nov 12, 2024
1 parent 8206ac8 commit 88d8c74
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 37 deletions.
20 changes: 10 additions & 10 deletions esdb/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ func (client *Client) AppendToStream(

appendOperation, err := streamsClient.Append(ctx, callOptions...)
if err != nil {
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("could not construct append operation. Reason: %w", err)
}

header := toAppendHeader(streamID, opts.ExpectedRevision)

if err := appendOperation.Send(header); err != nil {
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("could not send append request header. Reason: %w", err)
}

Expand All @@ -81,14 +81,14 @@ func (client *Client) AppendToStream(
}

if err = appendOperation.Send(appendRequest); err != nil {
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("could not send append request. Reason: %w", err)
}
}

response, err := appendOperation.CloseAndRecv()
if err != nil {
return nil, client.grpcClient.handleError(handle, headers, trailers, err)
return nil, client.grpcClient.handleError(handle, trailers, err)
}

result := response.GetResult()
Expand Down Expand Up @@ -249,7 +249,7 @@ func (client *Client) DeleteStream(
deleteRequest := toDeleteRequest(streamID, opts.ExpectedRevision)
deleteResponse, err := streamsClient.Delete(ctx, deleteRequest, callOptions...)
if err != nil {
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("failed to perform delete, details: %w", err)
}

Expand Down Expand Up @@ -280,7 +280,7 @@ func (client *Client) TombstoneStream(
tombstoneResponse, err := streamsClient.Tombstone(ctx, tombstoneRequest, callOptions...)

if err != nil {
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("failed to perform delete, details: %w", err)
}

Expand Down Expand Up @@ -347,13 +347,13 @@ func (client *Client) SubscribeToStream(
readClient, err := streamsClient.Read(ctx, subscriptionRequest, callOptions...)
if err != nil {
defer cancel()
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("failed to construct subscription. Reason: %w", err)
}
readResult, err := readClient.Recv()
if err != nil {
defer cancel()
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("failed to perform read. Reason: %w", err)
}
switch readResult.Content.(type) {
Expand Down Expand Up @@ -401,13 +401,13 @@ func (client *Client) SubscribeToAll(
readClient, err := streamsClient.Read(ctx, subscriptionRequest, callOptions...)
if err != nil {
defer cancel()
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("failed to construct subscription. Reason: %w", err)
}
readResult, err := readClient.Recv()
if err != nil {
defer cancel()
err = client.grpcClient.handleError(handle, headers, trailers, err)
err = client.grpcClient.handleError(handle, trailers, err)
return nil, fmt.Errorf("failed to perform read. Reason: %w", err)
}
switch readResult.Content.(type) {
Expand Down
38 changes: 26 additions & 12 deletions esdb/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ type grpcClient struct {
perRPCCredentials credentials.PerRPCCredentials
}

func (client *grpcClient) handleError(handle *connectionHandle, headers metadata.MD, trailers metadata.MD, err error) error {
func (client *grpcClient) handleError(handle *connectionHandle, trailers metadata.MD, err error) error {
if client.isClosed() {
return &Error{
code: ErrorCodeConnectionClosed,
err: fmt.Errorf("connection is closed"),
}
}

values := trailers.Get("exception")

if values != nil && values[0] == "not-leader" {
Expand All @@ -55,9 +62,16 @@ func (client *grpcClient) handleError(handle *connectionHandle, headers metadata
endpoint: &endpoint,
}

client.channel <- msg
client.logger.error("not leader exception, reconnecting to %v", endpoint)
return &Error{code: ErrorCodeNotLeader}
if !client.isClosed() {
client.channel <- msg
client.logger.error("not leader exception, reconnecting to %v", endpoint)
return &Error{code: ErrorCodeNotLeader}
}

return &Error{
code: ErrorCodeConnectionClosed,
err: fmt.Errorf("connection is closed"),
}
}
}
}
Expand All @@ -70,19 +84,17 @@ func (client *grpcClient) handleError(handle *connectionHandle, headers metadata
client.logger.error("unexpected exception: %v", err)

code := errToCode(err)
if code == ErrorUnavailable {
msg := reconnect{
if code == ErrorUnavailable && !client.isClosed() {
client.channel <- reconnect{
correlation: handle.Id(),
}

client.channel <- msg
}

return &Error{code: code, err: err}
}

func (client *grpcClient) getConnectionHandle() (*connectionHandle, error) {
if atomic.LoadInt32(client.closeFlag) != 0 {
if client.isClosed() {
return nil, &Error{
code: ErrorCodeConnectionClosed,
err: fmt.Errorf("connection is closed"),
Expand All @@ -104,6 +116,10 @@ func (client *grpcClient) close() {
})
}

func (client *grpcClient) isClosed() bool {
return atomic.LoadInt32(client.closeFlag) != 0
}

type getConnection struct {
channel chan connectionHandle
}
Expand Down Expand Up @@ -254,7 +270,7 @@ func connectionStateMachine(config Configuration, closeFlag *int32, channel chan
}

if state.connection != nil {
state.connection.Close()
_ = state.connection.Close()
state.connection = nil
}

Expand Down Expand Up @@ -589,7 +605,6 @@ func errToCode(err error) ErrorCode {
}

func shuffleCandidates(src []string) []string {
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(src), func(i, j int) {
src[i], src[j] = src[j], src[i]
})
Expand All @@ -598,7 +613,6 @@ func shuffleCandidates(src []string) []string {
}

func shuffleMembers(src []*gossipApi.MemberInfo) []*gossipApi.MemberInfo {
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(src), func(i, j int) {
src[i], src[j] = src[j], src[i]
})
Expand Down
26 changes: 13 additions & 13 deletions esdb/persistent_subscription_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ func (client *persistentClient) ConnectToPersistentSubscription(
readClient, err := client.persistentSubscriptionClient.Read(ctx, callOptions...)
if err != nil {
defer cancel()
return nil, client.inner.handleError(handle, headers, trailers, err)
return nil, client.inner.handleError(handle, trailers, err)
}

err = readClient.Send(toPersistentReadRequest(bufferSize, groupName, []byte(streamName)))
if err != nil {
defer cancel()
return nil, client.inner.handleError(handle, headers, trailers, err)
return nil, client.inner.handleError(handle, trailers, err)
}

readResult, err := readClient.Recv()
if err != nil {
defer cancel()
return nil, client.inner.handleError(handle, headers, trailers, err)
return nil, client.inner.handleError(handle, trailers, err)
}
switch readResult.Content.(type) {
case *persistent.ReadResp_SubscriptionConfirmation_:
Expand Down Expand Up @@ -77,7 +77,7 @@ func (client *persistentClient) CreateStreamSubscription(
defer cancel()
_, err := client.persistentSubscriptionClient.Create(ctx, createSubscriptionConfig, callOptions...)
if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand Down Expand Up @@ -105,7 +105,7 @@ func (client *persistentClient) CreateAllSubscription(

_, err = client.persistentSubscriptionClient.Create(ctx, protoConfig, callOptions...)
if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand All @@ -129,7 +129,7 @@ func (client *persistentClient) UpdateStreamSubscription(

_, err := client.persistentSubscriptionClient.Update(ctx, updateSubscriptionConfig, callOptions...)
if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand All @@ -153,7 +153,7 @@ func (client *persistentClient) UpdateAllSubscription(

_, err := client.persistentSubscriptionClient.Update(ctx, updateSubscriptionConfig, callOptions...)
if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand All @@ -175,7 +175,7 @@ func (client *persistentClient) DeleteStreamSubscription(

_, err := client.persistentSubscriptionClient.Delete(ctx, deleteSubscriptionOptions, callOptions...)
if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand All @@ -196,7 +196,7 @@ func (client *persistentClient) DeleteAllSubscription(

_, err := client.persistentSubscriptionClient.Delete(ctx, deleteSubscriptionOptions, callOptions...)
if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand Down Expand Up @@ -246,7 +246,7 @@ func (client *persistentClient) listPersistentSubscriptions(
resp, err := client.persistentSubscriptionClient.List(ctx, listReq, callOptions...)

if err != nil {
return nil, client.inner.handleError(handle, headers, trailers, err)
return nil, client.inner.handleError(handle, trailers, err)
}

var infos []PersistentSubscriptionInfo
Expand Down Expand Up @@ -292,7 +292,7 @@ func (client *persistentClient) getPersistentSubscriptionInfo(

resp, err := client.persistentSubscriptionClient.GetInfo(ctx, getInfoReq, callOptions...)
if err != nil {
return nil, client.inner.handleError(handle, headers, trailers, err)
return nil, client.inner.handleError(handle, trailers, err)
}

info, err := subscriptionInfoFromWire(resp.SubscriptionInfo)
Expand Down Expand Up @@ -342,7 +342,7 @@ func (client *persistentClient) replayParkedMessages(
_, err := client.persistentSubscriptionClient.ReplayParked(ctx, replayReq, callOptions...)

if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand All @@ -362,7 +362,7 @@ func (client *persistentClient) restartSubsystem(
_, err := client.persistentSubscriptionClient.RestartSubsystem(ctx, &shared.Empty{}, callOptions...)

if err != nil {
return client.inner.handleError(handle, headers, trailers, err)
return client.inner.handleError(handle, trailers, err)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion esdb/projection_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ func (client *ProjectionClient) listInternal(

if err != nil {
if !errors.Is(err, io.EOF) {
err = client.inner.grpcClient.handleError(handle, headers, trailers, err)
err = client.inner.grpcClient.handleError(handle, trailers, err)
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion esdb/reads.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (stream *ReadStream) Recv() (*ResolvedEvent, error) {
atomic.StoreInt32(stream.closed, 1)

if !errors.Is(err, io.EOF) {
err = stream.params.client.handleError(stream.params.handle, *stream.params.headers, *stream.params.trailers, err)
err = stream.params.client.handleError(stream.params.handle, *stream.params.trailers, err)
}

return nil, err
Expand Down

0 comments on commit 88d8c74

Please sign in to comment.