From 362bf2fef9f3a6eaea0b8bd4c6339b250b6c992a Mon Sep 17 00:00:00 2001 From: Mikhail Sveshnikov Date: Tue, 14 Feb 2023 21:39:20 +0300 Subject: [PATCH] Add some args to streamlit server configuration (#607) including custom args for custom templates closes #583 --- mlem/contrib/streamlit/server.py | 36 ++++++++++++++++++++------------ mlem/utils/templates.py | 11 +++++++++- tests/contrib/test_streamlit.py | 36 ++++++++++++++++++++++++++++++++ tests/utils/test_save.ipynb | 32 +++++++++++++++++----------- 4 files changed, 89 insertions(+), 26 deletions(-) diff --git a/mlem/contrib/streamlit/server.py b/mlem/contrib/streamlit/server.py index 7ba37422..110ccf8b 100644 --- a/mlem/contrib/streamlit/server.py +++ b/mlem/contrib/streamlit/server.py @@ -9,6 +9,7 @@ import streamlit import streamlit_pydantic +from pydantic import BaseModel from mlem.core.errors import MlemError from mlem.core.requirements import LibRequirementsMixin, Requirements @@ -21,27 +22,37 @@ SCRIPT_PY = "script.py" -class StreamlitScript(TemplateModel): - TEMPLATE_FILE: ClassVar = TEMPLATE_PY - TEMPLATE_DIR: ClassVar = os.path.dirname(__file__) - +class StreamlitScript(BaseModel): server_host: str = "0.0.0.0" - server_port: str = "8080" + """Hostname for running FastAPI backend""" + server_port: int = 8081 + """Port for running FastAPI backend""" page_title: str = "MLEM Streamlit UI" + """Title of the page in browser""" title: str = "MLEM Streamlit UI" + """Title of the page""" description: str = "" + """Additional text after title""" + args: Dict[str, str] + """Additional args for custom template""" -class StreamlitServer(Server, LibRequirementsMixin): +class StreamlitTemplate(StreamlitScript, TemplateModel): + TEMPLATE_FILE: ClassVar = TEMPLATE_PY + TEMPLATE_DIR: ClassVar = os.path.dirname(__file__) + + def prepare_dict(self): + d = super().prepare_dict() + d.update(self.args) + return d + + +class StreamlitServer(Server, StreamlitScript, LibRequirementsMixin): """Streamlit UI server""" type: ClassVar = "streamlit" libraries: ClassVar = (streamlit, streamlit_pydantic) - server_host: str = "0.0.0.0" - """Hostname for running FastAPI backend""" - server_port: int = 8080 - """Port for running FastAPI backend""" run_server: bool = True """Whether to run backend server or use existing one""" ui_host: str = "0.0.0.0" @@ -88,9 +99,8 @@ def _write_streamlit_script(self, path): if self.template is not None: shutil.copy(self.template, os.path.join(dirname, TEMPLATE_PY)) templates_dir = [dirname] - StreamlitScript( - server_host=self.server_host, - server_port=str(self.server_port), + StreamlitTemplate.from_model( + self, templates_dir=templates_dir, ).write(path) diff --git a/mlem/utils/templates.py b/mlem/utils/templates.py index db7b8dca..eeefea03 100644 --- a/mlem/utils/templates.py +++ b/mlem/utils/templates.py @@ -1,4 +1,4 @@ -from typing import ClassVar, List +from typing import ClassVar, List, Type, TypeVar from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem @@ -10,6 +10,8 @@ ) from pydantic import BaseModel +T = TypeVar("T", bound="TemplateModel") + class TemplateModel(BaseModel): """Base class to render jinja templates from pydantic models""" @@ -38,3 +40,10 @@ def write(self, path: str, fs: AbstractFileSystem = None, **additional): fs = fs or LocalFileSystem() with fs.open(path, "w") as f: f.write(self.generate(**additional)) + + @classmethod + def from_model(cls: Type[T], obj, templates_dir: List[str] = None) -> T: + args = { + f: getattr(obj, f) for f in cls.__fields__ if f != "templates_dir" + } + return cls(templates_dir=templates_dir or [], **args) diff --git a/tests/contrib/test_streamlit.py b/tests/contrib/test_streamlit.py index 0bd7578b..d4abb86a 100644 --- a/tests/contrib/test_streamlit.py +++ b/tests/contrib/test_streamlit.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from mlem.contrib.streamlit.server import StreamlitServer from mlem.contrib.streamlit.utils import augment_model @@ -33,3 +34,38 @@ class M4(BaseModel): aug, model = augment_model(M4) assert model is None + + +def test_custom_template(tmpdir): + template_path = str(tmpdir / "template") + with open(template_path, "w", encoding="utf8") as f: + f.write( + """{{page_title}} +{{title}} +{{description}} +{{server_host}} +{{server_port}} +{{custom_arg}}""" + ) + server = StreamlitServer( + template=template_path, + page_title="page title", + title="title", + description="description", + server_host="host", + server_port=0, + args={"custom_arg": "custom arg"}, + ) + path = str(tmpdir / "script") + server._write_streamlit_script(path) # pylint: disable=protected-access + + with open(path, encoding="utf8") as f: + assert ( + f.read() + == """page title +title +description +host +0 +custom arg""" + ) diff --git a/tests/utils/test_save.ipynb b/tests/utils/test_save.ipynb index 318904dd..68ccf781 100644 --- a/tests/utils/test_save.ipynb +++ b/tests/utils/test_save.ipynb @@ -5,10 +5,10 @@ "execution_count": 1, "metadata": { "execution": { - "iopub.execute_input": "2023-01-05T20:28:08.980607Z", - "iopub.status.busy": "2023-01-05T20:28:08.980471Z", - "iopub.status.idle": "2023-01-05T20:28:08.985874Z", - "shell.execute_reply": "2023-01-05T20:28:08.985523Z" + "iopub.execute_input": "2023-02-13T14:11:29.261665Z", + "iopub.status.busy": "2023-02-13T14:11:29.261394Z", + "iopub.status.idle": "2023-02-13T14:11:29.267566Z", + "shell.execute_reply": "2023-02-13T14:11:29.266734Z" }, "pycharm": { "name": "#%%\n" @@ -24,20 +24,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "execution": { - "iopub.execute_input": "2023-01-05T20:28:08.987817Z", - "iopub.status.busy": "2023-01-05T20:28:08.987672Z", - "iopub.status.idle": "2023-01-05T20:28:09.517178Z", - "shell.execute_reply": "2023-01-05T20:28:09.516853Z" + "iopub.execute_input": "2023-02-13T14:11:29.270506Z", + "iopub.status.busy": "2023-02-13T14:11:29.270233Z", + "iopub.status.idle": "2023-02-13T14:11:29.969468Z", + "shell.execute_reply": "2023-02-13T14:11:29.968705Z" }, "pycharm": { - "name": "#%%\n", - "is_executing": true + "is_executing": true, + "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "numpy==1.22.4\n" + ] + } + ], "source": [ "from mlem.utils.module import get_object_requirements\n", "\n",