From 638bd571e3fb523e163eb52247d8ba65ac071b38 Mon Sep 17 00:00:00 2001 From: LemurPwned Date: Fri, 20 Dec 2024 16:25:05 +0100 Subject: [PATCH] adding export-import session --- view/streamlit_app.py | 52 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/view/streamlit_app.py b/view/streamlit_app.py index 08092c4..1d03aae 100644 --- a/view/streamlit_app.py +++ b/view/streamlit_app.py @@ -5,6 +5,24 @@ from autofit import autofit from helpers import simulate_pimm, simulate_vsd from utils import GENERIC_BOUNDS, GENERIC_UNITS +import json + + +def export_session_state(): + export_dict = {} + opts = ["_btn", "_file", "_state", "low_", "up_", "check_", "upload"] + for k, v in st.session_state.items(): + skip = any(forb_opts in k for forb_opts in opts) + if not skip: + export_dict[k] = v + + return json.dumps(export_dict) + + +def import_session_state(file): + for k, v in json.load(file).items(): + st.session_state[k] = v + with st.expander("# Read me"): st.write( @@ -17,13 +35,32 @@ ) with st.sidebar: - st.file_uploader( - "Upload your data here", - help="Upload your data here. Must be `\t` separated values and have H and f headers.", - type=["txt", "dat"], - accept_multiple_files=False, - key="upload", - ) + with st.expander("Export/Import"): + st.download_button( + label="Export session state", + data=export_session_state(), + file_name="session_state.json", + mime="application/json", + type="primary", + ) + + st.file_uploader( + "Upload session state", + help="Upload your data here. Must be `\t` separated values and have H and f headers.", + type=["json"], + accept_multiple_files=False, + key="import_file", + ) + if st.session_state.import_file: + import_session_state(st.session_state.import_file) + + st.file_uploader( + "Upload your data here", + help="Upload your data here. Must be `\t` separated values and have H and f headers.", + type=["txt", "dat"], + accept_multiple_files=False, + key="upload", + ) N = st.number_input( "Number of layers", min_value=1, max_value=10, value=1, key="N", format="%d" ) @@ -149,7 +186,6 @@ help="Maximum frequency (cutoff) visible in plot", ) - pimm_tab, vsd_tab, opt_tab = st.tabs(["PIMM", "VSD", "Optimization"]) with opt_tab: with st.expander("Optimization parameters"):