Skip to content

Commit

Permalink
added statemanager class
Browse files Browse the repository at this point in the history
  • Loading branch information
ch1ru committed Feb 28, 2024
1 parent bb485bc commit e03a6ca
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 34 deletions.
3 changes: 1 addition & 2 deletions app/Dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ def draw_page(self):
active_adata: AdataModel = AdataModel(work_id=adata.work_id, adata=anndata, filename=adata.filename, created=adata.created, adata_name=adata.adata_name, notes = adata.notes, id=adata.id)
st.session_state["adata_state"] = AdataState(active=active_adata, insert_into_db=False)
os.environ['WORKDIR'] = workspace.data_dir #set wd
with open(os.path.join(os.getenv('TMP_DIR'), 'session_state.pkl'), 'wb') as pkl_file:
pickle.dump(st.session_state, pkl_file)

with st.sidebar:
st.subheader(f"{workspace.workspace_name}")
st.markdown(f"""<p style='font-size: 16px; color: rgba(255, 255, 255, 1)'>{workspace.description}<p>""", unsafe_allow_html=True)
Expand Down
8 changes: 5 additions & 3 deletions app/pages/1_Upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.orm import Session
from database.schemas import schemas
from state.AdataState import AdataState
from state.StateManager import StateManager
from state.ScriptState import ScriptState
from utils.session_cache import load_data_from_cache, cache_data_to_session
import loompy as lmp
Expand Down Expand Up @@ -91,12 +92,15 @@ def upload_file(self):
upload_path = os.path.join(workspace_model.data_dir, "uploads")
download_path = os.path.join(workspace_model.data_dir, "downloads")
adata_path = os.path.join(workspace_model.data_dir, "adata")
tmp_path = os.path.join(workspace_model.data_dir, "tmp")
if not os.path.exists(upload_path):
os.mkdir(upload_path)
if not os.path.exists(download_path):
os.mkdir(download_path)
if not os.path.exists(adata_path):
os.mkdir(adata_path)
if not os.path.exists(tmp_path):
os.mkdir(tmp_path)



Expand Down Expand Up @@ -125,7 +129,7 @@ def upload_file(self):
except KeyError as ke:
print("KeyError: ", ke)
if(st.session_state == {}):
load_data_from_cache()
StateManager.load_session()
else:
st.error("Couldn't find workspace in session, have you selected one?")

Expand Down Expand Up @@ -322,8 +326,6 @@ def show_anndata(self, adata, f = None, filename = ""):

self.show_sidebar_preview(f)

cache_data_to_session()


except ValidationError as e:
st.error(e)
Expand Down
11 changes: 9 additions & 2 deletions app/pages/2_Preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import re
import plotly.graph_objects as go
from models.ScriptModel import Language
from state.StateManager import StateManager


st.set_page_config(layout="wide", page_title='Nuwa', page_icon='🧬')
Expand Down Expand Up @@ -75,6 +76,8 @@ def __init__(self, adata: AnnData):

if "preprocess_plots" not in st.session_state:
st.session_state["preprocess_plots"] = dict(pca=None)

self.state_manager = StateManager()


def save_adata(self):
Expand Down Expand Up @@ -111,8 +114,12 @@ def filter_highest_expr_genes(self):
fig = highest_expr_genes_box_plot(self.adata, n_top=n_top_genes)
st.plotly_chart(fig)

# write to script state
Highest_expr_genes.add_script(n_top_genes=n_top_genes, language=Language.ALL_SUPPORTED)
# save session
self.state_manager \
.add_adata(self.adata) \
.add_script(Highest_expr_genes(n_top_genes=n_top_genes, language=Language.ALL_SUPPORTED)) \
.save_session()




Expand Down
10 changes: 10 additions & 0 deletions app/scripts/Script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import streamlit as st
from state.ScriptState import ScriptState

class Script:

def __init__(self):
self.script_state: ScriptState = st.session_state.script_state

def add_script(self):
return
2 changes: 0 additions & 2 deletions app/scripts/preprocessing/Annotate_mito.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ class Annotate_mito:
Exports an R or python script for annotating mitochondrial genes.
"""

@staticmethod
def add_script(language: Language | str, object: str = None):

script_state: ScriptState = st.session_state.script_state
Expand Down Expand Up @@ -45,7 +44,6 @@ class Filter_mito:
Exports an R or python script for filtering out mitochondrial genes.
"""

@staticmethod
def add_script(language: Language | str, mito_pct: int, object: str = None):

script_state: ScriptState = st.session_state.script_state
Expand Down
38 changes: 22 additions & 16 deletions app/scripts/preprocessing/Highest_expr_genes.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from models.ScriptModel import Language
import streamlit as st
from state.ScriptState import ScriptState
from scripts.Script import Script

class Highest_expr_genes:
class Highest_expr_genes(Script):
"""
Exports an R or python script for plotting highest expressed genes.
"""

@staticmethod
def add_script(language: Language | str, n_top_genes: int = 20, object: str = None):
def __init__(self, language: Language | str, n_top_genes: int = 20, object: str = None):
super().__init__()
if isinstance(language, Language):
self.language = language

self.n_top_genes = n_top_genes
self.object = object

script_state: ScriptState = st.session_state.script_state
def add_script(self):

if language == Language.R or language == Language.R.value or language == Language.ALL_SUPPORTED:
if object == None:
object = "pbmc.data"
if self.language == Language.R or self.language == Language.R.value or self.language == Language.ALL_SUPPORTED:
if self.object == None:
self.object = "pbmc.data"
script = f""" \
\n# This uses the scater library \
\nplotHighestExprs( \
\n\t{object}, \
\n\tn = {n_top_genes}, \
\n\t{self.object}, \
\n\tn = {self.n_top_genes}, \
\n\tcolour_cells_by = NULL, \
\n\tdrop_features = NULL, \
\n\texprs_values = "counts", \
Expand All @@ -29,18 +35,18 @@ def add_script(language: Language | str, n_top_genes: int = 20, object: str = No
\n\tswap_rownames = NULL \
\n)
"""
script_state.add_script(script, language=Language.R)
self.script_state.add_script(script, language=Language.R)

if language == Language.python or language == Language.python.value or language == Language.ALL_SUPPORTED:
if object == None:
object = "adata"
if self.language == Language.python or self.language == Language.python.value or self.language == Language.ALL_SUPPORTED:
if self.object == None:
self.object = "adata"
script = f"""
\n# Plot highest expr genes \
\nsc.pl.highest_expr_genes({object}, n_top={n_top_genes})
\nsc.pl.highest_expr_genes({self.object}, n_top={self.n_top_genes})
"""
script_state.add_script(script, language=Language.python)
self.script_state.add_script(script, language=Language.python)

if not isinstance(language, Language):
if not isinstance(self.language, Language):
print("Error: Unknown language, not adding to script state")
return

2 changes: 1 addition & 1 deletion app/state/AdataState.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, active: AdataModel, insert_into_db=True):
current: schemas.Adata = db_adatas.first()
self.current = AdataModel(work_id=current.work_id, adata_name=current.adata_name, created=current.created, notes=current.notes, id=current.id, filename=current.filename)

#add original adata to object
# add original adata to object
self.current.adata = active.adata
#set current to newly created adata
self.current_index = self.get_index_of_current()
Expand Down
53 changes: 53 additions & 0 deletions app/state/StateManager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from state.AdataState import AdataState
from state.ScriptState import ScriptState
from models.AdataModel import AdataModel
from models.ScriptModel import ScriptModel
from models.WorkspaceModel import WorkspaceModel
from scripts.Script import Script
from utils.session_cache import cache_data_to_session, load_data_from_cache
import scanpy as sc
import os
import streamlit as st
from anndata import AnnData

class StateManager:
"""
Makes changes made to the database and filesystems synchronously using data from session state in an atomic way. Also responsible for loading data into session state and initialising files when loading new dataset.
"""

def add_script(self, script: Script):
# add script if present
if script is not None:
if isinstance(script, Script):
self.script = script
return self

def add_adata(self, adata: AnnData):
if isinstance(adata, AnnData):
self.add_adata = adata
return self

def load_session():
raise NotImplementedError


def save_session(self):
"""
Takes a copy of current session state and saves in pickle format. Extracts adata from session state and updates filesystem and database. Also
takes an optional script to update database.
"""
# write adata h5ad object to file
if hasattr(self, 'adata'):
sc.write(filename=os.path.join(os.getenv('WORKDIR'), 'adata', st.session_state.adata_state.current.adata_name), adata=self.adata)

# add script if present
if hasattr(self, 'script'):
self.script.add_script()

# cache data to pickle file
cache_data_to_session()


def init_session():
raise NotImplementedError
18 changes: 14 additions & 4 deletions app/utils/session_cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import hashlib
import pickle
import os
import streamlit as st
from database.database import SessionLocal

def cache_data_to_session():
try:
dbfile = open(os.path.join(os.getenv('TMP_DIR'), 'session_state'), 'wb')

state = {}
for key in st.session_state:
#streamlit form and button can't be set using session state, so remove them here
Expand All @@ -14,6 +15,15 @@ def cache_data_to_session():

state['adata_state'].conn = None
state['script_state'].conn = None

# create hash to be used as filename
encoded = pickle.dumps(state)
hash = hashlib.md5()
hash.update(encoded)
state_hash = hash.hexdigest()

# Write to file
dbfile = open(os.path.join(os.getenv('WORKDIR'), 'tmp', state_hash), 'wb')
pickle.dump(state, dbfile)
#python doesn't copy the objects so db connection in state is destroyed. Add it back here
st.session_state["adata_state"].conn = SessionLocal()
Expand All @@ -22,12 +32,12 @@ def cache_data_to_session():
except Exception as e:
st.toast(e, icon="❌")

def load_data_from_cache():
def load_data_from_cache(state_file):
try:
dbfile = open(os.path.join(os.getenv('TMP_DIR'), 'session_state'), 'rb')
dbfile = open(os.path.join(os.getenv('TMP_DIR'), state_file), 'rb')
session = pickle.load(dbfile)
for key in session:
st.session_state[key] = session[key] #load in keys to session state
st.session_state[key] = session[key] # load in keys to session state
dbfile.close()

adata_state = st.session_state.adata_state
Expand Down
1 change: 0 additions & 1 deletion cpu.docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ services:
- POSTGRES_DB=postgres
- POSTGRES_HOST=postgres
- POSTGRES_PORT=5432
- TMP_DIR=/app/tmp
ports:
- ${SERVER_ADDR}:${HOST_PORT}:8501
postgres:
Expand Down
1 change: 0 additions & 1 deletion cuda.docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ services:
- POSTGRES_DB=postgres
- POSTGRES_HOST=postgres
- POSTGRES_PORT=5432
- TMP_DIR=/app/tmp
ports:
- ${SERVER_ADDR}:${HOST_PORT}:8501
deploy:
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/clustering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ sort: 1

# Clustering

This set of tutorials will be a complete walkthrough of preprocessing, integrating and clustering multiple covid datasets using Nuwa. It is based on the set of tutorials by [NBISweden](https://nbisweden.github.io/workshop-scRNAseq/home_contents.html). Only minimal knowledge of single cell analysis is required for this tutorial.
This set of tutorials will be a complete walkthrough of preprocessing, clustering and differential expression of a covid dataset using Nuwa. It is based on the set of tutorials by [NBISweden](https://nbisweden.github.io/workshop-scRNAseq/home_contents.html). Only minimal knowledge of single cell analysis is required for this tutorial.

{% include list.liquid all=true %}

Expand Down
1 change: 0 additions & 1 deletion tests.docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ services:
- POSTGRES_DB=postgres
- POSTGRES_HOST=postgres
- POSTGRES_PORT=5432
- TMP_DIR=/app/tmp
ports:
- ${SERVER_ADDR}:${HOST_PORT}:8501
postgres:
Expand Down

0 comments on commit e03a6ca

Please sign in to comment.