Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
More streamlit goodness (#602)
Browse files Browse the repository at this point in the history
Added debug mode (update app on template change) and options for page
title and descriptions
  • Loading branch information
mike0sv authored Feb 8, 2023
1 parent 551afe4 commit f6b5c14
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
7 changes: 6 additions & 1 deletion mlem/contrib/streamlit/_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from mlem.contrib.streamlit.utils import model_form
from mlem.runtime.client import HTTPClient

streamlit.set_page_config(
page_title="{{page_title}}",
)


@streamlit.cache(hash_funcs={HTTPClient: lambda x: 0})
def get_client():
Expand All @@ -11,7 +15,8 @@ def get_client():
)


streamlit.title("MLEM Streamlit UI")
streamlit.title("{{title}}")
streamlit.write("""{{description}}""")
model_form(get_client())
streamlit.markdown("---")
streamlit.write(
Expand Down
45 changes: 33 additions & 12 deletions mlem/contrib/streamlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import subprocess
import tempfile
from threading import Thread
from time import sleep
from typing import ClassVar, Dict, Optional

Expand All @@ -26,6 +27,9 @@ class StreamlitScript(TemplateModel):

server_host: str = "0.0.0.0"
server_port: str = "8080"
page_title: str = "MLEM Streamlit UI"
title: str = "MLEM Streamlit UI"
description: str = ""


class StreamlitServer(Server, LibRequirementsMixin):
Expand All @@ -50,38 +54,55 @@ class StreamlitServer(Server, LibRequirementsMixin):
"""Path to alternative template for streamlit app"""
standardize: bool = False # changing default for streamlit
"""Use standard model interface"""
debug: bool = False
"""Update app on template change"""

def serve(self, interface: Interface):
with self.prepare_streamlit_script():
with self.prepare_streamlit_script() as path:
if self.run_server:
from mlem.contrib.fastapi import FastAPIServer

if self.debug:
Thread(
target=lambda: self._idle(path), daemon=True
).start()
FastAPIServer(
host=self.server_host,
port=self.server_port,
standardize=self.standardize,
debug=self.debug,
).serve(interface)
else:
while True:
sleep(100)
self._idle(path)

def _idle(self, path):
while True:
sleep(1)
if self.debug:
self._write_streamlit_script(path)

def _write_streamlit_script(self, path):
templates_dir = []
dirname = os.path.dirname(path)
if self.template is not None:
shutil.copy(self.template, os.path.join(dirname, TEMPLATE_PY))
templates_dir = [dirname]
StreamlitScript(
server_host=self.server_host,
server_port=str(self.server_port),
templates_dir=templates_dir,
).write(path)

@contextlib.contextmanager
def prepare_streamlit_script(self):
with tempfile.TemporaryDirectory(
prefix="mlem_streamlit_script_"
) as tempdir:
path = os.path.join(tempdir, SCRIPT_PY)
templates_dir = []
if self.template is not None:
shutil.copy(self.template, os.path.join(tempdir, TEMPLATE_PY))
templates_dir = [tempdir]
StreamlitScript(
server_host=self.server_host,
server_port=str(self.server_port),
templates_dir=templates_dir,
).write(path)
self._write_streamlit_script(path)
with self.run_streamlit_daemon(path):
yield
yield path

@contextlib.contextmanager
def run_streamlit_daemon(self, path):
Expand Down
16 changes: 10 additions & 6 deletions mlem/contrib/streamlit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ def method_form(method_name: str, method: InterfaceMethod, client: HTTPClient):
with streamlit.tabs(["Response:"])[0]:
with streamlit.spinner("Processing..."):
try:
response = getattr(client, method_name)(
**{
k: v.dict() if isinstance(v, BaseModel) else v
for k, v in arg_values.items()
}
)
response = call_method(client, method_name, arg_values)
except ExecutionError as e:
streamlit.error(e)
return
Expand All @@ -55,6 +50,15 @@ def method_form(method_name: str, method: InterfaceMethod, client: HTTPClient):
streamlit.write(response)


def call_method(client: HTTPClient, method_name: str, arg_values: dict):
return getattr(client, method_name)(
**{
k: v.dict() if isinstance(v, BaseModel) else v
for k, v in arg_values.items()
}
)


def method_args(method_name: str, method: InterfaceMethod):
arg_values = {}
for arg in method.args:
Expand Down

0 comments on commit f6b5c14

Please sign in to comment.