Skip to content

Commit

Permalink
Make the executor close LeaseJobRuns stream (#3876)
Browse files Browse the repository at this point in the history
* Make the executor close LeaseJobRuns stream

Signed-off-by: JamesMurkin <[email protected]>

* Fix some tests

Signed-off-by: JamesMurkin <[email protected]>

* Simplify + tests

Signed-off-by: JamesMurkin <[email protected]>

* Simplify + tests

Signed-off-by: JamesMurkin <[email protected]>

* Improve tests

Signed-off-by: JamesMurkin <[email protected]>

* Make error while closing stream non fatal

Signed-off-by: JamesMurkin <[email protected]>

---------

Signed-off-by: JamesMurkin <[email protected]>
  • Loading branch information
JamesMurkin authored Sep 5, 2024
1 parent bdf9f4e commit 996b203
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 90 deletions.
70 changes: 38 additions & 32 deletions internal/executor/service/lease_requester.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package service

import (
"context"
"fmt"
"io"

grpcretry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
Expand Down Expand Up @@ -69,45 +69,51 @@ func (requester *JobLeaseRequester) LeaseJobRuns(ctx *armadacontext.Context, req
leaseRuns := []*executorapi.JobRunLease{}
runIdsToCancel := []*armadaevents.Uuid{}
runIdsToPreempt := []*armadaevents.Uuid{}
for {
shouldEndStreamCall := false
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
shouldEndStreamCall = true
} else {
return nil, ctx.Err()
}
default:
res, err := stream.Recv()
if err == io.EOF {
shouldEndStreamCall = true
} else if err != nil {
return nil, err
}

switch typed := res.GetEvent().(type) {
case *executorapi.LeaseStreamMessage_Lease:
leaseRuns = append(leaseRuns, typed.Lease)
case *executorapi.LeaseStreamMessage_PreemptRuns:
runIdsToPreempt = append(runIdsToPreempt, typed.PreemptRuns.JobRunIdsToPreempt...)
case *executorapi.LeaseStreamMessage_CancelRuns:
runIdsToCancel = append(runIdsToCancel, typed.CancelRuns.JobRunIdsToCancel...)
case *executorapi.LeaseStreamMessage_End:
shouldEndStreamCall = true
default:
log.Errorf("unexpected lease stream message type %T", typed)
}
shouldEndStream := false
for !shouldEndStream {
res, err := stream.Recv()
if err != nil {
return nil, err
}

if shouldEndStreamCall {
break
switch typed := res.GetEvent().(type) {
case *executorapi.LeaseStreamMessage_Lease:
leaseRuns = append(leaseRuns, typed.Lease)
case *executorapi.LeaseStreamMessage_PreemptRuns:
runIdsToPreempt = append(runIdsToPreempt, typed.PreemptRuns.JobRunIdsToPreempt...)
case *executorapi.LeaseStreamMessage_CancelRuns:
runIdsToCancel = append(runIdsToCancel, typed.CancelRuns.JobRunIdsToCancel...)
case *executorapi.LeaseStreamMessage_End:
shouldEndStream = true
default:
log.Errorf("unexpected lease stream message type %T", typed)
}
}

err = closeStream(stream)
if err != nil {
log.Warnf("Failed to close lease jobs stream cleanly - %s", err)
}

return &LeaseResponse{
LeasedRuns: leaseRuns,
RunIdsToCancel: runIdsToCancel,
RunIdsToPreempt: runIdsToPreempt,
}, nil
}

// This should be called after our end of stream message has been seen (LeaseStreamMessage_End)
// We call recv one more time and expect an EOF back, indicating the stream is properly closed
func closeStream(stream executorapi.ExecutorApi_LeaseJobRunsClient) error {
res, err := stream.Recv()
if err == nil {
switch typed := res.GetEvent().(type) {
default:
return fmt.Errorf("failed closing stream - unexpectedly received event of type %T", typed)
}
} else if err == io.EOF {
return nil
} else {
return err
}
}
147 changes: 89 additions & 58 deletions internal/executor/service/lease_requester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestLeaseJobRuns(t *testing.T) {
mockStream.EXPECT().Send(gomock.Any()).Return(nil)
setStreamExpectations(mockStream, tc.leaseMessages, tc.cancelMessages, tc.preemptMessages)
mockStream.EXPECT().Recv().Return(endMarker, nil)
mockStream.EXPECT().Recv().Return(nil, io.EOF)

response, err := jobRequester.LeaseJobRuns(ctx, &LeaseRequest{})
assert.NoError(t, err)
Expand Down Expand Up @@ -114,99 +115,84 @@ func TestLeaseJobRuns_Send(t *testing.T) {
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStream, nil)
mockStream.EXPECT().Send(expectedRequest).Return(nil)
mockStream.EXPECT().Recv().Return(endMarker, nil)
mockStream.EXPECT().Recv().Return(nil, io.EOF)

_, err := jobRequester.LeaseJobRuns(shortCtx, leaseRequest)
assert.NoError(t, err)
}

func TestLeaseJobRuns_HandlesNoEndMarkerMessage(t *testing.T) {
leaseMessages := []*executorapi.JobRunLease{lease1, lease2}
shortCtx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 200*time.Millisecond)
defer cancel()

jobRequester, mockExecutorApiClient, mockStream := setup(t)
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStream, nil)
mockStream.EXPECT().Send(gomock.Any()).Return(nil)
setStreamExpectations(mockStream, leaseMessages, nil, nil)
// No end marker, hang. Should
mockStream.EXPECT().Recv().Do(func() {
time.Sleep(time.Millisecond * 400)
})

response, err := jobRequester.LeaseJobRuns(shortCtx, &LeaseRequest{})
// Timeout on context expiry
assert.NoError(t, err)
// Still receive leases that were received prior to the timeout
assert.Equal(t, leaseMessages, response.LeasedRuns)
}

func TestLeaseJobRuns_Error(t *testing.T) {
func TestLeaseJobRuns_ReceiveError(t *testing.T) {
endStreamMarkerTimeoutErr := fmt.Errorf("end of stream marker timeout")
closeStreamErr := fmt.Errorf("close stream timeout")
receiveErr := fmt.Errorf("recv error")
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second)
defer cancel()
tests := map[string]struct {
streamError bool
sendError bool
recvError bool
recvEndOfFileError bool
shouldError bool
leaseMessages []*executorapi.JobRunLease
expectedLeases []*executorapi.JobRunLease
recvError bool
endOfStreamErr bool
closeStreamErr bool
shouldError bool
expectedError error
leaseMessages []*executorapi.JobRunLease
expectedLeases []*executorapi.JobRunLease
}{
"StreamError": {
sendError: true,
shouldError: true,
"Happy Path": {
shouldError: false,
leaseMessages: []*executorapi.JobRunLease{lease1, lease2},
expectedLeases: nil,
expectedLeases: []*executorapi.JobRunLease{lease1, lease2},
},
"SendError": {
sendError: true,
"RecvError": {
recvError: true,
shouldError: true,
expectedError: receiveErr,
leaseMessages: []*executorapi.JobRunLease{lease1, lease2},
expectedLeases: nil,
},
"RecvError": {
recvError: true,
"Timeout - end stream marker": {
endOfStreamErr: true,
shouldError: true,
expectedError: endStreamMarkerTimeoutErr,
leaseMessages: []*executorapi.JobRunLease{lease1, lease2},
expectedLeases: nil,
},
"RecvEOF": {
recvEndOfFileError: true,
shouldError: false,
leaseMessages: []*executorapi.JobRunLease{lease1, lease2},
expectedLeases: []*executorapi.JobRunLease{lease1, lease2},
"Close stream error": {
closeStreamErr: true,
shouldError: false,
expectedError: closeStreamErr,
leaseMessages: []*executorapi.JobRunLease{lease1, lease2},
expectedLeases: []*executorapi.JobRunLease{lease1, lease2},
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
jobRequester, mockExecutorApiClient, mockStream := setup(t)
if tc.streamError {
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("stream error")).AnyTimes()
} else {
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStream, nil).AnyTimes()
}
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStream, nil)
mockStream.EXPECT().Send(gomock.Any()).Return(nil)

if tc.sendError {
mockStream.EXPECT().Send(gomock.Any()).Return(fmt.Errorf("send error")).AnyTimes()
if tc.recvError {
mockStream.EXPECT().Recv().Return(nil, receiveErr).AnyTimes()
} else {
mockStream.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes()
}
setStreamExpectations(mockStream, tc.leaseMessages, nil, nil)

if tc.recvError || tc.recvEndOfFileError {
if tc.recvError {
mockStream.EXPECT().Recv().Return(nil, fmt.Errorf("recv error")).AnyTimes()
}
if tc.recvEndOfFileError {
setStreamExpectations(mockStream, tc.leaseMessages, nil, nil)
mockStream.EXPECT().Recv().Return(nil, io.EOF).AnyTimes()
if tc.endOfStreamErr {
mockStream.EXPECT().Recv().Return(nil, endStreamMarkerTimeoutErr)
} else {
mockStream.EXPECT().Recv().Return(endMarker, nil)

if tc.closeStreamErr {
mockStream.EXPECT().Recv().Return(nil, closeStreamErr)
} else {
mockStream.EXPECT().Recv().Return(nil, io.EOF)
}
}
}

response, err := jobRequester.LeaseJobRuns(ctx, &LeaseRequest{})
if tc.shouldError {
assert.Error(t, err)
assert.Nil(t, response)
assert.Contains(t, err.Error(), tc.expectedError.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expectedLeases, response.LeasedRuns)
Expand All @@ -215,6 +201,51 @@ func TestLeaseJobRuns_Error(t *testing.T) {
}
}

func TestLeaseJobRuns_SendError(t *testing.T) {
streamError := fmt.Errorf("stream error")
sendError := fmt.Errorf("send error")
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second)
defer cancel()
tests := map[string]struct {
streamError bool
sendError bool
expectedError error
}{
"StreamError": {
streamError: true,
expectedError: streamError,
},
"SendError": {
sendError: true,
expectedError: sendError,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
jobRequester, mockExecutorApiClient, mockStream := setup(t)
if tc.streamError {
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, streamError).AnyTimes()
} else {
mockExecutorApiClient.EXPECT().LeaseJobRuns(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockStream, nil).AnyTimes()
}

if tc.sendError {
mockStream.EXPECT().Send(gomock.Any()).Return(sendError).AnyTimes()
} else {
mockStream.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes()
}

mockStream.EXPECT().Recv().Times(0)

response, err := jobRequester.LeaseJobRuns(ctx, &LeaseRequest{})
assert.Error(t, err)
assert.Nil(t, response)
assert.Contains(t, err.Error(), tc.expectedError.Error())
})
}
}

func setup(t *testing.T) (*JobLeaseRequester, *mocks.MockExecutorApiClient, *mocks.MockExecutorApi_LeaseJobRunsClient) {
ctrl := gomock.NewController(t)
mockExecutorApiClient := mocks.NewMockExecutorApiClient(ctrl)
Expand Down

0 comments on commit 996b203

Please sign in to comment.