Skip to content

Commit

Permalink
allow get_user to be async
Browse files Browse the repository at this point in the history
careful to deprecate overridden get_current_user without ignoring auth

Needs some changes due to early steps that are called before prepare,
but must now be moved to prepare due to the reliance on auth info.

- setting CORS headers (set_default_headers)
- check_xsrf_cookie

now that get_user is async, we have to re-run these bits
  • Loading branch information
minrk committed Apr 28, 2022
1 parent 488ea88 commit 0afc6e7
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 14 deletions.
2 changes: 2 additions & 0 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class IdentityProvider(LoggingConfigurable):
"""
Interface for providing identity
_may_ be a coroutine.
Two principle methods:
- :meth:`~.IdentityProvider.get_user` returns a :class:`~.User` object
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def is_token_authenticated(cls, handler):
"""
if getattr(handler, "_user_id", None) is None:
# ensure get_user has been called, so we know if we're token-authenticated
handler.get_current_user()
handler.current_user
return getattr(handler, "_token_authenticated", False)

@classmethod
Expand Down
69 changes: 58 additions & 11 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Distributed under the terms of the Modified BSD License.
import datetime
import functools
import inspect
import ipaddress
import json
import mimetypes
Expand Down Expand Up @@ -134,7 +135,21 @@ def clear_login_cookie(self):
self.force_clear_cookie(self.cookie_name)

def get_current_user(self):
return self.identity_provider.get_user(self)
clsname = self.__class__.__name__
msg = (
f"Calling `{clsname}.get_current_user()` directly is deprecated in jupyter-server 2.0."
" Use `self.current_user` instead (works in all versions)."
)
if hasattr(self, "_jupyter_current_user"):
# backward-compat: return _jupyter_current_user
warnings.warn(
msg,
DeprecationWarning,
stacklevel=2,
)
return self._jupyter_current_user
# haven't called get_user in prepare, raise
raise RuntimeError(msg)

def skip_check_origin(self):
"""Ask my login_handler if I should skip the origin_check
Expand Down Expand Up @@ -164,7 +179,7 @@ def cookie_name(self):
@property
def logged_in(self):
"""Is a user currently logged in?"""
user = self.get_current_user()
user = self.current_user
return user and not user == "anonymous"

@property
Expand Down Expand Up @@ -346,6 +361,13 @@ def allow_credentials(self):
def set_default_headers(self):
"""Add CORS headers, if defined"""
super().set_default_headers()

def set_cors_headers(self):
"""Add CORS headers, if defined
Now that current_user is async (jupyter-server 2.0),
must be called at the end of prepare(), instead of in set_default_headers.
"""
if self.allow_origin:
self.set_header("Access-Control-Allow-Origin", self.allow_origin)
elif self.allow_origin_pat:
Expand Down Expand Up @@ -484,6 +506,9 @@ def check_referer(self):

def check_xsrf_cookie(self):
"""Bypass xsrf cookie checks when token-authenticated"""
if not hasattr(self, "_jupyter_current_user"):
# Called too early, will be checked later
return
if self.token_authenticated or self.settings.get("disable_check_xsrf", False):
# Token-authenticated requests do not need additional XSRF-check
# Servers without authentication are vulnerable to XSRF
Expand Down Expand Up @@ -543,9 +568,39 @@ def check_host(self):
)
return allow

def prepare(self):
async def prepare(self):
if not self.check_host():
raise web.HTTPError(403)

from jupyter_server.auth import IdentityProvider

if (
type(self.identity_provider) is IdentityProvider
and inspect.getmodule(self.get_current_user).__name__ != __name__
):
# check for overridden get_current_user + default IdentityProvider
# deprecated way to override auth (e.g. JupyterHub < 3.0)
# allow deprecated, overridden get_current_user
warnings.warn(
"Overriding JupyterHandler.get_current_user is deprecated in jupyter-server 2.0."
" Use an IdentityProvider class.",
DeprecationWarning,
# stacklevel not useful here
)
user = self.get_current_user()
else:
user = self.identity_provider.get_user(self)
if inspect.isawaitable(user):
# IdentityProvider.get_user _may_ be async
user = await user

# self.current_user for tornado's @web.authenticated
# self._jupyter_current_user for backward-compat in deprecated get_current_user calls
# and our own private checks for whether .current_user has been set
self.current_user = self._jupyter_current_user = user
# complete initial steps which require auth to resolve first:
self.set_cors_headers()
self.check_xsrf_cookie()
return super().prepare()

# ---------------------------------------------------------------
Expand Down Expand Up @@ -663,14 +718,6 @@ def write_error(self, status_code, **kwargs):
self.log.warning(reply["message"])
self.finish(json.dumps(reply))

def get_current_user(self):
"""Raise 403 on API handlers instead of redirecting to human login page"""
# preserve _user_cache so we don't raise more than once
if hasattr(self, "_user_cache"):
return self._user_cache
self._user_cache = user = super().get_current_user()
return user

def get_login_url(self):
# if get_login_url is invoked in an API handler,
# that means @web.authenticated is trying to trigger a redirect.
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def pre_get(self):
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
user = self.get_current_user()
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/gateway/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def authenticate(self):
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
if self.get_current_user() is None:
if self.current_user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

Expand Down

0 comments on commit 0afc6e7

Please sign in to comment.