diff --git a/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebFluxGraphQlAutoConfiguration.java b/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebFluxGraphQlAutoConfiguration.java index 9279137..5b243f7 100644 --- a/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebFluxGraphQlAutoConfiguration.java +++ b/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebFluxGraphQlAutoConfiguration.java @@ -22,7 +22,6 @@ import graphql.schema.idl.SchemaPrinter; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import reactor.core.publisher.Mono; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfigureAfter; @@ -48,7 +47,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerResponse; import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; -import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.reactive.socket.server.support.WebSocketUpgradeHandlerPredicate; import static org.springframework.web.reactive.function.server.RequestPredicates.accept; import static org.springframework.web.reactive.function.server.RequestPredicates.contentType; @@ -115,20 +114,11 @@ public HandlerMapping graphQlWebSocketEndpoint( if (logger.isInfoEnabled()) { logger.info("GraphQL endpoint WebSocket " + path); } - WebSocketHandlerMapping handlerMapping = new WebSocketHandlerMapping(); - handlerMapping.setUrlMap(Collections.singletonMap(path, graphQlWebSocketHandler)); - handlerMapping.setOrder(-2); // Ahead of HTTP endpoint ("routerFunctionMapping" bean) - return handlerMapping; - } - } - - - private static class WebSocketHandlerMapping extends SimpleUrlHandlerMapping { - - @Override - public Mono getHandlerInternal(ServerWebExchange exchange) { - return ("WebSocket".equalsIgnoreCase(exchange.getRequest().getHeaders().getUpgrade()) ? - super.getHandlerInternal(exchange) : Mono.empty()); + SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping(); + mapping.setHandlerPredicate(new WebSocketUpgradeHandlerPredicate()); + mapping.setUrlMap(Collections.singletonMap(path, graphQlWebSocketHandler)); + mapping.setOrder(-2); // Ahead of HTTP endpoint ("routerFunctionMapping" bean) + return mapping; } } diff --git a/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java b/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java index 941225a..17f4307 100644 --- a/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java +++ b/graphql-spring-boot-starter/src/main/java/org/springframework/graphql/boot/WebMvcGraphQlAutoConfiguration.java @@ -19,7 +19,6 @@ import java.util.Map; import java.util.stream.Collectors; -import javax.servlet.http.HttpServletRequest; import javax.websocket.server.ServerContainer; import graphql.GraphQL; @@ -45,16 +44,15 @@ import org.springframework.graphql.web.WebInterceptor; import org.springframework.graphql.web.webmvc.GraphQlHttpHandler; import org.springframework.graphql.web.webmvc.GraphQlWebSocketHandler; -import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerResponse; -import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; +import org.springframework.web.socket.server.support.WebSocketHandlerMapping; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import static org.springframework.web.servlet.function.RequestPredicates.accept; @@ -127,23 +125,16 @@ public HandlerMapping graphQlWebSocketMapping(GraphQlWebSocketHandler handler, G if (logger.isInfoEnabled()) { logger.info("GraphQL endpoint WebSocket " + path); } - WebSocketHandlerMapping handlerMapping = new WebSocketHandlerMapping(); - handlerMapping.setUrlMap(Collections.singletonMap(path, + WebSocketHandlerMapping mapping = new WebSocketHandlerMapping(); + mapping.setWebSocketUpgradeMatch(true); + mapping.setUrlMap(Collections.singletonMap(path, new WebSocketHttpRequestHandler(handler, new DefaultHandshakeHandler()))); - handlerMapping.setOrder(2); // Ahead of HTTP endpoint ("routerFunctionMapping" bean) - return handlerMapping; + mapping.setOrder(2); // Ahead of HTTP endpoint ("routerFunctionMapping" bean) + return mapping; } } - private static class WebSocketHandlerMapping extends SimpleUrlHandlerMapping { - - @Override - protected Object getHandlerInternal(HttpServletRequest request) throws Exception { - return ("WebSocket".equalsIgnoreCase(request.getHeader(HttpHeaders.UPGRADE)) ? - super.getHandlerInternal(request) : null); - } - } }