Skip to content

Commit

Permalink
Merge pull request #145 from MannLabs/development
Browse files Browse the repository at this point in the history
merge development branch
  • Loading branch information
sophiamaedler authored Jan 31, 2025
2 parents b46b1e9 + 7391e85 commit 758e7c5
Show file tree
Hide file tree
Showing 15 changed files with 511 additions and 153 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ torch
pytorch-lightning
torchvision

spatialdata
spatialdata>=0.2.0
napari-spatialdata
pyqt5
lxml_html_clean
ashlar>=1.19.0
networkx
py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd
py-lmd>=1.3.1
4 changes: 2 additions & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ torch
pytorch-lightning
torchvision

spatialdata
spatialdata>=0.2.0
napari-spatialdata
pyqt5
lxml_html_clean
ashlar>=1.19.0
networkx
py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd
py-lmd>=1.3.1

#packages for building the documentation
sphinx
Expand Down
11 changes: 11 additions & 0 deletions src/scportrait/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
"""Top-level package for scPortrait"""

# silence warnings
import warnings

from scportrait import io
from scportrait import pipeline as pipeline
from scportrait import plotting as pl
from scportrait import processing as pp
from scportrait import tools as tl

# silence warning from spatialdata resulting in an older dask version see #139
warnings.filterwarnings("ignore", message="ignoring keyword argument 'read_only'")

# silence warning from cellpose resulting in missing parameter set in model call see #141
warnings.filterwarnings(
"ignore", message=r"You are using `torch.load` with `weights_only=False`.*", category=FutureWarning
)
59 changes: 52 additions & 7 deletions src/scportrait/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
import torch

from scportrait.pipeline._utils.helper import read_config


class Logable:
"""Create log entries.
Expand Down Expand Up @@ -92,6 +94,27 @@ def _clean_log_file(self):
if os.path.exists(log_file_path):
os.remove(log_file_path)

# def _clear_cache(self, vars_to_delete=None):
# """Helper function to help clear memory usage. Mainly relevant for GPU based segmentations.

# Args:
# vars_to_delete (list): List of variable names (as strings) to delete.
# """

# # delete all specified variables
# if vars_to_delete is not None:
# for var_name in vars_to_delete:
# if var_name in globals():
# del globals()[var_name]

# if torch.cuda.is_available():
# torch.cuda.empty_cache()

# if torch.backends.mps.is_available():
# torch.mps.empty_cache()

# gc.collect()

def _clear_cache(self, vars_to_delete=None):
"""Helper function to help clear memory usage. Mainly relevant for GPU based segmentations."""

Expand Down Expand Up @@ -137,7 +160,7 @@ class ProcessingStep(Logable):
DEFAULT_SEGMENTATION_DIR_NAME = "segmentation"
DEFAULT_TILES_FOLDER = "tiles"

DEFAULT_EXTRACTIN_DIR_NAME = "extraction"
DEFAULT_EXTRACTION_DIR_NAME = "extraction"
DEFAULT_DATA_DIR = "data"

DEFAULT_IMAGE_DTYPE = np.uint16
Expand All @@ -155,19 +178,41 @@ class ProcessingStep(Logable):
DEFAULT_SELECTION_DIR_NAME = "selection"

def __init__(
self, config, directory, project_location, debug=False, overwrite=False, project=None, filehandler=None
self,
config,
directory=None,
project_location=None,
debug=False,
overwrite=False,
project=None,
filehandler=None,
from_project: bool = False,
):
super().__init__(directory=directory)

self.debug = debug
self.overwrite = overwrite
self.project_location = project_location
self.config = config
if from_project:
self.project_run = True
self.project_location = project_location
self.project = project
self.filehandler = filehandler
else:
self.project_run = False
self.project_location = None
self.project = None
self.filehandler = None

if isinstance(config, str):
config = read_config(config)
if self.__class__.__name__ in config.keys():
self.config = config[self.__class__.__name__]
else:
self.config = config
else:
self.config = config
self.overwrite = overwrite

self.project = project
self.filehandler = filehandler

self.get_context()

self.deep_debug = False
Expand Down
13 changes: 12 additions & 1 deletion src/scportrait/pipeline/_utils/helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from typing import TypeVar

import yaml

T = TypeVar("T")


def flatten(nested_list: list[list[T]]) -> list[T]:
def read_config(config_path: str) -> dict:
with open(config_path) as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
return config


def flatten(nested_list: list[list[T]]) -> list[T | tuple[T]]:
"""Flatten a list of lists into a single list.
Args:
Expand Down
8 changes: 5 additions & 3 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ def _read_sdata(self) -> SpatialData:
_sdata = SpatialData()
_sdata.write(self.sdata_path, overwrite=True)

allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"]
for key in _sdata.labels:
segmentation_object = _sdata.labels[key]
if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)
if key in allowed_labels:
segmentation_object = _sdata.labels[key]
if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)

return _sdata

Expand Down
104 changes: 76 additions & 28 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def __init__(self, *args, **kwargs):
self.overwrite_run_path = self.overwrite

def _get_compression_type(self):
self.compression_type = "lzf" if self.compression else None
if (self.compression is True) or (self.compression == "lzf"):
self.compression_type = "lzf"
elif self.compression == "gzip":
self.compression_type = "gzip"
else:
self.compression_type = None
self.log(f"Compression algorithm: {self.compression_type}")
return self.compression_type

def _check_config(self):
Expand Down Expand Up @@ -261,24 +267,55 @@ def _get_segmentation_info(self):
f"Found no segmentation masks with key {self.segmentation_key}. Cannot proceed with extraction."
)

# get relevant segmentation masks to perform extraction on
nucleus_key = f"{self.segmentation_key}_nucleus"
# intialize default values to track what should be extracted
self.nucleus_key = None
self.cytosol_key = None
self.extract_nucleus_mask = False
self.extract_cytosol_mask = False

if nucleus_key in relevant_masks:
self.extract_nucleus_mask = True
self.nucleus_key = nucleus_key
else:
self.extract_nucleus_mask = False
self.nucleus_key = None
if "segmentation_mask" in self.config:
allowed_mask_values = ["nucleus", "cytosol"]
allowed_mask_values = [f"{self.segmentation_key}_{x}" for x in allowed_mask_values]

if isinstance(self.config["segmentation_mask"], str):
assert self.config["segmentation_mask"] in allowed_mask_values

cytosol_key = f"{self.segmentation_key}_cytosol"
if "nucleus" in self.config["segmentation_mask"]:
self.nucleus_key = self.config["segmentation_mask"]
self.extract_nucleus_mask = True

elif "cytosol" in self.config["segmentation_mask"]:
self.cytosol_key = self.config["segmentation_mask"]
self.extract_cytosol_mask = True
else:
raise ValueError(
f"Segmentation mask {self.config['segmentation_mask']} is not a valid mask to extract from."
)

elif isinstance(self.config["segmentation_mask"], list):
assert all(x in allowed_mask_values for x in self.config["segmentation_mask"])

for x in self.config["segmentation_mask"]:
if "nucleus" in x:
self.nucleus_key = x
self.extract_nucleus_mask = True
if "cytosol" in x:
self.cytosol_key = x
self.extract_cytosol_mask = True

if cytosol_key in relevant_masks:
self.extract_cytosol_mask = True
self.cytosol_key = cytosol_key
else:
self.extract_cytosol_mask = False
self.cytosol_key = None
# get relevant segmentation masks to perform extraction on
nucleus_key = f"{self.segmentation_key}_nucleus"

if nucleus_key in relevant_masks:
self.extract_nucleus_mask = True
self.nucleus_key = nucleus_key

cytosol_key = f"{self.segmentation_key}_cytosol"

if cytosol_key in relevant_masks:
self.extract_cytosol_mask = True
self.cytosol_key = cytosol_key

self.n_masks = np.sum([self.extract_nucleus_mask, self.extract_cytosol_mask])
self.masks = [x for x in [self.nucleus_key, self.cytosol_key] if x is not None]
Expand Down Expand Up @@ -415,7 +452,7 @@ def _save_removed_classes(self, classes):
# define path where classes should be saved
filtered_path = os.path.join(
self.project_location,
self.DEFAULT_SEGMENTATION_DIR_NAME,
self.DEFAULT_EXTRACTION_DIR_NAME,
self.DEFAULT_REMOVED_CLASSES_FILE,
)

Expand Down Expand Up @@ -636,7 +673,7 @@ def _transfer_tempmmap_to_hdf5(self):
axs[i].imshow(img, vmin=0, vmax=1)
axs[i].axis("off")
fig.tight_layout()
fig.show()
plt.show(fig)

self.log("Transferring extracted single cells to .hdf5")

Expand All @@ -651,7 +688,8 @@ def _transfer_tempmmap_to_hdf5(self):
) # increase to 64 bit otherwise information may become truncated

self.log("single-cell index created.")
self._clear_cache(vars_to_delete=[cell_ids])
del cell_ids
# self._clear_cache(vars_to_delete=[cell_ids]) # this is not working as expected so we will just delete the variable directly

_, c, x, y = _tmp_single_cell_data.shape
single_cell_data = hf.create_dataset(
Expand All @@ -668,7 +706,8 @@ def _transfer_tempmmap_to_hdf5(self):
single_cell_data[ix] = _tmp_single_cell_data[i]

self.log("single-cell data created")
self._clear_cache(vars_to_delete=[single_cell_data])
del single_cell_data
# self._clear_cache(vars_to_delete=[single_cell_data]) # this is not working as expected so we will just delete the variable directly

# also transfer labelled index to HDF5
index_labelled = _tmp_single_cell_index[keep_index]
Expand All @@ -684,18 +723,27 @@ def _transfer_tempmmap_to_hdf5(self):
hf.create_dataset("single_cell_index_labelled", data=index_labelled, chunks=None, dtype=dt)

self.log("single-cell index labelled created.")
self._clear_cache(vars_to_delete=[index_labelled])
del index_labelled
# self._clear_cache(vars_to_delete=[index_labelled]) # this is not working as expected so we will just delete the variable directly

hf.create_dataset(
"channel_information",
data=np.char.encode(self.channel_names.astype(str)),
dtype=h5py.special_dtype(vlen=str),
)

hf.create_dataset(
"n_masks",
data=self.n_masks,
dtype=int,
)

self.log("channel information created.")

# cleanup memory
self._clear_cache(vars_to_delete=[_tmp_single_cell_index, index_labelled])
del _tmp_single_cell_index
# self._clear_cache(vars_to_delete=[_tmp_single_cell_index]) # this is not working as expected so we will just delete the variable directly

os.remove(self._tmp_single_cell_data_path)
os.remove(self._tmp_single_cell_index_path)

Expand Down Expand Up @@ -808,7 +856,6 @@ def process(self, partial=False, n_cells=None, seed=42):
# directory where intermediate results should be saved
cache: "/mnt/temp/cache"
"""

total_time_start = timeit.default_timer()

start_setup = timeit.default_timer()
Expand Down Expand Up @@ -871,31 +918,33 @@ def process(self, partial=False, n_cells=None, seed=42):

self.log("Running in single threaded mode.")
results = []
for arg in tqdm(args):
for arg in tqdm(args, total=len(args), desc="Processing cell batches"):
x = f(arg)
results.append(x)
else:
# set up function for multi-threaded processing
f = func_partial(self._extract_classes_multi, self.px_centers)
batched_args = self._generate_batched_args(args)
args = self._generate_batched_args(args)

self.log(f"Running in multiprocessing mode with {self.threads} threads.")
with mp.get_context("fork").Pool(
processes=self.threads
) as pool: # both spawn and fork work but fork is faster so forcing fork here
results = list(
tqdm(
pool.imap(f, batched_args),
total=len(batched_args),
pool.imap(f, args),
total=len(args),
desc="Processing cell batches",
)
)
pool.close()
pool.join()
print("multiprocessing done.")

self.save_index_to_remove = flatten(results)

# cleanup memory and remove any no longer required variables
del results, args
# self._clear_cache(vars_to_delete=["results", "args"]) # this is not working as expected at the moment so need to manually delete the variables
stop_extraction = timeit.default_timer()

# calculate duration
Expand All @@ -912,7 +961,6 @@ def process(self, partial=False, n_cells=None, seed=42):
self.DEFAULT_LOG_NAME = "processing.log" # change log name back to default

self._post_extraction_cleanup()

total_time_stop = timeit.default_timer()
total_time = total_time_stop - total_time_start

Expand Down
Loading

0 comments on commit 758e7c5

Please sign in to comment.