Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Add some args to streamlit server configuration (#607)
Browse files Browse the repository at this point in the history
including custom args for custom templates

closes #583
  • Loading branch information
mike0sv authored Feb 14, 2023
1 parent f51fbe2 commit 362bf2f
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 26 deletions.
36 changes: 23 additions & 13 deletions mlem/contrib/streamlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import streamlit
import streamlit_pydantic
from pydantic import BaseModel

from mlem.core.errors import MlemError
from mlem.core.requirements import LibRequirementsMixin, Requirements
Expand All @@ -21,27 +22,37 @@
SCRIPT_PY = "script.py"


class StreamlitScript(TemplateModel):
TEMPLATE_FILE: ClassVar = TEMPLATE_PY
TEMPLATE_DIR: ClassVar = os.path.dirname(__file__)

class StreamlitScript(BaseModel):
server_host: str = "0.0.0.0"
server_port: str = "8080"
"""Hostname for running FastAPI backend"""
server_port: int = 8081
"""Port for running FastAPI backend"""
page_title: str = "MLEM Streamlit UI"
"""Title of the page in browser"""
title: str = "MLEM Streamlit UI"
"""Title of the page"""
description: str = ""
"""Additional text after title"""
args: Dict[str, str]
"""Additional args for custom template"""


class StreamlitServer(Server, LibRequirementsMixin):
class StreamlitTemplate(StreamlitScript, TemplateModel):
TEMPLATE_FILE: ClassVar = TEMPLATE_PY
TEMPLATE_DIR: ClassVar = os.path.dirname(__file__)

def prepare_dict(self):
d = super().prepare_dict()
d.update(self.args)
return d


class StreamlitServer(Server, StreamlitScript, LibRequirementsMixin):
"""Streamlit UI server"""

type: ClassVar = "streamlit"
libraries: ClassVar = (streamlit, streamlit_pydantic)

server_host: str = "0.0.0.0"
"""Hostname for running FastAPI backend"""
server_port: int = 8080
"""Port for running FastAPI backend"""
run_server: bool = True
"""Whether to run backend server or use existing one"""
ui_host: str = "0.0.0.0"
Expand Down Expand Up @@ -88,9 +99,8 @@ def _write_streamlit_script(self, path):
if self.template is not None:
shutil.copy(self.template, os.path.join(dirname, TEMPLATE_PY))
templates_dir = [dirname]
StreamlitScript(
server_host=self.server_host,
server_port=str(self.server_port),
StreamlitTemplate.from_model(
self,
templates_dir=templates_dir,
).write(path)

Expand Down
11 changes: 10 additions & 1 deletion mlem/utils/templates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List
from typing import ClassVar, List, Type, TypeVar

from fsspec import AbstractFileSystem
from fsspec.implementations.local import LocalFileSystem
Expand All @@ -10,6 +10,8 @@
)
from pydantic import BaseModel

T = TypeVar("T", bound="TemplateModel")


class TemplateModel(BaseModel):
"""Base class to render jinja templates from pydantic models"""
Expand Down Expand Up @@ -38,3 +40,10 @@ def write(self, path: str, fs: AbstractFileSystem = None, **additional):
fs = fs or LocalFileSystem()
with fs.open(path, "w") as f:
f.write(self.generate(**additional))

@classmethod
def from_model(cls: Type[T], obj, templates_dir: List[str] = None) -> T:
args = {
f: getattr(obj, f) for f in cls.__fields__ if f != "templates_dir"
}
return cls(templates_dir=templates_dir or [], **args)
36 changes: 36 additions & 0 deletions tests/contrib/test_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import BaseModel

from mlem.contrib.streamlit.server import StreamlitServer
from mlem.contrib.streamlit.utils import augment_model


Expand Down Expand Up @@ -33,3 +34,38 @@ class M4(BaseModel):

aug, model = augment_model(M4)
assert model is None


def test_custom_template(tmpdir):
template_path = str(tmpdir / "template")
with open(template_path, "w", encoding="utf8") as f:
f.write(
"""{{page_title}}
{{title}}
{{description}}
{{server_host}}
{{server_port}}
{{custom_arg}}"""
)
server = StreamlitServer(
template=template_path,
page_title="page title",
title="title",
description="description",
server_host="host",
server_port=0,
args={"custom_arg": "custom arg"},
)
path = str(tmpdir / "script")
server._write_streamlit_script(path) # pylint: disable=protected-access

with open(path, encoding="utf8") as f:
assert (
f.read()
== """page title
title
description
host
0
custom arg"""
)
32 changes: 20 additions & 12 deletions tests/utils/test_save.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2023-01-05T20:28:08.980607Z",
"iopub.status.busy": "2023-01-05T20:28:08.980471Z",
"iopub.status.idle": "2023-01-05T20:28:08.985874Z",
"shell.execute_reply": "2023-01-05T20:28:08.985523Z"
"iopub.execute_input": "2023-02-13T14:11:29.261665Z",
"iopub.status.busy": "2023-02-13T14:11:29.261394Z",
"iopub.status.idle": "2023-02-13T14:11:29.267566Z",
"shell.execute_reply": "2023-02-13T14:11:29.266734Z"
},
"pycharm": {
"name": "#%%\n"
Expand All @@ -24,20 +24,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2023-01-05T20:28:08.987817Z",
"iopub.status.busy": "2023-01-05T20:28:08.987672Z",
"iopub.status.idle": "2023-01-05T20:28:09.517178Z",
"shell.execute_reply": "2023-01-05T20:28:09.516853Z"
"iopub.execute_input": "2023-02-13T14:11:29.270506Z",
"iopub.status.busy": "2023-02-13T14:11:29.270233Z",
"iopub.status.idle": "2023-02-13T14:11:29.969468Z",
"shell.execute_reply": "2023-02-13T14:11:29.968705Z"
},
"pycharm": {
"name": "#%%\n",
"is_executing": true
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"numpy==1.22.4\n"
]
}
],
"source": [
"from mlem.utils.module import get_object_requirements\n",
"\n",
Expand Down

0 comments on commit 362bf2f

Please sign in to comment.