diff --git a/snp2cell/snp2cell_class.py b/snp2cell/snp2cell_class.py index e178fc8..78f5a61 100644 --- a/snp2cell/snp2cell_class.py +++ b/snp2cell/snp2cell_class.py @@ -7,6 +7,7 @@ import re import textwrap from pathlib import Path +from typing import Union, Optional, Any, Callable, Iterable import dill import matplotlib.pyplot as plt @@ -220,7 +221,7 @@ def propagate_scores(self, score_keys, num_cores=None, log=logging.getLogger()): @add_logger() def rand_sim( self, - score_key="score", + score_key: Union[str, list[str]] = "score", perturb_key=None, n=1000, num_cores=None, @@ -251,7 +252,11 @@ def rand_sim( ) @add_logger() - def add_score_statistics(self, score_keys="score", log=logging.getLogger()): + def add_score_statistics( + self, + score_keys: Union[str, list[str], dict[str, str]] = "score", + log=logging.getLogger(), + ): log.info(f"adding statistics for: {score_keys}") if isinstance(score_keys, str): score_keys = [score_keys] @@ -383,7 +388,7 @@ def adata_add_de_scores( if "method" in kwargs and kwargs["method"] == "logreg": de_df = get_rank_df(self.adata) else: - de_df = sc.get.rank_genes_groups_df(self.adata, group=None) + de_df = sc.get.rank_genes_groups_df(self.adata, group=None) # type: ignore if rank_by == "abs": log.info("ranking by up- and downregulation...")