From dccc42385a31f7e5847ddcad30e9620a88a3edd9 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Mon, 12 Feb 2024 14:34:13 +0000 Subject: [PATCH] Add runtime check to ensure handlers have auth decorators Also adds `@ws_authenticated` to make the check simpler --- jupyter_server/auth/decorator.py | 28 +++++++++ jupyter_server/base/handlers.py | 30 +++++++--- jupyter_server/base/websocket.py | 1 + jupyter_server/extension/application.py | 2 +- jupyter_server/serverapp.py | 60 ++++++++++++++++++++ jupyter_server/services/api/handlers.py | 5 ++ jupyter_server/services/events/handlers.py | 10 +--- jupyter_server/services/kernels/websocket.py | 6 +- 8 files changed, 121 insertions(+), 21 deletions(-) diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py index 4ac8a75689..0d13a7a0df 100644 --- a/jupyter_server/auth/decorator.py +++ b/jupyter_server/auth/decorator.py @@ -112,3 +112,31 @@ def wrapper(self, *args, **kwargs): setattr(wrapper, "__allow_unauthenticated", True) return cast(FuncT, wrapper) + + +def ws_authenticated(method: FuncT) -> FuncT: + """A decorator for websockets derived from `WebSocketHandler` + that authenticates user before allowing to proceed. + + Differently from tornado.web.authenticated, does not redirect + to the login page, which would be meaningless for websockets. + + .. versionadded:: 2.13 + + Parameters + ---------- + method : bound callable + the endpoint method to add authentication for. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + user = self.current_user + if user is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise HTTPError(403) + return method(self, *args, **kwargs) + + setattr(wrapper, "__allow_unauthenticated", False) + + return cast(FuncT, wrapper) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index c88237cf35..f27eb9f284 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -14,7 +14,7 @@ import warnings from http.client import responses from logging import Logger -from typing import TYPE_CHECKING, Any, Awaitable, Sequence, cast +from typing import TYPE_CHECKING, Any, Awaitable, Coroutine, Sequence, cast from urllib.parse import urlparse import prometheus_client @@ -1016,14 +1016,14 @@ def compute_etag(self) -> str | None: # access is allowed as this class is used to serve static assets on login page # TODO: create an allow-list of files used on login page and remove this decorator @allow_unauthenticated - def get(self, *args, **kwargs) -> None: - return super().get(*args, **kwargs) + def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]: + return super().get(path, include_body) # access is allowed as this class is used to serve static assets on login page # TODO: create an allow-list of files used on login page and remove this decorator @allow_unauthenticated - def head(self, *args, **kwargs) -> None: - return super().head(*args, **kwargs) + def head(self, path: str) -> Awaitable[None]: + return super().head(path) @classmethod def get_absolute_path(cls, roots: Sequence[str], path: str) -> str: @@ -1072,7 +1072,7 @@ class TrailingSlashHandler(web.RequestHandler): This should be the first, highest priority handler. """ - # does not require `allow_unauthenticated` (inherits from `web.RequestHandler`) + @allow_unauthenticated def get(self) -> None: """Handle trailing slashes in a get.""" assert self.request.uri is not None @@ -1136,14 +1136,14 @@ async def get(self, path: str = "") -> None: class RedirectWithParams(web.RequestHandler): - """Sam as web.RedirectHandler, but preserves URL parameters""" + """Same as web.RedirectHandler, but preserves URL parameters""" def initialize(self, url: str, permanent: bool = True) -> None: """Initialize a redirect handler.""" self._url = url self._permanent = permanent - # does not require `allow_unauthenticated` (inherits from `web.RequestHandler`) + @allow_unauthenticated def get(self) -> None: """Get a redirect.""" sep = "&" if "?" in self._url else "?" @@ -1166,6 +1166,18 @@ def get(self) -> None: self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY)) +class PublicStaticFileHandler(web.StaticFileHandler): + """Same as web.StaticFileHandler, but decorated to acknowledge that auth is not required.""" + + @allow_unauthenticated + def head(self, path: str) -> Awaitable[None]: + return super().head(path) + + @allow_unauthenticated + def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]: + return super().get(path, include_body) + + # ----------------------------------------------------------------------------- # URL pattern fragments for reuse # ----------------------------------------------------------------------------- @@ -1181,6 +1193,6 @@ def get(self) -> None: default_handlers = [ (r".*/", TrailingSlashHandler), (r"api", APIVersionHandler), - (r"/(robots\.txt|favicon\.ico)", web.StaticFileHandler), + (r"/(robots\.txt|favicon\.ico)", PublicStaticFileHandler), (r"/metrics", PrometheusMetricsHandler), ] diff --git a/jupyter_server/base/websocket.py b/jupyter_server/base/websocket.py index c7bd311dd5..a424d06376 100644 --- a/jupyter_server/base/websocket.py +++ b/jupyter_server/base/websocket.py @@ -97,6 +97,7 @@ def _maybe_auth(self): self.log.warning("Couldn't authenticate WebSocket connection") raise web.HTTPError(403) + @no_type_check def prepare(self, *args, **kwargs): """Handle a get request.""" self._maybe_auth() diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index aeeab5a94d..b676086747 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -358,7 +358,7 @@ def _prepare_handlers(self): ) new_handlers.append(handler) - webapp.add_handlers(".*$", new_handlers) # type:ignore[arg-type] + webapp.add_handlers(".*$", new_handlers) def _prepare_templates(self): """Add templates to web app settings if extension has templates.""" diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 16f73863dc..597838912e 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -41,6 +41,7 @@ from tornado.httputil import url_concat from tornado.log import LogFormatter, access_log, app_log, gen_log from tornado.netutil import bind_sockets +from tornado.routing import Matcher, Rule if not sys.platform.startswith("win"): from tornado.netutil import bind_unix_socket @@ -280,8 +281,52 @@ def __init__( ) handlers = self.init_handlers(default_services, settings) + undecorated_methods = [] + for matcher, handler, *_ in handlers: + undecorated_methods.extend(self._check_handler_auth(matcher, handler)) + + if undecorated_methods: + message = ( + "Core endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:\n" + + "\n".join(undecorated_methods) + ) + if jupyter_app.allow_unauthenticated_access: + warnings.warn( + message, + RuntimeWarning, + stacklevel=2, + ) + else: + raise Exception(message) + super().__init__(handlers, **settings) + def add_handlers(self, host_pattern, host_handlers): + undecorated_methods = [] + for rule in host_handlers: + if isinstance(rule, Rule): + matcher = rule.matcher + handler = rule.target + else: + matcher, handler, *_ = rule + undecorated_methods.extend(self._check_handler_auth(matcher, handler)) + + if undecorated_methods: + message = ( + "Extension endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:\n" + + "\n".join(undecorated_methods) + ) + if self.settings["allow_unauthenticated_access"]: + warnings.warn( + message, + RuntimeWarning, + stacklevel=2, + ) + else: + raise Exception(message) + + return super().add_handlers(host_pattern, host_handlers) + def init_settings( self, jupyter_app, @@ -487,6 +532,21 @@ def last_activity(self): sources.extend(self.settings["last_activity_times"].values()) return max(sources) + def _check_handler_auth(self, matcher: t.Union[str, Matcher], handler: web.RequestHandler): + missing_authentication = [] + for method_name in handler.SUPPORTED_METHODS: + method = getattr(handler, method_name.lower()) + is_unimplemented = method == web.RequestHandler._unimplemented_method + is_allowlisted = hasattr(method, "__allow_unauthenticated") + possibly_blocklisted = hasattr( + method, "__wrapped__" + ) # TODO: can we make web.auth leave a better footprint? + if not is_unimplemented and not is_allowlisted and not possibly_blocklisted: + missing_authentication.append( + f"- {method_name} of {handler.__class__.__name__} registered for {matcher}" + ) + return missing_authentication + class JupyterPasswordApp(JupyterApp): """Set a password for the Jupyter server. diff --git a/jupyter_server/services/api/handlers.py b/jupyter_server/services/api/handlers.py index 8b9e44f9cf..625d9ca372 100644 --- a/jupyter_server/services/api/handlers.py +++ b/jupyter_server/services/api/handlers.py @@ -25,6 +25,11 @@ def initialize(self): """Initialize the API spec handler.""" web.StaticFileHandler.initialize(self, path=os.path.dirname(__file__)) + @web.authenticated + @authorized + def head(self): + return self.get("api.yaml", include_body=False) + @web.authenticated @authorized def get(self): diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index 1ca28b948c..ce580048f2 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -12,7 +12,7 @@ from jupyter_core.utils import ensure_async from tornado import web, websocket -from jupyter_server.auth.decorator import authorized +from jupyter_server.auth.decorator import authorized, ws_authenticated from jupyter_server.base.handlers import JupyterHandler from ...base.handlers import APIHandler @@ -29,16 +29,11 @@ class SubscribeWebsocket( auth_resource = AUTH_RESOURCE async def pre_get(self): - """Handles authentication/authorization when + """Handles authorization when attempting to subscribe to events emitted by Jupyter Server's eventbus. """ - # authenticate the request before opening the websocket user = self.current_user - if user is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) - # authorize the user. authorized = await ensure_async( self.authorizer.is_authorized(self, user, "execute", "events") @@ -46,6 +41,7 @@ async def pre_get(self): if not authorized: raise web.HTTPError(403) + @ws_authenticated async def get(self, *args, **kwargs): """Get an event socket.""" await ensure_async(self.pre_get()) diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 4c2c1c8914..374df76f3e 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -6,6 +6,7 @@ from tornado import web from tornado.websocket import WebSocketHandler +from jupyter_server.auth.decorator import ws_authenticated from jupyter_server.base.handlers import JupyterHandler from jupyter_server.base.websocket import WebSocketMixin @@ -34,11 +35,7 @@ def get_compression_options(self): async def pre_get(self): """Handle a pre_get.""" - # authenticate first user = self.current_user - if user is None: - self.log.warning("Couldn't authenticate WebSocket connection") - raise web.HTTPError(403) # authorize the user. authorized = await ensure_async( @@ -61,6 +58,7 @@ async def pre_get(self): if hasattr(self.connection, "prepare"): await self.connection.prepare() + @ws_authenticated async def get(self, kernel_id): """Handle a get request for a kernel.""" self.kernel_id = kernel_id