Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow serving REST APIs as part of panel serve #1164

Merged
merged 10 commits into from
Jul 29, 2020
31 changes: 31 additions & 0 deletions panel/command/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@

import ast
import base64
import logging # isort:skip
import os

from glob import glob

from bokeh.command.subcommands.serve import Serve as _BkServe

from ..auth import OAuthProvider
from ..config import config
from ..io.rest import REST_PROVIDERS
from ..io.server import INDEX_HTML, get_static_routes
from ..io.state import state

log = logging.getLogger(__name__)


def parse_var(s):
"""
Expand Down Expand Up @@ -83,6 +89,17 @@ class Serve(_BkServe):
action = 'store',
type = str,
help = "A random string used to encode the user information."
)),
('--rest-provider', dict(
action = 'store',
type = str,
help = "The interface to use to serve REST API"
)),
('--rest-endpoint', dict(
action = 'store',
type = str,
help = "Endpoint to store REST API on.",
default = 'rest'
))
)

Expand All @@ -103,6 +120,20 @@ def customize_kwargs(self, args, server_kwargs):
static_dirs['panel_dist'] = os.path.join(os.path.dirname(os.path.split(__file__)[0]), 'dist')
patterns += get_static_routes(static_dirs)

files = []
for f in args.files:
if args.glob:
files.extend(glob(f))
else:
files.append(f)

# Handle tranquilized functions in the supplied functions
if args.rest_provider in REST_PROVIDERS:
pattern = REST_PROVIDERS[args.rest_provider](files, args.rest_endpoint)
patterns.extend(pattern)
elif args.rest_provider is not None:
raise ValueError("rest-provider %r not recognized." % args.rest_provider)

if args.oauth_provider:
config.oauth_provider = args.oauth_provider
if config.oauth_key and args.oauth_key:
Expand Down
174 changes: 174 additions & 0 deletions panel/io/rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import json
import os
import pkg_resources
import tempfile
import traceback

from runpy import run_path
from unittest.mock import MagicMock
from urllib.parse import parse_qs

from tornado import web
from tornado.wsgi import WSGIContainer

from .state import state


class HTTPError(web.HTTPError):
"""
Custom HTTPError type
"""


class BaseHandler(web.RequestHandler):

def write_error(self, status_code, **kwargs):
self.set_header('Content-Type', 'application/json')
if self.settings.get("serve_traceback") and "exc_info" in kwargs:
# in debug mode, try to send a traceback
lines = []
for line in traceback.format_exception(*kwargs["exc_info"]):
lines.append(line)
self.finish(json.dumps({
'error': {
'code': status_code,
'message': self._reason,
'traceback': lines,
}
}))
else:
self.finish(json.dumps({
'error': {
'code': status_code,
'message': self._reason,
}
}))

class ParamHandler(BaseHandler):

def __init__(self, app, request, **kwargs):
self.root = kwargs.pop('root', None)
super().__init__(app, request, **kwargs)

@classmethod
def serialize(cls, parameterized, parameters):
values = {p: getattr(parameterized, p) for p in parameters}
return parameterized.param.serialize_parameters(values)

@classmethod
def deserialize(cls, parameterized, parameters):
for p in parameters:
if p not in parameterized.param:
reason = f"'{p}' query parameter not recognized."
raise HTTPError(reason=reason, status_code=400)
return {p: parameterized.param.deserialize_value(p, v)
for p, v in parameters.items()}

async def get(self):
path = self.request.path
endpoint = path[path.index(self.root)+len(self.root):]
parameterized, parameters, _ = state._rest_endpoints.get(
endpoint, (None, None, None)
)
if not parameterized:
return
args = parse_qs(self.request.query)
params = self.deserialize(parameterized[0], args)
parameterized[0].param.set_param(**params)
self.set_header('Content-Type', 'application/json')
self.write(self.serialize(parameterized[0], parameters))


def build_tranquilize_application(files):
from tranquilizer.handler import ScriptHandler, NotebookHandler
from tranquilizer.main import make_app, UnsupportedFileType

functions = []
for filename in files:
extension = filename.split('.')[-1]
if extension == 'py':
source = ScriptHandler(filename)
elif extension == 'ipynb':
try:
import nbconvert # noqa
except ImportError as e: # pragma no cover
raise ImportError("Please install nbconvert to serve Jupyter Notebooks.") from e

source = NotebookHandler(filename)
else:
raise UnsupportedFileType('{} is not a script (.py) or notebook (.ipynb)'.format(filename))
functions.extend(source.tranquilized_functions)
return make_app(functions, 'Panel REST API', prefix='rest/')


def tranquilizer_rest_provider(files, endpoint):
"""
Returns a Tranquilizer based REST API. Builds the API by evaluating
the scripts and notebooks being served and finding all tranquilized
functions inside them.

Arguments
---------
files: list(str)
A list of paths being served
endpoint: str
The endpoint to serve the REST API on

Returns
-------
A Tornado routing pattern containing the route and handler
"""
app = build_tranquilize_application(files)
tr = WSGIContainer(app)
return [(r"^/%s/.*" % endpoint, web.FallbackHandler, dict(fallback=tr))]


def param_rest_provider(files, endpoint):
"""
Returns a Param based REST API given the scripts or notebooks
containing the tranquilized functions.

Arguments
---------
files: list(str)
A list of paths being served
endpoint: str
The endpoint to serve the REST API on

Returns
-------
A Tornado routing pattern containing the route and handler
"""
for filename in files:
extension = filename.split('.')[-1]
if extension == 'py':
run_path(filename)
elif extension == 'ipynb':
try:
import nbconvert # noqa
except ImportError:
raise ImportError("Please install nbconvert to serve Jupyter Notebooks.")
from nbconvert import ScriptExporter
exporter = ScriptExporter()
source, _ = exporter.from_filename(filename)
source_dir = os.path.dirname(filename)
with tempfile.NamedTemporaryFile(mode='w', dir=source_dir, delete=True) as tmp:
tmp.write(source)
tmp.flush()
run_path(tmp.name, init_globals={'get_ipython': MagicMock()})
else:
raise ValueError('{} is not a script (.py) or notebook (.ipynb)'.format(filename))

if endpoint and not endpoint.endswith('/'):
endpoint += '/'
return [((r"^/%s.*" % endpoint if endpoint else r"^.*"), ParamHandler, dict(root=endpoint))]


REST_PROVIDERS = {
'tranquilizer': tranquilizer_rest_provider,
'param': param_rest_provider
}

# Populate REST Providers from external extensions
for entry_point in pkg_resources.iter_entry_points('panel.io.rest'):
REST_PROVIDERS[entry_point.name] = entry_point.resolve()
100 changes: 80 additions & 20 deletions panel/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ class _state(param.Parameterized):
# Stores a set of locked Websockets, reset after every change event
_locks = WeakSet()

# Indicators listening to the busy state
_indicators = []

# Endpoints
_rest_endpoints = {}

def __repr__(self):
server_info = []
for server, panel, docs in self._servers.values():
Expand All @@ -95,30 +99,37 @@ def _update_busy(self):
for indicator in self._indicators:
indicator.value = self.busy

def _get_callback(self, endpoint):
_updating = {}
def link(*events):
event = events[0]
obj = event.cls if event.obj is None else event.obj
parameterizeds = self._rest_endpoints[endpoint][0]
if obj not in parameterizeds:
return
updating = _updating.get(id(obj), [])
values = {event.name: event.new for event in events
if event.name not in updating}
if not values:
return
_updating[id(obj)] = list(values)
for parameterized in parameterizeds:
if parameterized in _updating:
continue
try:
parameterized.param.set_param(**values)
except Exception:
raise
finally:
if id(obj) in _updating:
not_updated = [p for p in _updating[id(obj)] if p not in values]
_updating[id(obj)] = not_updated
return link

#----------------------------------------------------------------
# Public Methods
#----------------------------------------------------------------

def kill_all_servers(self):
"""Stop all servers and clear them from the current state."""
for server_id in self._servers:
try:
self._servers[server_id][0].stop()
except AssertionError: # can't stop a server twice
pass
self._servers = {}

def onload(self, callback):
"""
Callback that is triggered when a session has been served.
"""
if self.curdoc is None:
callback()
return
if self.curdoc not in self._onload:
self._onload[self.curdoc] = []
self._onload[self.curdoc].append(callback)

def add_periodic_callback(self, callback, period=500, count=None,
timeout=None, start=True):
"""
Expand Down Expand Up @@ -151,6 +162,55 @@ def add_periodic_callback(self, callback, period=500, count=None,
cb.start()
return cb

def kill_all_servers(self):
"""Stop all servers and clear them from the current state."""
for server_id in self._servers:
try:
self._servers[server_id][0].stop()
except AssertionError: # can't stop a server twice
pass
self._servers = {}

def onload(self, callback):
"""
Callback that is triggered when a session has been served.
"""
if self.curdoc is None:
callback()
return
if self.curdoc not in self._onload:
self._onload[self.curdoc] = []
self._onload[self.curdoc].append(callback)

def publish(self, endpoint, parameterized, parameters=None):
"""
Publish parameters on a Parameterized object as a REST API.

Arguments
---------
endpoint: str
The endpoint at which to serve the REST API.
parameterized: param.Parameterized
The Parameterized object to publish parameters from.
parameters: list(str) or None
A subset of parameters on the Parameterized to publish.
"""
if parameters is None:
parameters = list(parameterized.param)
if endpoint.startswith('/'):
endpoint = endpoint[1:]
if endpoint in self._rest_endpoints:
parameterizeds, old_parameters, cb = self._rest_endpoints[endpoint]
if set(parameters) != set(old_parameters):
raise ValueError("Param REST API output parameters must match across sessions.")
values = {k: v for k, v in parameterizeds[0].param.get_param_values() if k in parameters}
parameterized.param.set_param(**values)
parameterizeds.append(parameterized)
else:
cb = self._get_callback(endpoint)
self._rest_endpoints[endpoint] = ([parameterized], parameters, cb)
parameterized.param.watch(cb, parameters)

def sync_busy(self, indicator):
"""
Syncs the busy state with an indicator with a boolean value
Expand Down
3 changes: 3 additions & 0 deletions panel/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import defaultdict, OrderedDict
from contextlib import contextmanager
from datetime import datetime
from distutils.version import LooseVersion
from six import string_types

try: # python >= 3.3
Expand All @@ -25,6 +26,7 @@

from html import escape # noqa

import bokeh
import param
import numpy as np

Expand All @@ -33,6 +35,7 @@
if sys.version_info.major > 2:
unicode = str

bokeh_version = LooseVersion(bokeh.__version__)

def isfile(path):
"""Safe version of os.path.isfile robust to path length issues on Windows"""
Expand Down