diff --git a/examples/notebooks b/examples/notebooks index 5a9b127..d5ea844 160000 --- a/examples/notebooks +++ b/examples/notebooks @@ -1 +1 @@ -Subproject commit 5a9b127f06a39d326931728a0cf9850848fca205 +Subproject commit d5ea844b033e18d5fc3c82213c8bfde93465d47f diff --git a/requirements.txt b/requirements.txt index b12d8e4..59b101a 100755 --- a/requirements.txt +++ b/requirements.txt @@ -29,10 +29,10 @@ torch pytorch-lightning torchvision -spatialdata +spatialdata>=0.2.0 napari-spatialdata pyqt5 lxml_html_clean ashlar>=1.19.0 networkx -py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd +py-lmd>=1.3.1 diff --git a/requirements_dev.txt b/requirements_dev.txt index 088c9fb..9ff75f1 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -29,13 +29,13 @@ torch pytorch-lightning torchvision -spatialdata +spatialdata>=0.2.0 napari-spatialdata pyqt5 lxml_html_clean ashlar>=1.19.0 networkx -py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd +py-lmd>=1.3.1 #packages for building the documentation sphinx diff --git a/src/scportrait/__init__.py b/src/scportrait/__init__.py index 972a0f8..772f875 100644 --- a/src/scportrait/__init__.py +++ b/src/scportrait/__init__.py @@ -1,7 +1,18 @@ """Top-level package for scPortrait""" +# silence warnings +import warnings + from scportrait import io from scportrait import pipeline as pipeline from scportrait import plotting as pl from scportrait import processing as pp from scportrait import tools as tl + +# silence warning from spatialdata resulting in an older dask version see #139 +warnings.filterwarnings("ignore", message="ignoring keyword argument 'read_only'") + +# silence warning from cellpose resulting in missing parameter set in model call see #141 +warnings.filterwarnings( + "ignore", message=r"You are using `torch.load` with `weights_only=False`.*", category=FutureWarning +) diff --git a/src/scportrait/pipeline/_base.py b/src/scportrait/pipeline/_base.py index ab35274..d241cdb 100644 --- a/src/scportrait/pipeline/_base.py +++ b/src/scportrait/pipeline/_base.py @@ -9,6 +9,8 @@ import numpy as np import torch +from scportrait.pipeline._utils.helper import read_config + class Logable: """Create log entries. @@ -92,6 +94,27 @@ def _clean_log_file(self): if os.path.exists(log_file_path): os.remove(log_file_path) + # def _clear_cache(self, vars_to_delete=None): + # """Helper function to help clear memory usage. Mainly relevant for GPU based segmentations. + + # Args: + # vars_to_delete (list): List of variable names (as strings) to delete. + # """ + + # # delete all specified variables + # if vars_to_delete is not None: + # for var_name in vars_to_delete: + # if var_name in globals(): + # del globals()[var_name] + + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() + + # if torch.backends.mps.is_available(): + # torch.mps.empty_cache() + + # gc.collect() + def _clear_cache(self, vars_to_delete=None): """Helper function to help clear memory usage. Mainly relevant for GPU based segmentations.""" @@ -137,7 +160,7 @@ class ProcessingStep(Logable): DEFAULT_SEGMENTATION_DIR_NAME = "segmentation" DEFAULT_TILES_FOLDER = "tiles" - DEFAULT_EXTRACTIN_DIR_NAME = "extraction" + DEFAULT_EXTRACTION_DIR_NAME = "extraction" DEFAULT_DATA_DIR = "data" DEFAULT_IMAGE_DTYPE = np.uint16 @@ -155,19 +178,41 @@ class ProcessingStep(Logable): DEFAULT_SELECTION_DIR_NAME = "selection" def __init__( - self, config, directory, project_location, debug=False, overwrite=False, project=None, filehandler=None + self, + config, + directory=None, + project_location=None, + debug=False, + overwrite=False, + project=None, + filehandler=None, + from_project: bool = False, ): super().__init__(directory=directory) self.debug = debug self.overwrite = overwrite - self.project_location = project_location - self.config = config + if from_project: + self.project_run = True + self.project_location = project_location + self.project = project + self.filehandler = filehandler + else: + self.project_run = False + self.project_location = None + self.project = None + self.filehandler = None + + if isinstance(config, str): + config = read_config(config) + if self.__class__.__name__ in config.keys(): + self.config = config[self.__class__.__name__] + else: + self.config = config + else: + self.config = config self.overwrite = overwrite - self.project = project - self.filehandler = filehandler - self.get_context() self.deep_debug = False diff --git a/src/scportrait/pipeline/_utils/helper.py b/src/scportrait/pipeline/_utils/helper.py index 9b11fa2..e930104 100644 --- a/src/scportrait/pipeline/_utils/helper.py +++ b/src/scportrait/pipeline/_utils/helper.py @@ -1,9 +1,20 @@ from typing import TypeVar +import yaml + T = TypeVar("T") -def flatten(nested_list: list[list[T]]) -> list[T]: +def read_config(config_path: str) -> dict: + with open(config_path) as stream: + try: + config = yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + return config + + +def flatten(nested_list: list[list[T]]) -> list[T | tuple[T]]: """Flatten a list of lists into a single list. Args: diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 47017a1..89c480e 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -71,10 +71,12 @@ def _read_sdata(self) -> SpatialData: _sdata = SpatialData() _sdata.write(self.sdata_path, overwrite=True) + allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"] for key in _sdata.labels: - segmentation_object = _sdata.labels[key] - if not hasattr(segmentation_object.attrs, "cell_ids"): - segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None) + if key in allowed_labels: + segmentation_object = _sdata.labels[key] + if not hasattr(segmentation_object.attrs, "cell_ids"): + segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None) return _sdata diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index ff51dfd..20bfbb7 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -57,7 +57,13 @@ def __init__(self, *args, **kwargs): self.overwrite_run_path = self.overwrite def _get_compression_type(self): - self.compression_type = "lzf" if self.compression else None + if (self.compression is True) or (self.compression == "lzf"): + self.compression_type = "lzf" + elif self.compression == "gzip": + self.compression_type = "gzip" + else: + self.compression_type = None + self.log(f"Compression algorithm: {self.compression_type}") return self.compression_type def _check_config(self): @@ -261,24 +267,55 @@ def _get_segmentation_info(self): f"Found no segmentation masks with key {self.segmentation_key}. Cannot proceed with extraction." ) - # get relevant segmentation masks to perform extraction on - nucleus_key = f"{self.segmentation_key}_nucleus" + # intialize default values to track what should be extracted + self.nucleus_key = None + self.cytosol_key = None + self.extract_nucleus_mask = False + self.extract_cytosol_mask = False - if nucleus_key in relevant_masks: - self.extract_nucleus_mask = True - self.nucleus_key = nucleus_key - else: - self.extract_nucleus_mask = False - self.nucleus_key = None + if "segmentation_mask" in self.config: + allowed_mask_values = ["nucleus", "cytosol"] + allowed_mask_values = [f"{self.segmentation_key}_{x}" for x in allowed_mask_values] + + if isinstance(self.config["segmentation_mask"], str): + assert self.config["segmentation_mask"] in allowed_mask_values - cytosol_key = f"{self.segmentation_key}_cytosol" + if "nucleus" in self.config["segmentation_mask"]: + self.nucleus_key = self.config["segmentation_mask"] + self.extract_nucleus_mask = True + + elif "cytosol" in self.config["segmentation_mask"]: + self.cytosol_key = self.config["segmentation_mask"] + self.extract_cytosol_mask = True + else: + raise ValueError( + f"Segmentation mask {self.config['segmentation_mask']} is not a valid mask to extract from." + ) + + elif isinstance(self.config["segmentation_mask"], list): + assert all(x in allowed_mask_values for x in self.config["segmentation_mask"]) + + for x in self.config["segmentation_mask"]: + if "nucleus" in x: + self.nucleus_key = x + self.extract_nucleus_mask = True + if "cytosol" in x: + self.cytosol_key = x + self.extract_cytosol_mask = True - if cytosol_key in relevant_masks: - self.extract_cytosol_mask = True - self.cytosol_key = cytosol_key else: - self.extract_cytosol_mask = False - self.cytosol_key = None + # get relevant segmentation masks to perform extraction on + nucleus_key = f"{self.segmentation_key}_nucleus" + + if nucleus_key in relevant_masks: + self.extract_nucleus_mask = True + self.nucleus_key = nucleus_key + + cytosol_key = f"{self.segmentation_key}_cytosol" + + if cytosol_key in relevant_masks: + self.extract_cytosol_mask = True + self.cytosol_key = cytosol_key self.n_masks = np.sum([self.extract_nucleus_mask, self.extract_cytosol_mask]) self.masks = [x for x in [self.nucleus_key, self.cytosol_key] if x is not None] @@ -415,7 +452,7 @@ def _save_removed_classes(self, classes): # define path where classes should be saved filtered_path = os.path.join( self.project_location, - self.DEFAULT_SEGMENTATION_DIR_NAME, + self.DEFAULT_EXTRACTION_DIR_NAME, self.DEFAULT_REMOVED_CLASSES_FILE, ) @@ -636,7 +673,7 @@ def _transfer_tempmmap_to_hdf5(self): axs[i].imshow(img, vmin=0, vmax=1) axs[i].axis("off") fig.tight_layout() - fig.show() + plt.show(fig) self.log("Transferring extracted single cells to .hdf5") @@ -651,7 +688,8 @@ def _transfer_tempmmap_to_hdf5(self): ) # increase to 64 bit otherwise information may become truncated self.log("single-cell index created.") - self._clear_cache(vars_to_delete=[cell_ids]) + del cell_ids + # self._clear_cache(vars_to_delete=[cell_ids]) # this is not working as expected so we will just delete the variable directly _, c, x, y = _tmp_single_cell_data.shape single_cell_data = hf.create_dataset( @@ -668,7 +706,8 @@ def _transfer_tempmmap_to_hdf5(self): single_cell_data[ix] = _tmp_single_cell_data[i] self.log("single-cell data created") - self._clear_cache(vars_to_delete=[single_cell_data]) + del single_cell_data + # self._clear_cache(vars_to_delete=[single_cell_data]) # this is not working as expected so we will just delete the variable directly # also transfer labelled index to HDF5 index_labelled = _tmp_single_cell_index[keep_index] @@ -684,7 +723,8 @@ def _transfer_tempmmap_to_hdf5(self): hf.create_dataset("single_cell_index_labelled", data=index_labelled, chunks=None, dtype=dt) self.log("single-cell index labelled created.") - self._clear_cache(vars_to_delete=[index_labelled]) + del index_labelled + # self._clear_cache(vars_to_delete=[index_labelled]) # this is not working as expected so we will just delete the variable directly hf.create_dataset( "channel_information", @@ -692,10 +732,18 @@ def _transfer_tempmmap_to_hdf5(self): dtype=h5py.special_dtype(vlen=str), ) + hf.create_dataset( + "n_masks", + data=self.n_masks, + dtype=int, + ) + self.log("channel information created.") # cleanup memory - self._clear_cache(vars_to_delete=[_tmp_single_cell_index, index_labelled]) + del _tmp_single_cell_index + # self._clear_cache(vars_to_delete=[_tmp_single_cell_index]) # this is not working as expected so we will just delete the variable directly + os.remove(self._tmp_single_cell_data_path) os.remove(self._tmp_single_cell_index_path) @@ -808,7 +856,6 @@ def process(self, partial=False, n_cells=None, seed=42): # directory where intermediate results should be saved cache: "/mnt/temp/cache" """ - total_time_start = timeit.default_timer() start_setup = timeit.default_timer() @@ -871,13 +918,13 @@ def process(self, partial=False, n_cells=None, seed=42): self.log("Running in single threaded mode.") results = [] - for arg in tqdm(args): + for arg in tqdm(args, total=len(args), desc="Processing cell batches"): x = f(arg) results.append(x) else: # set up function for multi-threaded processing f = func_partial(self._extract_classes_multi, self.px_centers) - batched_args = self._generate_batched_args(args) + args = self._generate_batched_args(args) self.log(f"Running in multiprocessing mode with {self.threads} threads.") with mp.get_context("fork").Pool( @@ -885,17 +932,19 @@ def process(self, partial=False, n_cells=None, seed=42): ) as pool: # both spawn and fork work but fork is faster so forcing fork here results = list( tqdm( - pool.imap(f, batched_args), - total=len(batched_args), + pool.imap(f, args), + total=len(args), desc="Processing cell batches", ) ) pool.close() pool.join() - print("multiprocessing done.") self.save_index_to_remove = flatten(results) + # cleanup memory and remove any no longer required variables + del results, args + # self._clear_cache(vars_to_delete=["results", "args"]) # this is not working as expected at the moment so need to manually delete the variables stop_extraction = timeit.default_timer() # calculate duration @@ -912,7 +961,6 @@ def process(self, partial=False, n_cells=None, seed=42): self.DEFAULT_LOG_NAME = "processing.log" # change log name back to default self._post_extraction_cleanup() - total_time_stop = timeit.default_timer() total_time = total_time_stop - total_time_start diff --git a/src/scportrait/pipeline/featurization.py b/src/scportrait/pipeline/featurization.py index 316b6aa..9a0cbee 100644 --- a/src/scportrait/pipeline/featurization.py +++ b/src/scportrait/pipeline/featurization.py @@ -5,6 +5,7 @@ from contextlib import redirect_stdout from functools import partial as func_partial +import h5py import numpy as np import pandas as pd import pytorch_lightning as pl @@ -36,6 +37,7 @@ def __init__(self, *args, **kwargs): self.model = None self.transforms = None self.expected_imagesize = None + self.data_type = None self._setup_channel_selection() @@ -59,7 +61,10 @@ def _setup_output(self): if not os.path.isdir(self.directory): os.makedirs(self.directory) - self.run_path = os.path.join(self.directory, f"{self.data_type}_{self.label}") + if self.data_type is None: + self.run_path = os.path.join(self.directory, self.label) + else: + self.run_path = os.path.join(self.directory, f"{self.data_type}_{self.label}") if not os.path.isdir(self.run_path): os.makedirs(self.run_path) @@ -104,6 +109,25 @@ def _detect_automatic_inference_device(self): return inference_device + def _get_nmasks(self): + if "n_masks" not in self.__dict__.keys(): + if isinstance(self.extraction_file, str): + with h5py.File(self.extraction_file, "r") as f: + self.n_masks = f["n_masks"][()].item() + if isinstance(self.extraction_file, list): + n_masks = [] + for file in self.extraction_file: + with h5py.File(file, "r") as f: + n_masks.append(f["n_masks"][()].item()) + assert ( + x == n_masks[0] for x in n_masks + ), "number of masks are not consistent over all passed HDF5 files." + self.n_masks = n_masks[0] + try: + self.n_masks = h5py.File(self.extraction_file, "r")["n_masks"][()].item() + except Exception as e: + raise ValueError(f"Could not extract number of masks from HDF5 file. Error: {e}") from e + def _setup_inference_device(self): """ Configure the featurization run to use the specified inference device. @@ -166,10 +190,13 @@ def _setup_inference_device(self): self.inference_device = self._detect_automatic_inference_device() self.log(f"Automatically configured inferece device to {self.inference_device}") - def _general_setup(self): + def _general_setup(self, extraction_dir: str | list[str], return_results: bool = False): """Helper function to execute all setup functions that are common to all featurization steps.""" - self._setup_output() + self.extraction_file = extraction_dir + if not return_results: + self._setup_output() + self._get_nmasks() self._setup_log_transform() self._setup_inference_device() @@ -391,7 +418,8 @@ def configure_transforms(self, selected_transforms: list): def generate_dataloader( self, - extraction_dir: str, + extraction_dir: str | list[str], + labels: int | list[int] = 0, selected_transforms: transforms.Compose = transforms.Compose([]), size: int = 0, seed: int | None = 42, @@ -428,11 +456,20 @@ def generate_dataloader( self.log(f"Expected image size is set to {self.expected_imagesize}. Resizing images to this size.") t = transforms.Compose([t, transforms.Resize(self.expected_imagesize)]) + if isinstance(extraction_dir, list): + assert isinstance(labels, list), "If multiple directories are provided, multiple labels must be provided." + paths = extraction_dir + labels = labels + elif isinstance(extraction_dir, str): + assert isinstance(labels, int), "If only one directory is provided, only one label must be provided." + paths = [extraction_dir] + labels = [labels] + f = io.StringIO() with redirect_stdout(f): dataset = dataset_class( - dir_list=[extraction_dir], - dir_labels=[0], + dir_list=paths, + dir_labels=labels, transform=t, return_id=True, select_channel=self.channel_selection, @@ -780,8 +817,8 @@ def _setup_transforms(self) -> None: return - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir: str, return_results: bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._get_model_specs() self._get_network_dir() @@ -799,7 +836,7 @@ def _setup(self): self._setup_encoders() self._setup_transforms() - def process(self, extraction_dir: str, size: int = 0): + def process(self, extraction_dir: str, labels: int | list[int] = 0, size: int = 0, return_results: bool = False): """ Perform classification on the provided HDF5 dataset. @@ -876,31 +913,39 @@ class based on the previous single-cell extraction. Therefore, only the second a self.log("Started MLClusterClassifier classification.") # perform setup - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, + labels=labels, selected_transforms=self.transforms, size=size, dataset_class=self.DEFAULT_DATA_LOADER, ) # perform inference + all_results = [] for model in self.models: self.log(f"Starting inference for model encoder {model.__name__}") results = self.inference(self.dataloader, model) - output_name = f"inference_{model.__name__}" - path = os.path.join(self.run_path, f"{output_name}.csv") + if not return_results: + output_name = f"inference_{model.__name__}" + path = os.path.join(self.run_path, f"{output_name}.csv") - self._write_results_csv(results, path) - self._write_results_sdata(results, label=f"{self.label}_{model.__name__}") - - self.log(f"Results saved to file: {path}") + self._write_results_csv(results, path) + self._write_results_sdata(results, label=f"{self.label}_{model.__name__}") + else: + all_results.append(results) - # perform post processing cleanup - if not self.deep_debug: - self._post_processing_cleanup() + if return_results: + self._clear_cache() + return all_results + else: + self.log(f"Results saved to file: {path}") + # perform post processing cleanup + if not self.deep_debug: + self._post_processing_cleanup() class EnsembleClassifier(_FeaturizationBase): @@ -952,8 +997,8 @@ def _load_models(self): memory_usage = self._get_gpu_memory_usage() self.log(f"GPU memory usage after loading models: {memory_usage}") - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir: str, return_results: bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._get_model_specs() self._setup_transforms() @@ -965,7 +1010,7 @@ def _setup(self): self._load_models() - def process(self, extraction_dir, size=0): + def process(self, extraction_dir: str, labels: int | list[int] = 0, size: int = 0, return_results: bool = False): """ Function called to perform classification on the provided HDF5 dataset. @@ -1020,29 +1065,39 @@ class based on the previous single-cell extraction. Therefore, no parameters nee self.log("Starting Ensemble Classification") - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, + labels=labels, selected_transforms=self.transforms, size=size, dataset_class=self.DEFAULT_DATA_LOADER, ) # perform inference + all_results = {} for model_name, model in zip(self.model_names, self.model, strict=False): self.log(f"Starting inference for model {model_name}") results = self.inference(self.dataloader, model) output_name = f"ensemble_inference_{model_name}" - path = os.path.join(self.run_path, f"{output_name}.csv") - self._write_results_csv(results, path) - self._write_results_sdata(results, label=model_name) + if not return_results: + path = os.path.join(self.run_path, f"{output_name}.csv") - # perform post processing cleanup - if not self.deep_debug: - self._post_processing_cleanup() + self._write_results_csv(results, path) + self._write_results_sdata(results, label=model_name) + else: + all_results[model_name] = results + + if return_results: + self._clear_cache() + return all_results + else: + # perform post processing cleanup + if not self.deep_debug: + self._post_processing_cleanup() ####### CellFeaturization based on Classic Featurecalculation ####### @@ -1079,10 +1134,24 @@ def _setup_transforms(self): return def _get_channel_specs(self): - if "channel_names" in self.project.__dict__.keys(): - self.channel_names = self.project.channel_names + if self.project is None: + if isinstance(self.extraction_file, str): + with h5py.File(self.extraction_file, "r") as f: + self.channel_names = list(f["channel_information"][:].astype(str)) + if isinstance(self.extraction_file, list): + channel_names = [] + for file in self.extraction_file: + with h5py.File(file, "r") as f: + channel_names.append(list(f["channel_information"][:].astype(str))) + assert ( + x == channel_names[0] for x in channel_names + ), "Channel names are not consistent over all passed HDF5 files." + self.channel_names = channel_names[0] else: - self.channel_names = self.project.input_image.c.values + if "channel_names" in self.project.__dict__.keys(): + self.channel_names = self.project.channel_names + else: + self.channel_names = self.project.input_image.c.values def _generate_column_names( self, @@ -1294,12 +1363,14 @@ def __init__(self, *args, **kwargs): self.channel_selection = None # ensure that all images are passed to the function - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir: str | list[str], return_results: bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._setup_transforms() self._get_channel_specs() - def process(self, extraction_dir, size=0): + def process( + self, extraction_dir: str | list[str], labels: int | list[int] = 0, size: int = 0, return_results: bool = False + ): """ Perform featurization on the provided HDF5 dataset. @@ -1354,10 +1425,11 @@ def process(self, extraction_dir, size=0): self.log("Started CellFeaturization of all available channels.") # perform setup - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, + labels=labels, selected_transforms=self.transforms, size=size, dataset_class=self.DEFAULT_DATA_LOADER, @@ -1384,15 +1456,19 @@ def process(self, extraction_dir, size=0): column_names=self.column_names, ) - output_name = "calculated_image_features" - path = os.path.join(self.run_path, f"{output_name}.csv") + if return_results: + self._clear_cache() + return results + else: + output_name = "calculated_image_features" + path = os.path.join(self.run_path, f"{output_name}.csv") - self._write_results_csv(results, path) - self._write_results_sdata(results) + self._write_results_csv(results, path) + self._write_results_sdata(results) - # perform post processing cleanup - if not self.deep_debug: - self._post_processing_cleanup() + # perform post processing cleanup + if not self.deep_debug: + self._post_processing_cleanup() class CellFeaturizer_single_channel(_cellFeaturizerBase): @@ -1408,20 +1484,23 @@ def _setup_channel_selection(self): self.channel_selection = [0, self.channel_selection] return - def _setup(self): - self._general_setup() + def _setup(self, extraction_dir: str | list[str], return_results: bool): + self._general_setup(extraction_dir=extraction_dir, return_results=return_results) self._setup_channel_selection() self._setup_transforms() self._get_channel_specs() - def process(self, extraction_dir, size=0): + def process( + self, extraction_dir: str | list[str], labels: int | list[int] = 0, size=0, return_results: bool = False + ): self.log(f"Started CellFeaturization of selected channel {self.channel_selection}.") # perform setup - self._setup() + self._setup(extraction_dir=extraction_dir, return_results=return_results) self.dataloader = self.generate_dataloader( extraction_dir, + labels=labels, selected_transforms=self.transforms, size=size, dataset_class=self.DEFAULT_DATA_LOADER, diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index ae20d64..c219ee3 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -23,7 +23,6 @@ import numpy as np import psutil import xarray -import yaml from alphabase.io import tempmmap from napari_spatialdata import Interactive from ome_zarr.io import parse_url @@ -33,6 +32,7 @@ from scportrait.io import daskmmap from scportrait.pipeline._base import Logable +from scportrait.pipeline._utils.helper import read_config from scportrait.pipeline._utils.sdata_io import sdata_filehandler from scportrait.pipeline._utils.spatialdata_helper import ( calculate_centroids, @@ -94,7 +94,7 @@ class Project(Logable): def __init__( self, project_location: str, - config_path: str, + config_path: str = None, segmentation_f=None, extraction_f=None, featurization_f=None, @@ -185,11 +185,7 @@ def _load_config_from_file(self, file_path): if not os.path.isfile(file_path): raise ValueError(f"Your config path {file_path} is invalid.") - with open(file_path) as stream: - try: - self.config = yaml.safe_load(stream) - except yaml.YAMLError as exc: - print(exc) + self.config = read_config(file_path) def _get_config_file(self, config_path: str | None = None) -> None: """Load the config file for the project. If no config file is passed the default config file in the project directory is loaded. @@ -257,6 +253,7 @@ def _setup_segmentation_f(self, segmentation_f): overwrite=self.overwrite, project=None, filehandler=self.filehandler, + from_project=True, ) def _setup_extraction_f(self, extraction_f): @@ -285,6 +282,7 @@ def _setup_extraction_f(self, extraction_f): overwrite=self.overwrite, project=self, filehandler=self.filehandler, + from_project=True, ) def _setup_featurization_f(self, featurization_f): @@ -309,9 +307,10 @@ def _setup_featurization_f(self, featurization_f): self.featurization_directory, project_location=self.project_location, debug=self.debug, - overwrite=self.overwrite, + overwrite=False, # this needs to be set to false as the featurization step should not remove previously created features project=self, filehandler=self.filehandler, + from_project=True, ) def _setup_selection(self, selection_f): @@ -339,6 +338,7 @@ def _setup_selection(self, selection_f): overwrite=self.overwrite, project=self, filehandler=self.filehandler, + from_project=True, ) def update_featurization_f(self, featurization_f): @@ -888,9 +888,10 @@ def load_input_from_sdata( # ensure that the provided nucleus and cytosol segmentations fullfill the scPortrait requirements # requirements are: # 1. The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids - assert ( - self.sdata[self.nuc_seg_name].attrs["cell_ids"] == self.sdata[self.cyto_seg_name].attrs["cell_ids"] - ), "The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids." + if self.nuc_seg_status in self.sdata.keys() and self.cyto_seg_status in self.sdata.keys(): + assert ( + self.sdata[self.nuc_seg_name].attrs["cell_ids"] == self.sdata[self.cyto_seg_name].attrs["cell_ids"] + ), "The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids." # 2. the nucleus segmentation ids and the cytosol segmentation ids need to match # THIS NEEDS TO BE IMPLEMENTED HERE @@ -1066,6 +1067,8 @@ def featurize( # setup overwrite if specified in call if overwrite is not None: self.featurization_f.overwrite_run_path = overwrite + if overwrite is None: + self.featurization_f.overwrite_run_path = True # update the number of masks that are available in the segmentation object self.featurization_f.n_masks = sum([self.nuc_seg_status, self.cyto_seg_status]) @@ -1079,7 +1082,6 @@ def select( self, cell_sets: list[dict], calibration_marker: np.ndarray | None = None, - segmentation_name: str = "seg_all_nucleus", name: str | None = None, ): """ @@ -1091,14 +1093,12 @@ def select( self._check_sdata_status() - if not self.nuc_seg_status or not self.cyto_seg_status: + if not self.nuc_seg_status and not self.cyto_seg_status: raise ValueError("No nucleus or cytosol segmentation loaded. Please load a segmentation first.") assert self.sdata is not None, "No sdata object loaded." - assert segmentation_name in self.sdata.labels, f"Segmentation {segmentation_name} not found in sdata object." self.selection_f( - segmentation_name=segmentation_name, cell_sets=cell_sets, calibration_marker=calibration_marker, name=name, diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 9b2edcd..82e1e8f 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -85,6 +85,7 @@ def __init__( overwrite, project, filehandler, + from_project: bool = False, **kwargs, ): super().__init__( @@ -95,6 +96,7 @@ def __init__( overwrite=overwrite, project=project, filehandler=filehandler, + from_project=from_project, ) if self.directory is not None: @@ -742,7 +744,6 @@ def _resolve_sharding(self, sharding_plan): local_hf = h5py.File(local_output, "r") local_hdf_labels = local_hf.get(self.DEFAULT_MASK_NAME)[:] - print(type(local_hdf_labels)) shifted_map, edge_labels = shift_labels( local_hdf_labels, class_id_shift, @@ -902,8 +903,9 @@ def _resolve_sharding(self, sharding_plan): if not self.deep_debug: self._cleanup_shards(sharding_plan) - def _initializer_function(self, gpu_id_list): + def _initializer_function(self, gpu_id_list, n_processes): current_process().gpu_id_list = gpu_id_list + current_process().n_processes = n_processes def _perform_segmentation(self, shard_list): # get GPU status @@ -921,7 +923,7 @@ def _perform_segmentation(self, shard_list): with mp.get_context(self.context).Pool( processes=self.n_processes, initializer=self._initializer_function, - initargs=[self.gpu_id_list], + initargs=[self.gpu_id_list, self.n_processes], ) as pool: list( tqdm( diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 008a272..b677597 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -653,7 +653,11 @@ def _check_for_mask_matching_filtering(self) -> None: else: # add deprecation warning for old config setup if "filter_status" in self.config.keys(): - Warning("filter_status is deprecated, please use match_masks instead Will not perform filtering.") + self.filter_match_masks = True + self.mask_matching_filtering_threshold = 0.95 + Warning( + "filter_status is deprecated, please use match_masks instead. Will use default settings for mask matching." + ) # default behaviour that this filtering should be performed, otherwise another additional step is required before extraction self.filter_match_masks = True @@ -1349,6 +1353,9 @@ def _check_gpu_status(self): gpu_id_list = current.gpu_id_list cpu_id = int(cpu_name[cpu_name.find("-") + 1 :]) - 1 + if cpu_id >= len(gpu_id_list): + cpu_id = cpu_id % current.n_processes + # track gpu_id and update GPU status self.gpu_id = gpu_id_list[cpu_id] self.status = "multi_GPU" diff --git a/src/scportrait/pipeline/selection.py b/src/scportrait/pipeline/selection.py index 0afb36a..aa53235 100644 --- a/src/scportrait/pipeline/selection.py +++ b/src/scportrait/pipeline/selection.py @@ -1,10 +1,20 @@ +import multiprocessing as mp import os +import pickle +import timeit +from functools import partial as func_partial +import h5py +import matplotlib.pyplot as plt import numpy as np +import pandas as pd from alphabase.io import tempmmap from lmd.lib import SegmentationLoader +from scipy.sparse import coo_array +from tqdm.auto import tqdm from scportrait.pipeline._base import ProcessingStep +from scportrait.pipeline._utils.helper import flatten class LMDSelection(ProcessingStep): @@ -13,20 +23,60 @@ class LMDSelection(ProcessingStep): This method class relies on the functionality of the pylmd library. """ - # define all valid path optimization methods used with the "path_optimization" argument in the configuration - VALID_PATH_OPTIMIZERS = ["none", "hilbert", "greedy"] - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._check_config() self.name = None self.cell_sets = None self.calibration_marker = None - def _setup_selection(self): - # set orientation transform - self.config["orientation_transform"] = np.array([[0, -1], [1, 0]]) + self.deep_debug = False # flag for deep debugging by developers + + def _check_config(self): + assert "segmentation_channel" in self.config, "segmentation_channel not defined in config" + self.segmentation_channel_to_select = self.config["segmentation_channel"] + + # check for optional config parameters + + # this defines how large the box mask around the center of a cell is for the coordinate extraction + # assumption is that all pixels belonging to each mask are within the box otherwise they will be cut off during cutting contour generation + + if "cell_width" in self.config: + self.cell_radius = self.config["cell_width"] + else: + self.cell_radius = 100 + + if "threads" in self.config: + self.threads = self.config["threads"] + assert self.threads > 0, "threads must be greater than 0" + assert isinstance(self.threads, int), "threads must be an integer" + else: + self.threads = 10 + + if "batch_size_coordinate_extraction" in self.config: + self.batch_size = self.config["batch_size_coordinate_extraction"] + assert self.batch_size > 0, "batch_size_coordinate_extraction must be greater than 0" + assert isinstance(self.batch_size, int), "batch_size_coordinate_extraction must be an integer" + else: + self.batch_size = 100 + if "orientation_transform" in self.config: + self.orientation_transform = self.config["orientation_transform"] + else: + self.orientation_transform = np.array([[0, -1], [1, 0]]) + self.config["orientation_transform"] = ( + self.orientation_transform + ) # ensure its also in config so its passed on to the segmentation loader + + if "processes_cell_sets" in self.config: + self.processes_cell_sets = self.config["processes_cell_sets"] + assert self.processes_cell_sets > 0, "processes_cell_sets must be greater than 0" + assert isinstance(self.processes_cell_sets, int), "processes_cell_sets must be an integer" + else: + self.processes_cell_sets = 1 + + def _setup_selection(self): # configure name of extraction if self.name is None: try: @@ -39,6 +89,111 @@ def _setup_selection(self): savename = name.replace(" ", "_") + ".xml" self.savepath = os.path.join(self.directory, savename) + # check that the segmentation label exists + assert ( + self.segmentation_channel_to_select in self.project.filehandler.get_sdata()._shared_keys + ), f"Segmentation channel {self.segmentation_channel_to_select} not found in sdata." + + def __get_coords( + self, cell_ids: list, centers: list[tuple[int, int]], width: int = 60 + ) -> list[tuple[int, np.ndarray]]: + results = [] + + _sdata = self.project.filehandler.get_sdata() + for i, _id in enumerate(cell_ids): + values = centers[i] + + x_start = np.max([int(values[0]) - width, 0]) + y_start = np.max([int(values[1]) - width, 0]) + + x_end = x_start + width * 2 + y_end = y_start + width * 2 + + _cropped = _sdata[self.segmentation_channel_to_select][ + slice(x_start, x_end), slice(y_start, y_end) + ].compute() + + # optional plotting output for deep debugging + if self.deep_debug: + if self.threads == 1: + plt.figure() + plt.imshow(_cropped) + plt.show() + else: + raise ValueError("Deep debug is not supported with multiple threads.") + + sparse = coo_array(_cropped == _id) + + if ( + 0 in sparse.coords[0] + or 0 in sparse.coords[1] + or width * 2 - 1 in sparse.coords[0] + or width * 2 - 1 in sparse.coords[1] + ): + Warning( + f"Cell {i} with id {_id} is potentially not fully contained in the bounding mask. Consider increasing the value for the 'cell_width' parameter in your config." + ) + + x = sparse.coords[0] + x_start + y = sparse.coords[1] + y_start + + results.append((_id, np.array(list(zip(x, y, strict=True))))) + + return results + + def _get_coords_multi(self, width: int, arg: tuple[list[int], np.ndarray]) -> list[tuple[int, np.ndarray]]: + cell_ids, centers = arg + results = self.__get_coords(cell_ids, centers, width) + return results + + def _get_coords( + self, cell_ids: list, centers: list[tuple[int, int]], width: int = 60, batch_size: int = 100, threads: int = 10 + ) -> dict[int, np.ndarray]: + # create batches + n_batches = int(np.ceil(len(cell_ids) / batch_size)) + slices = [(i * batch_size, i * batch_size + batch_size) for i in range(n_batches - 1)] + slices.append(((n_batches - 1) * batch_size, len(cell_ids))) + + batched_args = [(cell_ids[start:end], centers[start:end]) for start, end in slices] + + f = func_partial(self._get_coords_multi, width) + + if ( + threads == 1 + ): # if only one thread is used, the function is called directly to avoid the overhead of multiprocessing + results = [f(arg) for arg in batched_args] + else: + with mp.get_context(self.context).Pool(processes=threads) as pool: + results = list( + tqdm( + pool.imap(f, batched_args), + total=len(batched_args), + desc="Processing cell batches", + ) + ) + pool.close() + pool.join() + + results = flatten(results) # type: ignore + return dict(results) # type: ignore + + def _get_cell_ids(self, cell_sets: list[dict]) -> list[int]: + cell_ids = [] + for cell_set in cell_sets: + if "classes" in cell_set: + cell_ids.extend(cell_set["classes"]) + else: + Warning(f"Cell set {cell_set['name']} does not contain any classes.") + return cell_ids + + def _get_centers(self, cell_ids: list[int]) -> list[tuple[int, int]]: + _sdata = self.project.filehandler.get_sdata() + centers = _sdata["centers_cells"].compute() + centers = centers.loc[cell_ids, :] + return centers[ + ["y", "x"] + ].values.tolist() # needs to be returned as yx to match the coordinate system as saved in spatialdataobjects + def _post_processing_cleanup(self, vars_to_delete: list | None = None): if vars_to_delete is not None: self._clear_cache(vars_to_delete=vars_to_delete) @@ -51,7 +206,6 @@ def _post_processing_cleanup(self, vars_to_delete: list | None = None): def process( self, - segmentation_name: str, cell_sets: list[dict], calibration_marker: np.array, name: str | None = None, @@ -61,9 +215,9 @@ def process( Under the hood this method relies on the pylmd library and utilizies its `SegmentationLoader` Class. Args: - segmentation_name (str): Name of the segmentation to be used for shape generation in the sdata object. cell_sets (list of dict): List of dictionaries containing the sets of cells which should be sorted into a single well. Mandatory keys for each dictionary are: name, classes. Optional keys are: well. calibration_marker (numpy.array): Array of size ‘(3,2)’ containing the calibration marker coordinates in the ‘(row, column)’ format. + name (str, optional): Name of the output file. If not provided, the name will be generated based on the names of the cell sets or if also not specified set to "selected_cells". Example: @@ -77,7 +231,6 @@ def process( # A numpy Array of shape (3, 2) should be passed. calibration_marker = np.array([marker_0, marker_1, marker_2]) - # Sets of cells can be defined by providing a name and a list of classes in a dictionary. cells_to_select = [{"name": "dataset1", "classes": [1, 2, 3]}] @@ -122,7 +275,7 @@ def process( convolution_smoothing: 25 # fold reduction of datapoints for compression - poly_compression_factor: 30 + rdp: 0.7 # Optimization of the cutting path inbetween shapes # optimized paths improve the cutting time and the microscopes focus @@ -160,32 +313,21 @@ def process( self._setup_selection() - ## TO Do - # check if classes and seglookup table already exist as pickle file - # if not create them - # else load them and proceed with selection - - # load segmentation from hdf5 - self.path_seg_mask = self.filehandler._load_seg_to_memmap( - [segmentation_name], tmp_dir_abs_path=self._tmp_dir_path + start_time = timeit.default_timer() + cell_ids = self._get_cell_ids(cell_sets) + centers = self._get_centers(cell_ids) + coord_index = self._get_coords( + cell_ids=cell_ids, centers=centers, width=self.cell_radius, batch_size=self.batch_size, threads=self.threads ) + self.log(f"Coordinate lookup index calculation took {timeit.default_timer() - start_time} seconds.") - segmentation = tempmmap.mmap_array_from_path(self.path_seg_mask) - - # create segmentation loader sl = SegmentationLoader( config=self.config, verbose=self.debug, processes=self.config["processes_cell_sets"], ) - if len(segmentation.shape) == 3: - segmentation = np.squeeze(segmentation) - else: - raise ValueError(f"Segmentation shape is not correct. Expected 2D array, got {segmentation.shape}") - - # get shape collections - shape_collection = sl(segmentation, self.cell_sets, self.calibration_marker) + shape_collection = sl(None, self.cell_sets, self.calibration_marker, coords_lookup=coord_index) if self.debug: shape_collection.plot(calibration=True) @@ -196,4 +338,4 @@ def process( self.log(f"Saved output at {self.savepath}") # perform post processing cleanup - self._post_processing_cleanup(vars_to_delete=[shape_collection, sl, segmentation]) + self._post_processing_cleanup(vars_to_delete=[shape_collection, sl, coord_index]) diff --git a/src/scportrait/tools/ml/transforms.py b/src/scportrait/tools/ml/transforms.py index b976989..08baba1 100644 --- a/src/scportrait/tools/ml/transforms.py +++ b/src/scportrait/tools/ml/transforms.py @@ -19,7 +19,7 @@ def __init__(self, choices=4, include_zero=True): delta = (360 - angles[-1]) / 2 angles = angles + delta - self.choices = angles + self.choices = angles.tolist() def __call__(self, tensor): angle = random.choice(self.choices) diff --git a/src/scportrait/tools/stitch/_stitch.py b/src/scportrait/tools/stitch/_stitch.py index 5b21148..e1020d6 100644 --- a/src/scportrait/tools/stitch/_stitch.py +++ b/src/scportrait/tools/stitch/_stitch.py @@ -21,12 +21,7 @@ from scportrait.io.daskmmap import dask_array_from_path from scportrait.processing.images._image_processing import rescale_image from scportrait.tools.stitch._utils.ashlar_plotting import plot_edge_quality, plot_edge_scatter -from scportrait.tools.stitch._utils.filereaders import ( - BioformatsReaderRescale, - FilePatternReaderRescale, -) from scportrait.tools.stitch._utils.filewriters import write_ome_zarr, write_spatialdata, write_tif, write_xml -from scportrait.tools.stitch._utils.parallelized_ashlar import ParallelEdgeAligner, ParallelMosaic class Stitcher: @@ -65,7 +60,7 @@ def __init__( do_intensity_rescale: bool | str = True, rescale_range: tuple = (1, 99), channel_order: list[str] = None, - reader_type=FilePatternReaderRescale, + reader_type="FilePatternReaderRescale", orientation: dict = None, plot_QC: bool = True, overwrite: bool = False, @@ -114,6 +109,7 @@ def __init__( if orientation is None: orientation = {"flip_x": False, "flip_y": True} + self.input_dir = input_dir self.slidename = slidename self.outdir = outdir @@ -139,6 +135,10 @@ def __init__( self.orientation = orientation self.reader_type = reader_type + # workaround for lazy imports of module + if self.reader_type == "FilePatternReaderRescale": + self.reader_type = self.FilePatternReaderRescale + # workflow setup self.plot_QC = plot_QC self.overwrite = overwrite @@ -158,10 +158,20 @@ def _lazy_imports(self): from ashlar.reg import EdgeAligner, Mosaic from ashlar.scripts.ashlar import process_axis_flip + from scportrait.tools.stitch._utils.filereaders import ( + BioformatsReaderRescale, + FilePatternReaderRescale, + ) + from scportrait.tools.stitch._utils.parallelized_ashlar import ParallelEdgeAligner, ParallelMosaic + self.ashlar_thumbnail = thumbnail self.ashlar_EdgeAligner = EdgeAligner self.ashlar_Mosaic = Mosaic self.ashlar_process_axis_flip = process_axis_flip + self.BioformatsReaderRescale = BioformatsReaderRescale + self.FilePatternReaderRescale = FilePatternReaderRescale + self.ParallelEdgeAligner = ParallelEdgeAligner + self.ParallelMosaic = ParallelMosaic def __exit__(self): self._clear_cache() @@ -294,14 +304,14 @@ def _initialize_reader(self): """ Initialize the reader for reading image tiles. """ - if self.reader_type == FilePatternReaderRescale: + if self.reader_type == self.FilePatternReaderRescale: self.reader = self.reader_type( self.input_dir, self.pattern, self.overlap, rescale_range=self.rescale_range, ) - elif self.reader_type == BioformatsReaderRescale: + elif self.reader_type == self.BioformatsReaderRescale: self.reader = self.reader_type(self.input_dir, rescale_range=self.rescale_range) # setup correct orientation of slide (this depends on microscope used to generate the data) @@ -564,7 +574,7 @@ class ParallelStitcher(Stitcher): do_intensity_rescale (bool or "full_image", optional): Flag to indicate whether to rescale image intensities (default is True). Alternatively, set to "full_image" to rescale the entire image. rescale_range (tuple or dict, optional): If all channels should be rescaled to the same range pass a tuple with the percentiles for rescaling (default is (1, 99)). Alternatively, a dictionary can be passed with the channel names as keys and the percentiles as values if each channel should be rescaled to a different range. channel_order (list, optional): Order of channels in the generated output mosaic. If none (default value) the order of the channels is left unchanged. - reader_type (class, optional): Type of reader to use for reading image tiles (default is FilePatternReaderRescale). + reader_type (class, optional): Type of reader to use for reading image tiles (default is "FilePatternReaderRescale"). orientation (dict, optional): Dictionary specifying which dimensions of the slide to flip (default is {'flip_x': False, 'flip_y': True}). plot_QC (bool, optional): Flag to indicate whether to plot quality control (QC) figures (default is True). overwrite (bool, optional): Flag to indicate whether to overwrite the output directory if it already exists (default is False). @@ -588,7 +598,7 @@ def __init__( WGAchannel: str = None, channel_order: list[str] = None, overwrite: bool = False, - reader_type=FilePatternReaderRescale, + reader_type="FilePatternReaderRescale", orientation=None, cache: str = None, threads: int = 20, @@ -613,8 +623,9 @@ def __init__( overwrite, cache, ) + # dirty fix to avoide multithreading error with BioformatsReader until this can be fixed - if self.reader_type == BioformatsReaderRescale: + if self.reader_type == self.BioformatsReaderRescale: threads = 1 print( "BioformatsReaderRescale does not support multithreading for calculating the error threshold currently. Proceeding with 1 thread." @@ -632,7 +643,7 @@ def _initialize_aligner(self): Returns: aligner (ParallelEdgeAligner): Initialized ParallelEdgeAligner object. """ - aligner = ParallelEdgeAligner( + aligner = self.ParallelEdgeAligner( self.reader, channel=self.stitching_channel_id, filter_sigma=self.filter_sigma, @@ -644,7 +655,7 @@ def _initialize_aligner(self): return aligner def _initialize_mosaic(self): - mosaic = ParallelMosaic( + mosaic = self.ParallelMosaic( self.aligner, self.aligner.mosaic_shape, verbose=True, channels=self.channels, n_threads=self.threads ) return mosaic