Skip to content

Commit

Permalink
Add runtime check to ensure handlers have auth decorators
Browse files Browse the repository at this point in the history
Also adds `@ws_authenticated` to make the check simpler
  • Loading branch information
krassowski committed Feb 12, 2024
1 parent 204d29f commit dccc423
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 21 deletions.
28 changes: 28 additions & 0 deletions jupyter_server/auth/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,31 @@ def wrapper(self, *args, **kwargs):
setattr(wrapper, "__allow_unauthenticated", True)

return cast(FuncT, wrapper)


def ws_authenticated(method: FuncT) -> FuncT:
"""A decorator for websockets derived from `WebSocketHandler`
that authenticates user before allowing to proceed.
Differently from tornado.web.authenticated, does not redirect
to the login page, which would be meaningless for websockets.
.. versionadded:: 2.13
Parameters
----------
method : bound callable
the endpoint method to add authentication for.
"""

@wraps(method)
def wrapper(self, *args, **kwargs):
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise HTTPError(403)
return method(self, *args, **kwargs)

setattr(wrapper, "__allow_unauthenticated", False)

return cast(FuncT, wrapper)
30 changes: 21 additions & 9 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings
from http.client import responses
from logging import Logger
from typing import TYPE_CHECKING, Any, Awaitable, Sequence, cast
from typing import TYPE_CHECKING, Any, Awaitable, Coroutine, Sequence, cast
from urllib.parse import urlparse

import prometheus_client
Expand Down Expand Up @@ -1016,14 +1016,14 @@ def compute_etag(self) -> str | None:
# access is allowed as this class is used to serve static assets on login page
# TODO: create an allow-list of files used on login page and remove this decorator
@allow_unauthenticated
def get(self, *args, **kwargs) -> None:
return super().get(*args, **kwargs)
def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]:
return super().get(path, include_body)

# access is allowed as this class is used to serve static assets on login page
# TODO: create an allow-list of files used on login page and remove this decorator
@allow_unauthenticated
def head(self, *args, **kwargs) -> None:
return super().head(*args, **kwargs)
def head(self, path: str) -> Awaitable[None]:
return super().head(path)

@classmethod
def get_absolute_path(cls, roots: Sequence[str], path: str) -> str:
Expand Down Expand Up @@ -1072,7 +1072,7 @@ class TrailingSlashHandler(web.RequestHandler):
This should be the first, highest priority handler.
"""

# does not require `allow_unauthenticated` (inherits from `web.RequestHandler`)
@allow_unauthenticated
def get(self) -> None:
"""Handle trailing slashes in a get."""
assert self.request.uri is not None
Expand Down Expand Up @@ -1136,14 +1136,14 @@ async def get(self, path: str = "") -> None:


class RedirectWithParams(web.RequestHandler):
"""Sam as web.RedirectHandler, but preserves URL parameters"""
"""Same as web.RedirectHandler, but preserves URL parameters"""

def initialize(self, url: str, permanent: bool = True) -> None:
"""Initialize a redirect handler."""
self._url = url
self._permanent = permanent

# does not require `allow_unauthenticated` (inherits from `web.RequestHandler`)
@allow_unauthenticated
def get(self) -> None:
"""Get a redirect."""
sep = "&" if "?" in self._url else "?"
Expand All @@ -1166,6 +1166,18 @@ def get(self) -> None:
self.write(prometheus_client.generate_latest(prometheus_client.REGISTRY))


class PublicStaticFileHandler(web.StaticFileHandler):
"""Same as web.StaticFileHandler, but decorated to acknowledge that auth is not required."""

@allow_unauthenticated
def head(self, path: str) -> Awaitable[None]:
return super().head(path)

@allow_unauthenticated
def get(self, path: str, include_body: bool = True) -> Coroutine[Any, Any, None]:
return super().get(path, include_body)


# -----------------------------------------------------------------------------
# URL pattern fragments for reuse
# -----------------------------------------------------------------------------
Expand All @@ -1181,6 +1193,6 @@ def get(self) -> None:
default_handlers = [
(r".*/", TrailingSlashHandler),
(r"api", APIVersionHandler),
(r"/(robots\.txt|favicon\.ico)", web.StaticFileHandler),
(r"/(robots\.txt|favicon\.ico)", PublicStaticFileHandler),
(r"/metrics", PrometheusMetricsHandler),
]
1 change: 1 addition & 0 deletions jupyter_server/base/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _maybe_auth(self):
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

@no_type_check
def prepare(self, *args, **kwargs):
"""Handle a get request."""
self._maybe_auth()
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/extension/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _prepare_handlers(self):
)
new_handlers.append(handler)

webapp.add_handlers(".*$", new_handlers) # type:ignore[arg-type]
webapp.add_handlers(".*$", new_handlers)

def _prepare_templates(self):
"""Add templates to web app settings if extension has templates."""
Expand Down
60 changes: 60 additions & 0 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tornado.httputil import url_concat
from tornado.log import LogFormatter, access_log, app_log, gen_log
from tornado.netutil import bind_sockets
from tornado.routing import Matcher, Rule

if not sys.platform.startswith("win"):
from tornado.netutil import bind_unix_socket
Expand Down Expand Up @@ -280,8 +281,52 @@ def __init__(
)
handlers = self.init_handlers(default_services, settings)

undecorated_methods = []
for matcher, handler, *_ in handlers:
undecorated_methods.extend(self._check_handler_auth(matcher, handler))

if undecorated_methods:
message = (
"Core endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:\n"
+ "\n".join(undecorated_methods)
)
if jupyter_app.allow_unauthenticated_access:
warnings.warn(
message,
RuntimeWarning,
stacklevel=2,
)
else:
raise Exception(message)

super().__init__(handlers, **settings)

def add_handlers(self, host_pattern, host_handlers):
undecorated_methods = []
for rule in host_handlers:
if isinstance(rule, Rule):
matcher = rule.matcher
handler = rule.target
else:
matcher, handler, *_ = rule
undecorated_methods.extend(self._check_handler_auth(matcher, handler))

if undecorated_methods:
message = (
"Extension endpoints without @allow_unauthenticated, @ws_authenticated, nor @web.authenticated:\n"
+ "\n".join(undecorated_methods)
)
if self.settings["allow_unauthenticated_access"]:
warnings.warn(
message,
RuntimeWarning,
stacklevel=2,
)
else:
raise Exception(message)

return super().add_handlers(host_pattern, host_handlers)

def init_settings(
self,
jupyter_app,
Expand Down Expand Up @@ -487,6 +532,21 @@ def last_activity(self):
sources.extend(self.settings["last_activity_times"].values())
return max(sources)

def _check_handler_auth(self, matcher: t.Union[str, Matcher], handler: web.RequestHandler):
missing_authentication = []
for method_name in handler.SUPPORTED_METHODS:
method = getattr(handler, method_name.lower())
is_unimplemented = method == web.RequestHandler._unimplemented_method
is_allowlisted = hasattr(method, "__allow_unauthenticated")
possibly_blocklisted = hasattr(
method, "__wrapped__"
) # TODO: can we make web.auth leave a better footprint?
if not is_unimplemented and not is_allowlisted and not possibly_blocklisted:
missing_authentication.append(
f"- {method_name} of {handler.__class__.__name__} registered for {matcher}"
)
return missing_authentication


class JupyterPasswordApp(JupyterApp):
"""Set a password for the Jupyter server.
Expand Down
5 changes: 5 additions & 0 deletions jupyter_server/services/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def initialize(self):
"""Initialize the API spec handler."""
web.StaticFileHandler.initialize(self, path=os.path.dirname(__file__))

@web.authenticated
@authorized
def head(self):
return self.get("api.yaml", include_body=False)

@web.authenticated
@authorized
def get(self):
Expand Down
10 changes: 3 additions & 7 deletions jupyter_server/services/events/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jupyter_core.utils import ensure_async
from tornado import web, websocket

from jupyter_server.auth.decorator import authorized
from jupyter_server.auth.decorator import authorized, ws_authenticated
from jupyter_server.base.handlers import JupyterHandler

from ...base.handlers import APIHandler
Expand All @@ -29,23 +29,19 @@ class SubscribeWebsocket(
auth_resource = AUTH_RESOURCE

async def pre_get(self):
"""Handles authentication/authorization when
"""Handles authorization when
attempting to subscribe to events emitted by
Jupyter Server's eventbus.
"""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

# authorize the user.
authorized = await ensure_async(
self.authorizer.is_authorized(self, user, "execute", "events")
)
if not authorized:
raise web.HTTPError(403)

@ws_authenticated
async def get(self, *args, **kwargs):
"""Get an event socket."""
await ensure_async(self.pre_get())
Expand Down
6 changes: 2 additions & 4 deletions jupyter_server/services/kernels/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tornado import web
from tornado.websocket import WebSocketHandler

from jupyter_server.auth.decorator import ws_authenticated
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.websocket import WebSocketMixin

Expand Down Expand Up @@ -34,11 +35,7 @@ def get_compression_options(self):

async def pre_get(self):
"""Handle a pre_get."""
# authenticate first
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

# authorize the user.
authorized = await ensure_async(
Expand All @@ -61,6 +58,7 @@ async def pre_get(self):
if hasattr(self.connection, "prepare"):
await self.connection.prepare()

@ws_authenticated
async def get(self, kernel_id):
"""Handle a get request for a kernel."""
self.kernel_id = kernel_id
Expand Down

0 comments on commit dccc423

Please sign in to comment.