diff --git a/oxia/internal/executor.go b/oxia/internal/executor.go index 507a044e..e998ffb2 100644 --- a/oxia/internal/executor.go +++ b/oxia/internal/executor.go @@ -113,7 +113,7 @@ func (e *executorImpl) writeStream(shardId *int64) (*streamWrapper, error) { e.RLock() sw, ok := e.writeStreams[*shardId] - if ok { + if ok && !sw.failed.Load() { e.RUnlock() return sw, nil } diff --git a/oxia/internal/write_stream.go b/oxia/internal/write_stream.go index 432e2bed..8489e1c7 100644 --- a/oxia/internal/write_stream.go +++ b/oxia/internal/write_stream.go @@ -19,6 +19,7 @@ import ( "io" "log/slog" "sync" + "sync/atomic" "github.com/streamnative/oxia/common" "github.com/streamnative/oxia/proto" @@ -29,6 +30,7 @@ type streamWrapper struct { stream proto.OxiaClient_WriteStreamClient pendingRequests []common.Future[*proto.WriteResponse] + failed atomic.Bool } func newStreamWrapper(stream proto.OxiaClient_WriteStreamClient) *streamWrapper { @@ -48,6 +50,7 @@ func (sw *streamWrapper) Send(ctx context.Context, req *proto.WriteRequest) (*pr sw.Lock() sw.pendingRequests = append(sw.pendingRequests, f) if err := sw.stream.Send(req); err != nil { + sw.failed.Store(true) sw.Unlock() return nil, err } @@ -62,17 +65,23 @@ func (sw *streamWrapper) handleStreamClosed() { // Fail all pending requests sw.Lock() + defer sw.Unlock() + for _, f := range sw.pendingRequests { f.Fail(io.EOF) } sw.pendingRequests = nil - sw.Unlock() + sw.failed.Store(true) } func (sw *streamWrapper) handleResponses() { for { response, err := sw.stream.Recv() + sw.Lock() + if err != nil { + sw.failed.Store(true) + sw.Unlock() return } @@ -81,7 +90,6 @@ func (sw *streamWrapper) handleResponses() { slog.Any("err", err), ) - sw.Lock() var f common.Future[*proto.WriteResponse] f, sw.pendingRequests = sw.pendingRequests[0], sw.pendingRequests[1:] sw.Unlock()