diff --git a/app/Dashboard.py b/app/Dashboard.py index f31c4af..e3ca8c5 100644 --- a/app/Dashboard.py +++ b/app/Dashboard.py @@ -286,8 +286,7 @@ def draw_page(self): active_adata: AdataModel = AdataModel(work_id=adata.work_id, adata=anndata, filename=adata.filename, created=adata.created, adata_name=adata.adata_name, notes = adata.notes, id=adata.id) st.session_state["adata_state"] = AdataState(active=active_adata, insert_into_db=False) os.environ['WORKDIR'] = workspace.data_dir #set wd - with open(os.path.join(os.getenv('TMP_DIR'), 'session_state.pkl'), 'wb') as pkl_file: - pickle.dump(st.session_state, pkl_file) + with st.sidebar: st.subheader(f"{workspace.workspace_name}") st.markdown(f"""

{workspace.description}

""", unsafe_allow_html=True) diff --git a/app/pages/1_Upload.py b/app/pages/1_Upload.py index fc9ccfd..c7ed0eb 100644 --- a/app/pages/1_Upload.py +++ b/app/pages/1_Upload.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from database.schemas import schemas from state.AdataState import AdataState +from state.StateManager import StateManager from state.ScriptState import ScriptState from utils.session_cache import load_data_from_cache, cache_data_to_session import loompy as lmp @@ -91,12 +92,15 @@ def upload_file(self): upload_path = os.path.join(workspace_model.data_dir, "uploads") download_path = os.path.join(workspace_model.data_dir, "downloads") adata_path = os.path.join(workspace_model.data_dir, "adata") + tmp_path = os.path.join(workspace_model.data_dir, "tmp") if not os.path.exists(upload_path): os.mkdir(upload_path) if not os.path.exists(download_path): os.mkdir(download_path) if not os.path.exists(adata_path): os.mkdir(adata_path) + if not os.path.exists(tmp_path): + os.mkdir(tmp_path) @@ -125,7 +129,7 @@ def upload_file(self): except KeyError as ke: print("KeyError: ", ke) if(st.session_state == {}): - load_data_from_cache() + StateManager.load_session() else: st.error("Couldn't find workspace in session, have you selected one?") @@ -322,8 +326,6 @@ def show_anndata(self, adata, f = None, filename = ""): self.show_sidebar_preview(f) - cache_data_to_session() - except ValidationError as e: st.error(e) diff --git a/app/pages/2_Preprocess.py b/app/pages/2_Preprocess.py index 3624c9d..820b1da 100644 --- a/app/pages/2_Preprocess.py +++ b/app/pages/2_Preprocess.py @@ -33,6 +33,7 @@ import re import plotly.graph_objects as go from models.ScriptModel import Language +from state.StateManager import StateManager st.set_page_config(layout="wide", page_title='Nuwa', page_icon='🧬') @@ -75,6 +76,8 @@ def __init__(self, adata: AnnData): if "preprocess_plots" not in st.session_state: st.session_state["preprocess_plots"] = dict(pca=None) + + self.state_manager = StateManager() def save_adata(self): @@ -111,8 +114,12 @@ def filter_highest_expr_genes(self): fig = highest_expr_genes_box_plot(self.adata, n_top=n_top_genes) st.plotly_chart(fig) - # write to script state - Highest_expr_genes.add_script(n_top_genes=n_top_genes, language=Language.ALL_SUPPORTED) + # save session + self.state_manager \ + .add_adata(self.adata) \ + .add_script(Highest_expr_genes(n_top_genes=n_top_genes, language=Language.ALL_SUPPORTED)) \ + .save_session() + diff --git a/app/scripts/Script.py b/app/scripts/Script.py new file mode 100644 index 0000000..126c524 --- /dev/null +++ b/app/scripts/Script.py @@ -0,0 +1,10 @@ +import streamlit as st +from state.ScriptState import ScriptState + +class Script: + + def __init__(self): + self.script_state: ScriptState = st.session_state.script_state + + def add_script(self): + return \ No newline at end of file diff --git a/app/scripts/preprocessing/Annotate_mito.py b/app/scripts/preprocessing/Annotate_mito.py index c43c47c..5212c04 100644 --- a/app/scripts/preprocessing/Annotate_mito.py +++ b/app/scripts/preprocessing/Annotate_mito.py @@ -7,7 +7,6 @@ class Annotate_mito: Exports an R or python script for annotating mitochondrial genes. """ - @staticmethod def add_script(language: Language | str, object: str = None): script_state: ScriptState = st.session_state.script_state @@ -45,7 +44,6 @@ class Filter_mito: Exports an R or python script for filtering out mitochondrial genes. """ - @staticmethod def add_script(language: Language | str, mito_pct: int, object: str = None): script_state: ScriptState = st.session_state.script_state diff --git a/app/scripts/preprocessing/Highest_expr_genes.py b/app/scripts/preprocessing/Highest_expr_genes.py index 1aa92db..a7644c1 100644 --- a/app/scripts/preprocessing/Highest_expr_genes.py +++ b/app/scripts/preprocessing/Highest_expr_genes.py @@ -1,25 +1,31 @@ from models.ScriptModel import Language import streamlit as st from state.ScriptState import ScriptState +from scripts.Script import Script -class Highest_expr_genes: +class Highest_expr_genes(Script): """ Exports an R or python script for plotting highest expressed genes. """ - @staticmethod - def add_script(language: Language | str, n_top_genes: int = 20, object: str = None): + def __init__(self, language: Language | str, n_top_genes: int = 20, object: str = None): + super().__init__() + if isinstance(language, Language): + self.language = language + + self.n_top_genes = n_top_genes + self.object = object - script_state: ScriptState = st.session_state.script_state + def add_script(self): - if language == Language.R or language == Language.R.value or language == Language.ALL_SUPPORTED: - if object == None: - object = "pbmc.data" + if self.language == Language.R or self.language == Language.R.value or self.language == Language.ALL_SUPPORTED: + if self.object == None: + self.object = "pbmc.data" script = f""" \ \n# This uses the scater library \ \nplotHighestExprs( \ - \n\t{object}, \ - \n\tn = {n_top_genes}, \ + \n\t{self.object}, \ + \n\tn = {self.n_top_genes}, \ \n\tcolour_cells_by = NULL, \ \n\tdrop_features = NULL, \ \n\texprs_values = "counts", \ @@ -29,18 +35,18 @@ def add_script(language: Language | str, n_top_genes: int = 20, object: str = No \n\tswap_rownames = NULL \ \n) """ - script_state.add_script(script, language=Language.R) + self.script_state.add_script(script, language=Language.R) - if language == Language.python or language == Language.python.value or language == Language.ALL_SUPPORTED: - if object == None: - object = "adata" + if self.language == Language.python or self.language == Language.python.value or self.language == Language.ALL_SUPPORTED: + if self.object == None: + self.object = "adata" script = f""" \n# Plot highest expr genes \ - \nsc.pl.highest_expr_genes({object}, n_top={n_top_genes}) + \nsc.pl.highest_expr_genes({self.object}, n_top={self.n_top_genes}) """ - script_state.add_script(script, language=Language.python) + self.script_state.add_script(script, language=Language.python) - if not isinstance(language, Language): + if not isinstance(self.language, Language): print("Error: Unknown language, not adding to script state") return diff --git a/app/state/AdataState.py b/app/state/AdataState.py index aa496e6..384efd4 100644 --- a/app/state/AdataState.py +++ b/app/state/AdataState.py @@ -27,7 +27,7 @@ def __init__(self, active: AdataModel, insert_into_db=True): current: schemas.Adata = db_adatas.first() self.current = AdataModel(work_id=current.work_id, adata_name=current.adata_name, created=current.created, notes=current.notes, id=current.id, filename=current.filename) - #add original adata to object + # add original adata to object self.current.adata = active.adata #set current to newly created adata self.current_index = self.get_index_of_current() diff --git a/app/state/StateManager.py b/app/state/StateManager.py new file mode 100644 index 0000000..e4c5d39 --- /dev/null +++ b/app/state/StateManager.py @@ -0,0 +1,53 @@ +from state.AdataState import AdataState +from state.ScriptState import ScriptState +from models.AdataModel import AdataModel +from models.ScriptModel import ScriptModel +from models.WorkspaceModel import WorkspaceModel +from scripts.Script import Script +from utils.session_cache import cache_data_to_session, load_data_from_cache +import scanpy as sc +import os +import streamlit as st +from anndata import AnnData + +class StateManager: + """ + Makes changes made to the database and filesystems synchronously using data from session state in an atomic way. Also responsible for loading data into session state and initialising files when loading new dataset. + """ + + def add_script(self, script: Script): + # add script if present + if script is not None: + if isinstance(script, Script): + self.script = script + return self + + def add_adata(self, adata: AnnData): + if isinstance(adata, AnnData): + self.add_adata = adata + return self + + def load_session(): + raise NotImplementedError + + + def save_session(self): + """ + Takes a copy of current session state and saves in pickle format. Extracts adata from session state and updates filesystem and database. Also + takes an optional script to update database. + + """ + # write adata h5ad object to file + if hasattr(self, 'adata'): + sc.write(filename=os.path.join(os.getenv('WORKDIR'), 'adata', st.session_state.adata_state.current.adata_name), adata=self.adata) + + # add script if present + if hasattr(self, 'script'): + self.script.add_script() + + # cache data to pickle file + cache_data_to_session() + + + def init_session(): + raise NotImplementedError \ No newline at end of file diff --git a/app/utils/session_cache.py b/app/utils/session_cache.py index a627d29..2b32c84 100644 --- a/app/utils/session_cache.py +++ b/app/utils/session_cache.py @@ -1,3 +1,4 @@ +import hashlib import pickle import os import streamlit as st @@ -5,7 +6,7 @@ def cache_data_to_session(): try: - dbfile = open(os.path.join(os.getenv('TMP_DIR'), 'session_state'), 'wb') + state = {} for key in st.session_state: #streamlit form and button can't be set using session state, so remove them here @@ -14,6 +15,15 @@ def cache_data_to_session(): state['adata_state'].conn = None state['script_state'].conn = None + + # create hash to be used as filename + encoded = pickle.dumps(state) + hash = hashlib.md5() + hash.update(encoded) + state_hash = hash.hexdigest() + + # Write to file + dbfile = open(os.path.join(os.getenv('WORKDIR'), 'tmp', state_hash), 'wb') pickle.dump(state, dbfile) #python doesn't copy the objects so db connection in state is destroyed. Add it back here st.session_state["adata_state"].conn = SessionLocal() @@ -22,12 +32,12 @@ def cache_data_to_session(): except Exception as e: st.toast(e, icon="❌") -def load_data_from_cache(): +def load_data_from_cache(state_file): try: - dbfile = open(os.path.join(os.getenv('TMP_DIR'), 'session_state'), 'rb') + dbfile = open(os.path.join(os.getenv('TMP_DIR'), state_file), 'rb') session = pickle.load(dbfile) for key in session: - st.session_state[key] = session[key] #load in keys to session state + st.session_state[key] = session[key] # load in keys to session state dbfile.close() adata_state = st.session_state.adata_state diff --git a/cpu.docker-compose.yml b/cpu.docker-compose.yml index 046d78f..56e741c 100644 --- a/cpu.docker-compose.yml +++ b/cpu.docker-compose.yml @@ -20,7 +20,6 @@ services: - POSTGRES_DB=postgres - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - - TMP_DIR=/app/tmp ports: - ${SERVER_ADDR}:${HOST_PORT}:8501 postgres: diff --git a/cuda.docker-compose.yml b/cuda.docker-compose.yml index 09df3c0..dbdc3ed 100644 --- a/cuda.docker-compose.yml +++ b/cuda.docker-compose.yml @@ -20,7 +20,6 @@ services: - POSTGRES_DB=postgres - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - - TMP_DIR=/app/tmp ports: - ${SERVER_ADDR}:${HOST_PORT}:8501 deploy: diff --git a/docs/tutorials/clustering/README.md b/docs/tutorials/clustering/README.md index 97087ca..672dd60 100644 --- a/docs/tutorials/clustering/README.md +++ b/docs/tutorials/clustering/README.md @@ -4,7 +4,7 @@ sort: 1 # Clustering -This set of tutorials will be a complete walkthrough of preprocessing, integrating and clustering multiple covid datasets using Nuwa. It is based on the set of tutorials by [NBISweden](https://nbisweden.github.io/workshop-scRNAseq/home_contents.html). Only minimal knowledge of single cell analysis is required for this tutorial. +This set of tutorials will be a complete walkthrough of preprocessing, clustering and differential expression of a covid dataset using Nuwa. It is based on the set of tutorials by [NBISweden](https://nbisweden.github.io/workshop-scRNAseq/home_contents.html). Only minimal knowledge of single cell analysis is required for this tutorial. {% include list.liquid all=true %} diff --git a/tests.docker-compose.yml b/tests.docker-compose.yml index a0045ee..364d324 100644 --- a/tests.docker-compose.yml +++ b/tests.docker-compose.yml @@ -21,7 +21,6 @@ services: - POSTGRES_DB=postgres - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - - TMP_DIR=/app/tmp ports: - ${SERVER_ADDR}:${HOST_PORT}:8501 postgres: