Skip to content

Commit

Permalink
Merge pull request #11 from blink1073/add-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored Apr 3, 2022
2 parents c82901e + e488282 commit 77a4f85
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 3 deletions.
16 changes: 14 additions & 2 deletions jupyter_server_terminals/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@
from .base import TerminalsMixin

try:
from jupyter_server.auth import authorized
from jupyter_server.base.handlers import APIHandler
except ModuleNotFoundError:
raise ModuleNotFoundError("Jupyter Server must be installed to use this extension.")


class TerminalRootHandler(TerminalsMixin, APIHandler):
AUTH_RESOURCE = "terminals"


class TerminalAPIHandler(APIHandler):
auth_resource = AUTH_RESOURCE


class TerminalRootHandler(TerminalsMixin, TerminalAPIHandler):
@web.authenticated
@authorized
def get(self):
models = self.terminal_manager.list()
self.finish(json.dumps(models))

@web.authenticated
@authorized
def post(self):
"""POST /terminals creates a new terminal and redirects to it"""
data = self.get_json_body() or {}
Expand All @@ -25,15 +35,17 @@ def post(self):
self.finish(json.dumps(model))


class TerminalHandler(TerminalsMixin, APIHandler):
class TerminalHandler(TerminalsMixin, TerminalAPIHandler):
SUPPORTED_METHODS = ("GET", "DELETE")

@web.authenticated
@authorized
def get(self, name):
model = self.terminal_manager.get(name)
self.finish(json.dumps(model))

@web.authenticated
@authorized
async def delete(self, name):
await self.terminal_manager.terminate(name, force=True)
self.set_status(204)
Expand Down
1 change: 1 addition & 0 deletions jupyter_server_terminals/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def initialize_handlers(self):
)
)
self.handlers.extend(api_handlers.default_handlers)
self.serverapp.web_app.settings["terminal_manager"] = self.terminal_manager

def current_activity(self):
if self.terminals_available:
Expand Down
18 changes: 17 additions & 1 deletion jupyter_server_terminals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,38 @@

try:
from jupyter_server._tz import utcnow
from jupyter_server.auth.utils import warn_disabled_authorization
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.zmqhandlers import WebSocketMixin
except ModuleNotFoundError:
raise ModuleNotFoundError("Jupyter Server must be installed to use this extension.")

AUTH_RESOURCE = "terminals"


class TermSocket(TerminalsMixin, WebSocketMixin, JupyterHandler, terminado.TermSocket):

auth_resource = AUTH_RESOURCE

def origin_check(self):
"""Terminado adds redundant origin_check
Tornado already calls check_origin, so don't do anything here.
"""
return True

def get(self, *args, **kwargs):
if not self.get_current_user():
user = self.current_user

if not user:
raise web.HTTPError(403)

# authorize the user.
if not self.authorizer:
# Warn if an authorizer is unavailable.
warn_disabled_authorization()
elif not self.authorizer.is_authorized(self, user, "execute", self.auth_resource):
raise web.HTTPError(403)

if not args[0] in self.term_manager.terminals:
raise web.HTTPError(404)
return super().get(*args, **kwargs)
Expand Down
174 changes: 174 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Tests for authorization"""

import pytest
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
from jupyter_server.auth.authorizer import Authorizer
from jupyter_server.auth.utils import HTTP_METHOD_TO_AUTH_ACTION, match_url_to_resource
from nbformat import writes
from nbformat.v4 import new_notebook
from tornado.httpclient import HTTPClientError
from tornado.websocket import WebSocketHandler
from traitlets.config import Config


class AuthorizerforTesting(Authorizer):

# Set these class attributes from within a test
# to verify that they match the arguments passed
# by the REST API.
permissions = {}

def normalize_url(self, path):
"""Drop the base URL and make sure path leads with a /"""
base_url = self.parent.base_url
# Remove base_url
if path.startswith(base_url):
path = path[len(base_url) :]
# Make sure path starts with /
if not path.startswith("/"):
path = "/" + path
return path

def is_authorized(self, handler, user, action, resource):
# Parse Request
if isinstance(handler, WebSocketHandler):
method = "WEBSOCKET"
else:
method = handler.request.method
url = self.normalize_url(handler.request.path)

# Map request parts to expected action and resource.
expected_action = HTTP_METHOD_TO_AUTH_ACTION[method]
expected_resource = match_url_to_resource(url)

# Assert that authorization layer returns the
# correct action + resource.
assert action == expected_action
assert resource == expected_resource

# Now, actually apply the authorization layer.
return all(
[
action in self.permissions.get("actions", []),
resource in self.permissions.get("resources", []),
]
)


@pytest.fixture
def jp_server_config():
return Config(
{
"ServerApp": {
"jpserver_extensions": {"jupyter_server_terminals": True},
"authorizer_class": AuthorizerforTesting,
}
}
)


@pytest.fixture
def send_request(jp_fetch, jp_ws_fetch):
"""Send to Jupyter Server and return response code."""

async def _(url, **fetch_kwargs):
if url.endswith("channels") or "/websocket/" in url:
fetch = jp_ws_fetch
else:
fetch = jp_fetch

try:
r = await fetch(url, **fetch_kwargs, allow_nonstandard_methods=True)
code = r.code
except HTTPClientError as err:
code = err.code
else:
if fetch is jp_ws_fetch:
r.close()

print(code, url, fetch_kwargs)
return code

return _


HTTP_REQUESTS = [
{
"method": "POST",
"url": "/api/terminals",
"body": "",
},
{
"method": "GET",
"url": "/api/terminals",
},
{
"method": "GET",
"url": "/terminals/websocket/{term_name}",
},
{
"method": "DELETE",
"url": "/api/terminals/{term_name}",
},
]

HTTP_REQUESTS_PARAMETRIZED = [(req["method"], req["url"], req.get("body")) for req in HTTP_REQUESTS]

# -------- Test scenarios -----------


@pytest.mark.parametrize("method, url, body", HTTP_REQUESTS_PARAMETRIZED)
@pytest.mark.parametrize("allowed", (True, False))
async def test_authorized_requests(
request,
io_loop,
send_request,
tmp_path,
jp_serverapp,
jp_cleanup_subprocesses,
method,
url,
body,
allowed,
):
# Setup stuff for the Contents API
# Add a notebook on disk
contents_dir = tmp_path / jp_serverapp.root_dir
p = contents_dir / "dir_for_testing"
p.mkdir(parents=True, exist_ok=True)

# Create a notebook
nb = writes(new_notebook(), version=4)
nbname = p.joinpath("nb_for_testing.ipynb")
nbname.write_text(nb, encoding="utf-8")

# Setup
nbpath = "dir_for_testing/nb_for_testing.ipynb"
kernelspec = NATIVE_KERNEL_NAME
km = jp_serverapp.kernel_manager

if "terminal" in url:
term_manager = jp_serverapp.web_app.settings["terminal_manager"]
request.addfinalizer(lambda: io_loop.run_sync(term_manager.terminate_all))
term_model = term_manager.create()
term_name = term_model["name"]

url = url.format(**locals())
if allowed:
# Create a server with full permissions
permissions = {
"actions": ["read", "write", "execute"],
"resources": [
"terminals",
],
}
expected_codes = {200, 201, 204, None} # Websockets don't return a code
else:
permissions = {"actions": [], "resources": []}
expected_codes = {403}
jp_serverapp.authorizer.permissions = permissions

code = await send_request(url, body=body, method=method)
assert code in expected_codes

await jp_cleanup_subprocesses()

0 comments on commit 77a4f85

Please sign in to comment.