Skip to content

Commit

Permalink
Fix / Request counting in delayed execution interceptor for streaming…
Browse files Browse the repository at this point in the history
… calls (#485)

* Fixes in delayed execution interceptor for streaming calls

* Disable the rapid fire test to see other error messages

* Updates in validator interceptor

* Simplify and fix delayed execution interceptor

* Do not trigger startCall() for each onHalfClose() event

* Re-enable rapid fire test for file storage

* Update comments in validation interceptor
  • Loading branch information
martin-traverse authored Dec 21, 2024
1 parent 5252fcf commit ec69f83
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,42 +31,32 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(

var delayedCall = new DelayedExecutionCall<>(call);

return new DelayedExecutionListener<>(delayedCall, headers, next);
return new DelayedExecutionListener<>(delayedCall, headers, next);
}

protected static class DelayedExecutionCall<ReqT, RespT> extends ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT> {

private long totalRequested;
private boolean firstRequest = true;

public DelayedExecutionCall(ServerCall<ReqT, RespT> delegate) {

super(delegate);

// Since startCall() is delayed, request one message straight away to start the pipeline
delegate.request(1);
}

@Override
public void request(int numMessages) {

// In the listener, onReady() sends out an extra request() call to pull the first message
// When the first request() comes from the main handler, it is a duplicate
// This method ignores the second request in the stream, to discard that duplicate

var priorRequested = totalRequested;

totalRequested += numMessages;

if (priorRequested >= 2) {
delegate().request(numMessages);
}
else if (priorRequested == 1) {
// Ignore the first request in the pipeline (do not send a duplicate)
if (firstRequest) {
firstRequest = false;
if (numMessages > 1)
delegate().request(numMessages - 1);
}
else {
if (numMessages == 1) {
delegate().request(numMessages);
}
else {
delegate().request(numMessages - 1);
}
delegate().request(numMessages);
}
}
}
Expand All @@ -78,7 +68,6 @@ protected static class DelayedExecutionListener<ReqT, RespT> extends ForwardingS
private final ServerCallHandler<ReqT, RespT> next;

private ServerCall.Listener<ReqT> delegate;
private boolean ready;

public DelayedExecutionListener(
ServerCall<ReqT, RespT> call, Metadata headers,
Expand All @@ -87,49 +76,50 @@ public DelayedExecutionListener(
this.call = call;
this.headers = headers;
this.next = next;

// By default, the interceptor is ready and can respond as soon as events arrive
this.ready = true;
}

@Override
@SuppressWarnings("unchecked")
protected ServerCall.Listener<ReqT> delegate() {

if (delegate != null)
return delegate;

else if (ready) {
if (delegate == null)
startCall();
return delegate;
}

return (ServerCall.Listener<ReqT>) NOOP_SINK;
return delegate;
}

protected void startCall() {

delegate = next.startCall(call, headers);
}

protected void setReady(boolean ready) {

// Allow child classes to delay the flow of events by turning this flag on / off
// Do not start the call twice if this method is called explicitly

this.ready = ready;
if (delegate == null)
delegate = next.startCall(call, headers);
}

@Override
public void onReady() {

// Do not trigger startCall() until the first real message is received

if (delegate == null)
call.request(1);
else
if (delegate != null)
delegate.onReady();
}
}

private static final ServerCall.Listener<?> NOOP_SINK = new ServerCall.Listener<>() {};
@Override
public void onCancel() {

// Do not trigger startCall() if the request is cancelled before it starts

if (delegate != null)
delegate.onCancel();
}

@Override
public void onHalfClose() {

// Do not trigger startCall() if the request is closed before it starts

if (delegate != null)
delegate.onHalfClose();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ public GrpcRequestValidator(GrpcServiceRegister serviceRegister, boolean logging
private class ValidationListener<ReqT, RespT> extends DelayedExecutionListener<ReqT, RespT> {

private final ServerCall<ReqT, RespT> serverCall;
private final Descriptors.MethodDescriptor methodDescriptor;
private final boolean loggingEnabled;

private final Descriptors.MethodDescriptor methodDescriptor;
private boolean validated = false;

public ValidationListener(
Expand All @@ -76,17 +76,16 @@ public ValidationListener(
ServerCallHandler<ReqT, RespT> nextHandler,
boolean loggingEnabled) {

// Using setReady(false) will prevent delayed interceptor from calling startCall()

super(serverCall, metadata, nextHandler);
super.setReady(false);

this.serverCall = serverCall;
this.loggingEnabled = loggingEnabled;

// Look up the descriptor for this call, to use for validation
var grpcDescriptor = serverCall.getMethodDescriptor();
var grpcMethodName = grpcDescriptor.getFullMethodName();

this.serverCall = serverCall;
this.methodDescriptor = serviceRegister.getMethodDescriptor(grpcMethodName);
this.loggingEnabled = loggingEnabled;
}

@Override
Expand All @@ -101,8 +100,9 @@ public void onMessage(ReqT req) {
validator.validateFixedMethod(message, methodDescriptor);
validated = true;

// Allow delayed interceptor to call startCAll() and start the normal flow of events
setReady(true);
// Allow delayed interceptor to start the normal flow of events
startCall();
delegate().onMessage(req);
}
catch (EValidation validationError) {

Expand All @@ -122,9 +122,10 @@ public void onMessage(ReqT req) {
serverCall.close(status, mappedError.getTrailers());
}
}
else {

// If validation has not succeeded, messages are sent to a no-op sink
delegate().onMessage(req);
delegate().onMessage(req);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void testRoundTrip_smallTextFile() throws Exception {
roundTripTest(content, false);
}

@RepeatedTest(100) @Tag("slow")
@RepeatedTest(100)
void rapidFireTest() throws Exception {

testRoundTrip_heterogeneousChunks();
Expand Down

0 comments on commit ec69f83

Please sign in to comment.