Skip to content

Commit

Permalink
Provide RequestHeaders to ServerErrorHandler (#4037)
Browse files Browse the repository at this point in the history
Motivation:

A user sometimes wants to generate an error response differently
depending on a certain request header value. For example:

- Include a certain request header value in an error response for easier
  troubleshooting; or
- Provide a different level of detail for the requests that contain
  a certain header.

Modifications:

- Added an optional `RequestHeaders` parameter to `ServerErrorHandler`
  methods.
- Modified `Http{1,2}RequestDecoder` so it constructs a `RequestHeaders`
  as early as possible, so that `ServerErrorHandler` is given with a
  non-null `RequestHeaders` in most cases.
- Updated `CustomServerErrorHandlerTest` to demonstrate this feature.

Result:

- (new feature) A user can generate an error response differently
  depending on the request headers.
- (breaking changes) The method signatures of `ServerErrorHandler` has
  been changed.
  • Loading branch information
trustin authored Jan 25, 2022
1 parent 1166d8e commit 1284b5e
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.Exceptions;
Expand Down Expand Up @@ -68,7 +69,8 @@ public HttpResponse onServiceException(ServiceRequestContext ctx, Throwable caus
logger.warn("{} Failed processing a request:", ctx, cause);
}

return internalRenderStatus(serviceConfig, HttpStatus.BAD_REQUEST, cause);
return internalRenderStatus(serviceConfig, ctx.request().headers(),
HttpStatus.BAD_REQUEST, cause);
}
}

Expand All @@ -85,33 +87,39 @@ public HttpResponse onServiceException(ServiceRequestContext ctx, Throwable caus
}

if (cause instanceof RequestTimeoutException) {
return internalRenderStatus(serviceConfig, HttpStatus.SERVICE_UNAVAILABLE, cause);
return internalRenderStatus(serviceConfig, ctx.request().headers(),
HttpStatus.SERVICE_UNAVAILABLE, cause);
}

if (isAnnotatedService && needsToWarn() && !Exceptions.isExpected(cause)) {
logger.warn("{} Unhandled exception from a service:", ctx, cause);
}

return internalRenderStatus(serviceConfig, HttpStatus.INTERNAL_SERVER_ERROR, cause);
return internalRenderStatus(serviceConfig, ctx.request().headers(),
HttpStatus.INTERNAL_SERVER_ERROR, cause);
}

@SuppressWarnings("deprecation")
private static boolean needsToWarn() {
return Flags.annotatedServiceExceptionVerbosity() == ExceptionVerbosity.UNHANDLED &&
logger.isWarnEnabled();
}

private static HttpResponse internalRenderStatus(ServiceConfig serviceConfig,
RequestHeaders headers,
HttpStatus status,
@Nullable Throwable cause) {
final AggregatedHttpResponse res = serviceConfig.server().config().errorHandler()
.renderStatus(serviceConfig, status, null, cause);
final AggregatedHttpResponse res =
serviceConfig.server().config().errorHandler()
.renderStatus(serviceConfig, headers, status, null, cause);
assert res != null;
return res.toHttpResponse();
}

@Nonnull
@Override
public AggregatedHttpResponse renderStatus(ServiceConfig config,
@Nullable RequestHeaders headers,
HttpStatus status,
@Nullable String description,
@Nullable Throwable cause) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,33 +145,45 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
keepAliveHandler.increaseNumRequests();
final HttpRequest nettyReq = (HttpRequest) msg;
if (!nettyReq.decoderResult().isSuccess()) {
fail(id, HttpStatus.BAD_REQUEST, "Decoder failure", null);
fail(id, null, HttpStatus.BAD_REQUEST, "Decoder failure", null);
return;
}

final HttpHeaders nettyHeaders = nettyReq.headers();

// Do not accept unsupported methods.
final io.netty.handler.codec.http.HttpMethod nettyMethod = nettyReq.method();
if (nettyMethod == io.netty.handler.codec.http.HttpMethod.CONNECT ||
!HttpMethod.isSupported(nettyMethod.name())) {
fail(id, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null);
if (!HttpMethod.isSupported(nettyMethod.name())) {
fail(id, null, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null);
return;
}

// Handle `expect: 100-continue` first to give `handle100Continue()` a chance to remove
// the `expect` header before converting the Netty HttpHeaders into Armeria RequestHeaders.
// This is because removing a header from RequestHeaders is more expensive due to its
// immutability.
final boolean hasInvalidExpectHeader = !handle100Continue(id, nettyReq);

// Convert the Netty HttpHeaders into Armeria RequestHeaders.
final RequestHeaders headers =
ArmeriaHttpUtil.toArmeria(ctx, nettyReq, cfg, scheme.toString());

// Do not accept a CONNECT request.
if (headers.method() == HttpMethod.CONNECT) {
fail(id, headers, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null);
return;
}

// Validate the 'content-length' header.
final String contentLengthStr = nettyHeaders.get(HttpHeaderNames.CONTENT_LENGTH);
final String contentLengthStr = headers.get(HttpHeaderNames.CONTENT_LENGTH);
final boolean contentEmpty;
if (contentLengthStr != null) {
final long contentLength;
long contentLength;
try {
contentLength = Long.parseLong(contentLengthStr);
} catch (NumberFormatException ignored) {
fail(id, HttpStatus.BAD_REQUEST, "Invalid content length", null);
return;
contentLength = -1;
}
if (contentLength < 0) {
fail(id, HttpStatus.BAD_REQUEST, "Invalid content length", null);
fail(id, headers, HttpStatus.BAD_REQUEST, "Invalid content length", null);
return;
}

Expand All @@ -180,32 +192,30 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
contentEmpty = true;
}

if (!handle100Continue(id, nettyReq, nettyHeaders)) {
// Reject the requests with an `expect` header whose value is not `100-continue`.
if (hasInvalidExpectHeader) {
ctx.pipeline().fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE);
fail(id, HttpStatus.EXPECTATION_FAILED, null, null);
fail(id, headers, HttpStatus.EXPECTATION_FAILED, null, null);
return;
}

// Close the request early when it is sure that there will be
// neither content nor trailers.
// Close the request early when it is certain there will be neither content nor trailers.
final EventLoop eventLoop = ctx.channel().eventLoop();
final RequestHeaders armeriaRequestHeaders =
ArmeriaHttpUtil.toArmeria(ctx, nettyReq, cfg, scheme.toString());
final boolean keepAlive = HttpUtil.isKeepAlive(nettyReq);
if (contentEmpty && !HttpUtil.isTransferEncodingChunked(nettyReq)) {
this.req = req = new EmptyContentDecodedHttpRequest(
eventLoop, id, 1, armeriaRequestHeaders, keepAlive);
eventLoop, id, 1, headers, keepAlive);
} else {
this.req = req = new DefaultDecodedHttpRequest(
eventLoop, id, 1, armeriaRequestHeaders, keepAlive, inboundTrafficController,
eventLoop, id, 1, headers, keepAlive, inboundTrafficController,
// FIXME(trustin): Use a different maxRequestLength for a different virtual
// host.
cfg.defaultVirtualHost().maxRequestLength());
}

ctx.fireChannelRead(req);
} else {
fail(id, HttpStatus.BAD_REQUEST, "Invalid decoder state", null);
fail(id, null, HttpStatus.BAD_REQUEST, "Invalid decoder state", null);
return;
}
}
Expand All @@ -223,11 +233,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
final HttpContent content = (HttpContent) msg;
final DecoderResult decoderResult = content.decoderResult();
if (!decoderResult.isSuccess()) {
final HttpStatus badRequest = HttpStatus.BAD_REQUEST;
fail(id, badRequest, Http2Error.PROTOCOL_ERROR, "Decoder failure", null);
fail(id, decodedReq.headers(), HttpStatus.BAD_REQUEST,
Http2Error.PROTOCOL_ERROR, "Decoder failure", null);
final ProtocolViolationException cause =
new ProtocolViolationException(decoderResult.cause());
decodedReq.close(HttpStatusException.of(badRequest, cause));
decodedReq.close(HttpStatusException.of(HttpStatus.BAD_REQUEST, cause));
return;
}

Expand All @@ -244,11 +254,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
.contentLength(req.headers())
.transferred(transferredLength)
.build();
final HttpStatus entityTooLarge = HttpStatus.REQUEST_ENTITY_TOO_LARGE;
fail(id, entityTooLarge, Http2Error.CANCEL, null, cause);
fail(id, decodedReq.headers(), HttpStatus.REQUEST_ENTITY_TOO_LARGE,
Http2Error.CANCEL, null, cause);
// Wrap the cause with the returned status to let LoggingService correctly log the
// status.
decodedReq.close(HttpStatusException.of(entityTooLarge, cause));
decodedReq.close(HttpStatusException.of(HttpStatus.REQUEST_ENTITY_TOO_LARGE, cause));
return;
}

Expand All @@ -268,25 +278,27 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}
}
} catch (URISyntaxException e) {
final HttpStatus badRequest = HttpStatus.BAD_REQUEST;
fail(id, badRequest, Http2Error.CANCEL, "Invalid request path", e);
if (req != null) {
req.close(HttpStatusException.of(badRequest, e));
fail(id, req.headers(), HttpStatus.BAD_REQUEST, Http2Error.CANCEL, "Invalid request path", e);
req.close(HttpStatusException.of(HttpStatus.BAD_REQUEST, e));
} else {
fail(id, null, HttpStatus.BAD_REQUEST, Http2Error.CANCEL, "Invalid request path", e);
}
} catch (Throwable t) {
final HttpStatus serverError = HttpStatus.INTERNAL_SERVER_ERROR;
fail(id, serverError, Http2Error.INTERNAL_ERROR, null, t);
if (req != null) {
req.close(HttpStatusException.of(serverError, t));
fail(id, req.headers(), HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, null, t);
req.close(HttpStatusException.of(HttpStatus.INTERNAL_SERVER_ERROR, t));
} else {
fail(id, null, HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, null, t);
logger.warn("Unexpected exception:", t);
}
} finally {
ReferenceCountUtil.release(msg);
}
}

private boolean handle100Continue(int id, HttpRequest nettyReq, HttpHeaders nettyHeaders) {
private boolean handle100Continue(int id, HttpRequest nettyReq) {
final HttpHeaders nettyHeaders = nettyReq.headers();
if (nettyReq.protocolVersion().compareTo(HttpVersion.HTTP_1_1) < 0) {
// Ignore HTTP/1.0 requests.
return true;
Expand All @@ -311,24 +323,25 @@ private boolean handle100Continue(int id, HttpRequest nettyReq, HttpHeaders nett
return true;
}

private void fail(int id, HttpStatus status, Http2Error error,
private void fail(int id, @Nullable RequestHeaders headers, HttpStatus status, Http2Error error,
@Nullable String message, @Nullable Throwable cause) {
if (encoder.isResponseHeadersSent(id, 1)) {
// The response is sent or being sent by HttpResponseSubscriber so we cannot send
// the error response.
encoder.writeReset(id, 1, error);
} else {
fail(id, status, message, cause);
fail(id, headers, status, message, cause);
}
}

private void fail(int id, HttpStatus status, @Nullable String message, @Nullable Throwable cause) {
private void fail(int id, @Nullable RequestHeaders headers, HttpStatus status,
@Nullable String message, @Nullable Throwable cause) {
discarding = true;
req = null;

// FIXME(trustin): Use a different verboseResponses for a different virtual host.
encoder.writeErrorResponse(id, 1, cfg.defaultVirtualHost().fallbackServiceConfig(),
status, message, cause);
headers, status, message, cause);
}

@Override
Expand Down
Loading

0 comments on commit 1284b5e

Please sign in to comment.