diff --git a/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java b/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java index a736242d7ee..3bc719a2f0b 100644 --- a/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/DefaultServerErrorHandler.java @@ -28,6 +28,7 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.Exceptions; @@ -68,7 +69,8 @@ public HttpResponse onServiceException(ServiceRequestContext ctx, Throwable caus logger.warn("{} Failed processing a request:", ctx, cause); } - return internalRenderStatus(serviceConfig, HttpStatus.BAD_REQUEST, cause); + return internalRenderStatus(serviceConfig, ctx.request().headers(), + HttpStatus.BAD_REQUEST, cause); } } @@ -85,26 +87,31 @@ public HttpResponse onServiceException(ServiceRequestContext ctx, Throwable caus } if (cause instanceof RequestTimeoutException) { - return internalRenderStatus(serviceConfig, HttpStatus.SERVICE_UNAVAILABLE, cause); + return internalRenderStatus(serviceConfig, ctx.request().headers(), + HttpStatus.SERVICE_UNAVAILABLE, cause); } if (isAnnotatedService && needsToWarn() && !Exceptions.isExpected(cause)) { logger.warn("{} Unhandled exception from a service:", ctx, cause); } - return internalRenderStatus(serviceConfig, HttpStatus.INTERNAL_SERVER_ERROR, cause); + return internalRenderStatus(serviceConfig, ctx.request().headers(), + HttpStatus.INTERNAL_SERVER_ERROR, cause); } + @SuppressWarnings("deprecation") private static boolean needsToWarn() { return Flags.annotatedServiceExceptionVerbosity() == ExceptionVerbosity.UNHANDLED && logger.isWarnEnabled(); } private static HttpResponse internalRenderStatus(ServiceConfig serviceConfig, + RequestHeaders headers, HttpStatus status, @Nullable Throwable cause) { - final AggregatedHttpResponse res = serviceConfig.server().config().errorHandler() - .renderStatus(serviceConfig, status, null, cause); + final AggregatedHttpResponse res = + serviceConfig.server().config().errorHandler() + .renderStatus(serviceConfig, headers, status, null, cause); assert res != null; return res.toHttpResponse(); } @@ -112,6 +119,7 @@ private static HttpResponse internalRenderStatus(ServiceConfig serviceConfig, @Nonnull @Override public AggregatedHttpResponse renderStatus(ServiceConfig config, + @Nullable RequestHeaders headers, HttpStatus status, @Nullable String description, @Nullable Throwable cause) { diff --git a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java index 8bd8ec86621..a01b5ef3fac 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java @@ -145,33 +145,45 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception keepAliveHandler.increaseNumRequests(); final HttpRequest nettyReq = (HttpRequest) msg; if (!nettyReq.decoderResult().isSuccess()) { - fail(id, HttpStatus.BAD_REQUEST, "Decoder failure", null); + fail(id, null, HttpStatus.BAD_REQUEST, "Decoder failure", null); return; } - final HttpHeaders nettyHeaders = nettyReq.headers(); - // Do not accept unsupported methods. final io.netty.handler.codec.http.HttpMethod nettyMethod = nettyReq.method(); - if (nettyMethod == io.netty.handler.codec.http.HttpMethod.CONNECT || - !HttpMethod.isSupported(nettyMethod.name())) { - fail(id, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); + if (!HttpMethod.isSupported(nettyMethod.name())) { + fail(id, null, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); + return; + } + + // Handle `expect: 100-continue` first to give `handle100Continue()` a chance to remove + // the `expect` header before converting the Netty HttpHeaders into Armeria RequestHeaders. + // This is because removing a header from RequestHeaders is more expensive due to its + // immutability. + final boolean hasInvalidExpectHeader = !handle100Continue(id, nettyReq); + + // Convert the Netty HttpHeaders into Armeria RequestHeaders. + final RequestHeaders headers = + ArmeriaHttpUtil.toArmeria(ctx, nettyReq, cfg, scheme.toString()); + + // Do not accept a CONNECT request. + if (headers.method() == HttpMethod.CONNECT) { + fail(id, headers, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); return; } // Validate the 'content-length' header. - final String contentLengthStr = nettyHeaders.get(HttpHeaderNames.CONTENT_LENGTH); + final String contentLengthStr = headers.get(HttpHeaderNames.CONTENT_LENGTH); final boolean contentEmpty; if (contentLengthStr != null) { - final long contentLength; + long contentLength; try { contentLength = Long.parseLong(contentLengthStr); } catch (NumberFormatException ignored) { - fail(id, HttpStatus.BAD_REQUEST, "Invalid content length", null); - return; + contentLength = -1; } if (contentLength < 0) { - fail(id, HttpStatus.BAD_REQUEST, "Invalid content length", null); + fail(id, headers, HttpStatus.BAD_REQUEST, "Invalid content length", null); return; } @@ -180,24 +192,22 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception contentEmpty = true; } - if (!handle100Continue(id, nettyReq, nettyHeaders)) { + // Reject the requests with an `expect` header whose value is not `100-continue`. + if (hasInvalidExpectHeader) { ctx.pipeline().fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); - fail(id, HttpStatus.EXPECTATION_FAILED, null, null); + fail(id, headers, HttpStatus.EXPECTATION_FAILED, null, null); return; } - // Close the request early when it is sure that there will be - // neither content nor trailers. + // Close the request early when it is certain there will be neither content nor trailers. final EventLoop eventLoop = ctx.channel().eventLoop(); - final RequestHeaders armeriaRequestHeaders = - ArmeriaHttpUtil.toArmeria(ctx, nettyReq, cfg, scheme.toString()); final boolean keepAlive = HttpUtil.isKeepAlive(nettyReq); if (contentEmpty && !HttpUtil.isTransferEncodingChunked(nettyReq)) { this.req = req = new EmptyContentDecodedHttpRequest( - eventLoop, id, 1, armeriaRequestHeaders, keepAlive); + eventLoop, id, 1, headers, keepAlive); } else { this.req = req = new DefaultDecodedHttpRequest( - eventLoop, id, 1, armeriaRequestHeaders, keepAlive, inboundTrafficController, + eventLoop, id, 1, headers, keepAlive, inboundTrafficController, // FIXME(trustin): Use a different maxRequestLength for a different virtual // host. cfg.defaultVirtualHost().maxRequestLength()); @@ -205,7 +215,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ctx.fireChannelRead(req); } else { - fail(id, HttpStatus.BAD_REQUEST, "Invalid decoder state", null); + fail(id, null, HttpStatus.BAD_REQUEST, "Invalid decoder state", null); return; } } @@ -223,11 +233,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception final HttpContent content = (HttpContent) msg; final DecoderResult decoderResult = content.decoderResult(); if (!decoderResult.isSuccess()) { - final HttpStatus badRequest = HttpStatus.BAD_REQUEST; - fail(id, badRequest, Http2Error.PROTOCOL_ERROR, "Decoder failure", null); + fail(id, decodedReq.headers(), HttpStatus.BAD_REQUEST, + Http2Error.PROTOCOL_ERROR, "Decoder failure", null); final ProtocolViolationException cause = new ProtocolViolationException(decoderResult.cause()); - decodedReq.close(HttpStatusException.of(badRequest, cause)); + decodedReq.close(HttpStatusException.of(HttpStatus.BAD_REQUEST, cause)); return; } @@ -244,11 +254,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception .contentLength(req.headers()) .transferred(transferredLength) .build(); - final HttpStatus entityTooLarge = HttpStatus.REQUEST_ENTITY_TOO_LARGE; - fail(id, entityTooLarge, Http2Error.CANCEL, null, cause); + fail(id, decodedReq.headers(), HttpStatus.REQUEST_ENTITY_TOO_LARGE, + Http2Error.CANCEL, null, cause); // Wrap the cause with the returned status to let LoggingService correctly log the // status. - decodedReq.close(HttpStatusException.of(entityTooLarge, cause)); + decodedReq.close(HttpStatusException.of(HttpStatus.REQUEST_ENTITY_TOO_LARGE, cause)); return; } @@ -268,17 +278,18 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } } catch (URISyntaxException e) { - final HttpStatus badRequest = HttpStatus.BAD_REQUEST; - fail(id, badRequest, Http2Error.CANCEL, "Invalid request path", e); if (req != null) { - req.close(HttpStatusException.of(badRequest, e)); + fail(id, req.headers(), HttpStatus.BAD_REQUEST, Http2Error.CANCEL, "Invalid request path", e); + req.close(HttpStatusException.of(HttpStatus.BAD_REQUEST, e)); + } else { + fail(id, null, HttpStatus.BAD_REQUEST, Http2Error.CANCEL, "Invalid request path", e); } } catch (Throwable t) { - final HttpStatus serverError = HttpStatus.INTERNAL_SERVER_ERROR; - fail(id, serverError, Http2Error.INTERNAL_ERROR, null, t); if (req != null) { - req.close(HttpStatusException.of(serverError, t)); + fail(id, req.headers(), HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, null, t); + req.close(HttpStatusException.of(HttpStatus.INTERNAL_SERVER_ERROR, t)); } else { + fail(id, null, HttpStatus.INTERNAL_SERVER_ERROR, Http2Error.INTERNAL_ERROR, null, t); logger.warn("Unexpected exception:", t); } } finally { @@ -286,7 +297,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } - private boolean handle100Continue(int id, HttpRequest nettyReq, HttpHeaders nettyHeaders) { + private boolean handle100Continue(int id, HttpRequest nettyReq) { + final HttpHeaders nettyHeaders = nettyReq.headers(); if (nettyReq.protocolVersion().compareTo(HttpVersion.HTTP_1_1) < 0) { // Ignore HTTP/1.0 requests. return true; @@ -311,24 +323,25 @@ private boolean handle100Continue(int id, HttpRequest nettyReq, HttpHeaders nett return true; } - private void fail(int id, HttpStatus status, Http2Error error, + private void fail(int id, @Nullable RequestHeaders headers, HttpStatus status, Http2Error error, @Nullable String message, @Nullable Throwable cause) { if (encoder.isResponseHeadersSent(id, 1)) { // The response is sent or being sent by HttpResponseSubscriber so we cannot send // the error response. encoder.writeReset(id, 1, error); } else { - fail(id, status, message, cause); + fail(id, headers, status, message, cause); } } - private void fail(int id, HttpStatus status, @Nullable String message, @Nullable Throwable cause) { + private void fail(int id, @Nullable RequestHeaders headers, HttpStatus status, + @Nullable String message, @Nullable Throwable cause) { discarding = true; req = null; // FIXME(trustin): Use a different verboseResponses for a different virtual host. encoder.writeErrorResponse(id, 1, cfg.defaultVirtualHost().fallbackServiceConfig(), - status, message, cause); + headers, status, message, cause); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java index dd6f37a415d..6c784a13baa 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http2RequestDecoder.java @@ -91,53 +91,73 @@ public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) { } @Override - public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers nettyHeaders, int padding, boolean endOfStream) throws Http2Exception { keepAliveChannelRead(true); DecodedHttpRequest req = requests.get(streamId); if (req == null) { assert encoder != null; + // Handle `expect: 100-continue` first to give `handle100Continue()` a chance to remove + // the `expect` header before converting the Netty HttpHeaders into Armeria RequestHeaders. + // This is because removing a header from RequestHeaders is more expensive due to its + // immutability. + final boolean hasInvalidExpectHeader = !handle100Continue(streamId, nettyHeaders); + // Validate the method. - final CharSequence methodText = headers.method(); + final CharSequence methodText = nettyHeaders.method(); if (methodText == null) { - writeErrorResponse(streamId, HttpStatus.BAD_REQUEST, "Missing method", null); + writeErrorResponse(streamId, null, HttpStatus.BAD_REQUEST, "Missing method", null); return; } // Reject a request with an unsupported method. - // Note: Accept a CONNECT request with a :protocol header, as defined in: - // https://datatracker.ietf.org/doc/html/rfc8441#section-4 final HttpMethod method = HttpMethod.tryParse(methodText.toString()); - if (method == null || - method == HttpMethod.CONNECT && !headers.contains(HttpHeaderNames.PROTOCOL)) { - writeErrorResponse(streamId, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); + if (method == null) { + writeErrorResponse(streamId, null, HttpStatus.METHOD_NOT_ALLOWED, "Unsupported method", null); + return; + } + + // Convert the Netty Http2Headers into Armeria RequestHeaders. + final RequestHeaders headers = + ArmeriaHttpUtil.toArmeriaRequestHeaders(ctx, nettyHeaders, endOfStream, scheme, cfg); + + // Accept a CONNECT request only when it has a :protocol header, as defined in: + // https://datatracker.ietf.org/doc/html/rfc8441#section-4 + if (method == HttpMethod.CONNECT && !nettyHeaders.contains(HttpHeaderNames.PROTOCOL)) { + writeErrorResponse(streamId, headers, HttpStatus.METHOD_NOT_ALLOWED, + "Unsupported method", null); return; } // Validate the 'content-length' header if exists. - if (headers.contains(HttpHeaderNames.CONTENT_LENGTH)) { - final long contentLength = headers.getLong(HttpHeaderNames.CONTENT_LENGTH, -1L); + final String contentLengthStr = headers.get(HttpHeaderNames.CONTENT_LENGTH); + if (contentLengthStr != null) { + long contentLength; + try { + contentLength = Long.parseLong(contentLengthStr); + } catch (NumberFormatException ignored) { + contentLength = -1; + } if (contentLength < 0) { - writeErrorResponse(streamId, HttpStatus.BAD_REQUEST, "Invalid content length", null); + writeErrorResponse(streamId, headers, HttpStatus.BAD_REQUEST, + "Invalid content length", null); return; } } - if (!handle100Continue(streamId, headers)) { - writeErrorResponse(streamId, HttpStatus.EXPECTATION_FAILED, null, null); + if (hasInvalidExpectHeader) { + writeErrorResponse(streamId, headers, HttpStatus.EXPECTATION_FAILED, null, null); return; } final EventLoop eventLoop = ctx.channel().eventLoop(); - final RequestHeaders armeriaRequestHeaders = - ArmeriaHttpUtil.toArmeriaRequestHeaders(ctx, headers, endOfStream, scheme, cfg); final int id = ++nextId; if (endOfStream) { // Close the request early when it is sure that there will be neither content nor trailers. - req = new EmptyContentDecodedHttpRequest(eventLoop, id, streamId, armeriaRequestHeaders, true); + req = new EmptyContentDecodedHttpRequest(eventLoop, id, streamId, headers, true); } else { - req = new DefaultDecodedHttpRequest(eventLoop, id, streamId, armeriaRequestHeaders, true, + req = new DefaultDecodedHttpRequest(eventLoop, id, streamId, headers, true, inboundTrafficController, // FIXME(trustin): Use a different maxRequestLength for // a different host. @@ -150,7 +170,7 @@ public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers final DefaultDecodedHttpRequest decodedReq = (DefaultDecodedHttpRequest) req; try { // Trailers is received. The decodedReq will be automatically closed. - decodedReq.write(ArmeriaHttpUtil.toArmeria(headers, true, endOfStream)); + decodedReq.write(ArmeriaHttpUtil.toArmeria(nettyHeaders, true, endOfStream)); } catch (Throwable t) { decodedReq.close(t); throw connectionError(INTERNAL_ERROR, t, "failed to consume a HEADERS frame"); @@ -235,11 +255,10 @@ public int onDataRead( .transferred(transferredLength) .build(); - final HttpStatus entityTooLarge = HttpStatus.REQUEST_ENTITY_TOO_LARGE; - writeErrorResponse(streamId, entityTooLarge, null, cause); + writeErrorResponse(streamId, req.headers(), HttpStatus.REQUEST_ENTITY_TOO_LARGE, null, cause); if (decodedReq.isOpen()) { - decodedReq.close(HttpStatusException.of(entityTooLarge, cause)); + decodedReq.close(HttpStatusException.of(HttpStatus.REQUEST_ENTITY_TOO_LARGE, cause)); } } else { // The response has been started already. Abort the request and let the response continue. @@ -272,12 +291,13 @@ private static boolean isWritable(@Nullable Http2Stream stream) { } } - private void writeErrorResponse( - int streamId, HttpStatus status, @Nullable String message, @Nullable Throwable cause) { + private void writeErrorResponse(int streamId, @Nullable RequestHeaders headers, + HttpStatus status, @Nullable String message, + @Nullable Throwable cause) { assert encoder != null; encoder.writeErrorResponse(0 /* unused */, streamId, cfg.defaultVirtualHost().fallbackServiceConfig(), - status, message, cause); + headers, status, message, cause); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpResponseSubscriber.java b/core/src/main/java/com/linecorp/armeria/server/HttpResponseSubscriber.java index b76156805ae..f0628497df5 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpResponseSubscriber.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpResponseSubscriber.java @@ -315,7 +315,7 @@ public void onError(Throwable cause) { final ServiceConfig serviceConfig = reqCtx.config(); final AggregatedHttpResponse res = serviceConfig.server().config().errorHandler() - .renderStatus(serviceConfig, status, null, cause0); + .renderStatus(serviceConfig, req.headers(), status, null, cause0); assert res != null; failAndRespond(cause0, res, Http2Error.CANCEL, false); } else if (Exceptions.isStreamCancelling(cause)) { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerErrorHandler.java b/core/src/main/java/com/linecorp/armeria/server/ServerErrorHandler.java index 09897fea947..2684fdea9e5 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerErrorHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerErrorHandler.java @@ -21,6 +21,7 @@ import com.linecorp.armeria.common.ContentTooLargeException; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.logging.RequestLog; @@ -109,6 +110,7 @@ static ServerErrorHandler ofDefault() { * for each request failed at the protocol level. * * @param config the {@link ServiceConfig} that provides the configuration properties. + * @param headers the received {@link RequestHeaders}, or {@code null} in case of severe protocol violation. * @param status the desired {@link HttpStatus} of the error response. * @param description an optional human-readable description of the error. * @param cause an optional exception that may contain additional information about the error, such as @@ -119,10 +121,11 @@ static ServerErrorHandler ofDefault() { */ @Nullable default AggregatedHttpResponse onProtocolViolation(ServiceConfig config, + @Nullable RequestHeaders headers, HttpStatus status, @Nullable String description, @Nullable Throwable cause) { - return renderStatus(config, status, description, cause); + return renderStatus(config, headers, status, description, cause); } /** @@ -135,9 +138,10 @@ default AggregatedHttpResponse onProtocolViolation(ServiceConfig config, * {@link ServerErrorHandler} or even independently, and thus should not be used for counting * {@link HttpStatusException}s or collecting stats. Use * {@link #onServiceException(ServiceRequestContext, Throwable)} and - * {@link #onProtocolViolation(ServiceConfig, HttpStatus, String, Throwable)} instead. + * {@link #onProtocolViolation(ServiceConfig, RequestHeaders, HttpStatus, String, Throwable)} instead. * * @param config the {@link ServiceConfig} that provides the configuration properties. + * @param headers the received {@link RequestHeaders}, or {@code null} in case of severe protocol violation. * @param status the desired {@link HttpStatus} of the error response. * @param description an optional human-readable description of the error. * @param cause an optional exception that may contain additional information about the error, such as @@ -148,6 +152,7 @@ default AggregatedHttpResponse onProtocolViolation(ServiceConfig config, */ @Nullable default AggregatedHttpResponse renderStatus(ServiceConfig config, + @Nullable RequestHeaders headers, HttpStatus status, @Nullable String description, @Nullable Throwable cause) { @@ -202,28 +207,32 @@ public HttpResponse onServiceException(ServiceRequestContext ctx, Throwable caus @Nullable @Override public AggregatedHttpResponse onProtocolViolation(ServiceConfig config, + @Nullable RequestHeaders headers, HttpStatus status, @Nullable String description, @Nullable Throwable cause) { final AggregatedHttpResponse response = - ServerErrorHandler.this.onProtocolViolation(config, status, description, cause); + ServerErrorHandler.this.onProtocolViolation( + config, headers, status, description, cause); if (response != null) { return response; } - return other.onProtocolViolation(config, status, description, cause); + return other.onProtocolViolation(config, headers, status, description, cause); } @Nullable @Override - public AggregatedHttpResponse renderStatus(ServiceConfig config, HttpStatus status, + public AggregatedHttpResponse renderStatus(ServiceConfig config, + @Nullable RequestHeaders headers, + HttpStatus status, @Nullable String description, @Nullable Throwable cause) { final AggregatedHttpResponse response = - ServerErrorHandler.super.renderStatus(config, status, description, cause); + ServerErrorHandler.super.renderStatus(config, headers, status, description, cause); if (response != null) { return response; } - return other.renderStatus(config, status, description, cause); + return other.renderStatus(config, headers, status, description, cause); } }; } diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerHttp1ObjectEncoder.java b/core/src/main/java/com/linecorp/armeria/server/ServerHttp1ObjectEncoder.java index 1deaa0d020b..cc01dba6554 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerHttp1ObjectEncoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerHttp1ObjectEncoder.java @@ -20,6 +20,7 @@ import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; @@ -184,6 +185,7 @@ public boolean isResponseHeadersSent(int id, int streamId) { @Override public ChannelFuture writeErrorResponse(int id, int streamId, ServiceConfig serviceConfig, + RequestHeaders headers, HttpStatus status, @Nullable String message, @Nullable Throwable cause) { @@ -193,7 +195,7 @@ public ChannelFuture writeErrorResponse(int id, int streamId, keepAliveHandler().destroy(); final ChannelFuture future = ServerHttpObjectEncoder.super.writeErrorResponse( - id, streamId, serviceConfig, status, message, cause); + id, streamId, serviceConfig, headers, status, message, cause); // Update the closed ID to prevent the HttpResponseSubscriber from // writing additional headers or messages. updateClosedId(id); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerHttp2ObjectEncoder.java b/core/src/main/java/com/linecorp/armeria/server/ServerHttp2ObjectEncoder.java index 3bfb5c60f1e..3fece6c74ab 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerHttp2ObjectEncoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerHttp2ObjectEncoder.java @@ -19,6 +19,7 @@ import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpHeadersBuilder; import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.ClosedStreamException; @@ -128,11 +129,11 @@ private void onKeepAliveReadOrWrite() { @Override public ChannelFuture writeErrorResponse(int id, int streamId, ServiceConfig serviceConfig, - HttpStatus status, @Nullable String message, - @Nullable Throwable cause) { + @Nullable RequestHeaders headers, HttpStatus status, + @Nullable String message, @Nullable Throwable cause) { ChannelFuture future = ServerHttpObjectEncoder.super.writeErrorResponse( - id, streamId, serviceConfig, status, message, cause); + id, streamId, serviceConfig, headers, status, message, cause); final Http2Stream stream = findStream(streamId); if (stream != null) { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerHttpObjectEncoder.java b/core/src/main/java/com/linecorp/armeria/server/ServerHttpObjectEncoder.java index 9fb20ee56d0..b8eac58e19e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerHttpObjectEncoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerHttpObjectEncoder.java @@ -21,6 +21,7 @@ import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpObject; import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.common.HttpObjectEncoder; @@ -70,36 +71,37 @@ ChannelFuture doWriteHeaders(int id, int streamId, ResponseHeaders headers, bool default ChannelFuture writeErrorResponse(int id, int streamId, ServiceConfig serviceConfig, + @Nullable RequestHeaders headers, HttpStatus status, @Nullable String message, @Nullable Throwable cause) { final AggregatedHttpResponse res = serviceConfig.server().config().errorHandler() - .onProtocolViolation(serviceConfig, status, message, cause); + .onProtocolViolation(serviceConfig, headers, status, message, cause); assert res != null; final HttpData content = res.content(); boolean transferredContent = false; try { - final ResponseHeaders headers = res.headers(); - final HttpHeaders trailers = res.trailers(); - if (trailers.isEmpty()) { + final ResponseHeaders resHeaders = res.headers(); + final HttpHeaders resTrailers = res.trailers(); + if (resTrailers.isEmpty()) { if (content.isEmpty()) { - return writeHeaders(id, streamId, headers, true); + return writeHeaders(id, streamId, resHeaders, true); } - writeHeaders(id, streamId, headers, false); + writeHeaders(id, streamId, resHeaders, false); transferredContent = true; return writeData(id, streamId, content, true); } - writeHeaders(id, streamId, headers, false); + writeHeaders(id, streamId, resHeaders, false); if (!content.isEmpty()) { transferredContent = true; writeData(id, streamId, content, false); } - return writeTrailers(id, streamId, trailers); + return writeTrailers(id, streamId, resTrailers); } finally { if (!transferredContent) { content.close(); diff --git a/core/src/test/java/com/linecorp/armeria/server/CustomServerErrorHandlerTest.java b/core/src/test/java/com/linecorp/armeria/server/CustomServerErrorHandlerTest.java index ef77cbfea41..564e3a5b559 100644 --- a/core/src/test/java/com/linecorp/armeria/server/CustomServerErrorHandlerTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/CustomServerErrorHandlerTest.java @@ -36,10 +36,10 @@ import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpMethod; -import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; @@ -48,10 +48,13 @@ class CustomServerErrorHandlerTest { + private static final int MAX_REQUEST_LENGTH = 10; + @RegisterExtension static ServerExtension server = new ServerExtension() { @Override protected void configure(ServerBuilder sb) { + sb.maxRequestLength(MAX_REQUEST_LENGTH); sb.service("/timeout", (ctx, req) -> { ctx.timeoutNow(); return HttpResponse.of(200); @@ -66,43 +69,10 @@ protected void configure(ServerBuilder sb) { new UnsupportedOperationException("Unsupported!")), 100, TimeUnit.MILLISECONDS); return HttpResponse.from(future); }); - sb.errorHandler(new ServerErrorHandler() { - @Override - public @Nullable HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) { - if (cause instanceof RequestTimeoutException) { - return HttpResponse.of(ResponseHeaders.of(HttpStatus.GATEWAY_TIMEOUT), - HttpData.ofUtf8("timeout!"), - HttpHeaders.of("trailer-exists", true)); - } - if (cause instanceof IllegalArgumentException) { - return HttpResponse.of(ResponseHeaders.of(HttpStatus.BAD_REQUEST), - HttpData.ofUtf8(cause.getMessage()), - HttpHeaders.of("trailer-exists", true)); - } - if (cause instanceof UnsupportedOperationException) { - return HttpResponse.of(ResponseHeaders.of(HttpStatus.NOT_IMPLEMENTED), - HttpData.ofUtf8(cause.getMessage()), - HttpHeaders.of("trailer-exists", true)); - } - return null; - } - - @Override - public AggregatedHttpResponse renderStatus(ServiceConfig config, - HttpStatus status, - @Nullable String description, - @Nullable Throwable cause) { - assertThat(config).isNotNull(); - return AggregatedHttpResponse.of( - ResponseHeaders.builder(HttpStatus.BAD_REQUEST) // Always emit 400. - .contentType(MediaType.JSON) - .set("alice", "bob") - .build(), - HttpData.ofUtf8("{\n \"code\": %d,\n \"message\": \"%s\"\n}", - status.code(), firstNonNull(description, "")), - HttpHeaders.of("charlie", "daniel")); - } - }); + sb.service("/post", (ctx, req) -> HttpResponse.from( + req.aggregate().thenApply(aggregated -> HttpResponse.of(HttpStatus.OK)))); + + sb.errorHandler(new CustomServerErrorHandler()); } }; @@ -136,16 +106,72 @@ void logIsCompleteEvenIfResponseContentIsDeferred() throws InterruptedException @ParameterizedTest @CsvSource({ "H1C", "H2C" }) - void protocolErrors(SessionProtocol protocol) { + void unsupportedMethods(SessionProtocol protocol) { + final WebClient client = WebClient.of(server.uri(protocol)); + final AggregatedHttpResponse res1 = client + .execute(RequestHeaders.of(HttpMethod.CONNECT, "/", "user-id", "42")) + .aggregate() + .join(); + assertThat(res1.status()).isSameAs(HttpStatus.BAD_REQUEST); + assertThat(res1.headers()).contains(Maps.immutableEntry(HttpHeaderNames.of("alice"), "bob")); + assertThatJson(res1.content().toStringUtf8()).isEqualTo( + "{ \"code\": 405, \"message\": \"Unsupported method\", \"user-id\": \"42\" }"); + assertThat(res1.trailers()).contains(Maps.immutableEntry(HttpHeaderNames.of("charlie"), "daniel")); + } + + @ParameterizedTest + @CsvSource({ "H1C", "H2C" }) + void tooLargeContent(SessionProtocol protocol) { final WebClient client = WebClient.of(server.uri(protocol)); final AggregatedHttpResponse res1 = client - .execute(HttpRequest.of(HttpMethod.CONNECT, "/")) + .execute(RequestHeaders.of(HttpMethod.POST, "/post", "user-id", "24"), + HttpData.wrap(new byte[MAX_REQUEST_LENGTH + 1])) .aggregate() .join(); assertThat(res1.status()).isSameAs(HttpStatus.BAD_REQUEST); assertThat(res1.headers()).contains(Maps.immutableEntry(HttpHeaderNames.of("alice"), "bob")); assertThatJson(res1.content().toStringUtf8()).isEqualTo( - "{ \"code\": 405, \"message\": \"Unsupported method\" }"); + "{ \"code\": 413, \"message\": \"\", \"user-id\": \"24\" }"); assertThat(res1.trailers()).contains(Maps.immutableEntry(HttpHeaderNames.of("charlie"), "daniel")); } + + private static class CustomServerErrorHandler implements ServerErrorHandler { + @Override + public @Nullable HttpResponse onServiceException(ServiceRequestContext ctx, Throwable cause) { + if (cause instanceof RequestTimeoutException) { + return HttpResponse.of(ResponseHeaders.of(HttpStatus.GATEWAY_TIMEOUT), + HttpData.ofUtf8("timeout!"), + HttpHeaders.of("trailer-exists", true)); + } + if (cause instanceof IllegalArgumentException) { + return HttpResponse.of(ResponseHeaders.of(HttpStatus.BAD_REQUEST), + HttpData.ofUtf8(cause.getMessage()), + HttpHeaders.of("trailer-exists", true)); + } + if (cause instanceof UnsupportedOperationException) { + return HttpResponse.of(ResponseHeaders.of(HttpStatus.NOT_IMPLEMENTED), + HttpData.ofUtf8(cause.getMessage()), + HttpHeaders.of("trailer-exists", true)); + } + return null; + } + + @Override + public AggregatedHttpResponse renderStatus(ServiceConfig config, + @Nullable RequestHeaders headers, + HttpStatus status, + @Nullable String description, + @Nullable Throwable cause) { + assertThat(config).isNotNull(); + return AggregatedHttpResponse.of( + ResponseHeaders.builder(HttpStatus.BAD_REQUEST) // Always emit 400. + .contentType(MediaType.JSON) + .set("alice", "bob") + .build(), + HttpData.ofUtf8("{\n \"code\": %d,\n \"message\": \"%s\",\n \"user-id\": \"%s\"\n}", + status.code(), firstNonNull(description, ""), + headers != null ? headers.get("user-id", "") : ""), + HttpHeaders.of("charlie", "daniel")); + } + } }