From 75ae100cf158191ee9097750e658b2f822cc837b Mon Sep 17 00:00:00 2001 From: Michaela Mueller <51025211+mumichae@users.noreply.github.com> Date: Thu, 3 Mar 2022 19:22:04 +0100 Subject: [PATCH] Use new scib package (#13) * update scib package name * fix batch variable check between integrated and nonintegraed adata * update environments to correct scib versions * using pip version of scib * fix R import issue for Harmony (as mentioned in https://github.com/immunogenomics/harmony/pull/134) * use value counts to check for batch relabeling after integration --- data/generate_data.py | 3 +- envs/scIB-python-paper.yml | 32 ++--- envs/scib-pipeline.yml | 14 +- scripts/integration/runIntegration.py | 43 +++--- scripts/integration/runMethods.R | 1 + scripts/integration/runPost.py | 12 +- scripts/integration_fail_file.py | 26 ++-- scripts/merge_benchmarks.py | 4 +- scripts/metrics/merge_metrics.py | 17 +-- scripts/metrics/metrics.py | 175 ++++++++++++----------- scripts/precompute_conn.py | 45 +++--- scripts/preprocessing/runPP.py | 36 +++-- scripts/update_timestamp.py | 38 ++--- scripts/visualization/save_embeddings.py | 30 ++-- tests/pipeline/test_metrics.py | 14 +- 15 files changed, 258 insertions(+), 232 deletions(-) diff --git a/data/generate_data.py b/data/generate_data.py index 49f24e16..c3deb5af 100644 --- a/data/generate_data.py +++ b/data/generate_data.py @@ -2,6 +2,7 @@ import numpy as np import scib import warnings + warnings.simplefilter(action='ignore', category=FutureWarning) @@ -39,7 +40,7 @@ def get_adata_pbmc(): """ Code from https://scanpy-tutorials.readthedocs.io/en/latest/integrating-data-using-ingest.html """ - #adata_ref = sc.datasets.pbmc3k_processed() + # adata_ref = sc.datasets.pbmc3k_processed() # quick fix for broken dataset paths, should be removed with scanpy>=1.6.0 adata_ref = sc.read( "pbmc3k_processed.h5ad", diff --git a/envs/scIB-python-paper.yml b/envs/scIB-python-paper.yml index eae54e77..dfe67d1e 100644 --- a/envs/scIB-python-paper.yml +++ b/envs/scIB-python-paper.yml @@ -3,33 +3,32 @@ channels: - conda-forge - bioconda dependencies: - - python==3.7 - - numpy==1.18.1 + - python=3.7 + - numpy=1.18.1 - pandas - seaborn - matplotlib - - scanpy==1.4.6 - - anndata==0.7.1 + - scanpy=1.4.6 + - anndata=0.7.1 - h5py<3 - scipy - memory_profiler - - rpy2==3.1.0 + - rpy2=3.1.0 - r-stringi - - anndata2ri==1.0.2 - - bbknn==1.3.9 + - anndata2ri=1.0.2 + - bbknn=1.3.9 - libgcc-ng - gsl - scikit-learn - networkx - r-base - r-devtools - - r-seurat==3.1.1 + - r-seurat=3.1.1 - bioconductor-scater - bioconductor-scran - pip - numba<=0.46 - llvmlite - - tensorflow==1.15 - gxx_linux-64 - gxx_impl_linux-64 - gcc_linux-64 @@ -39,8 +38,8 @@ dependencies: - igraph - openblas - r-essentials - - r-globals==0.12.5 - - r-listenv==0.8.0 + - r-globals=0.12.5 + - r-listenv=0.8.0 - r-rlang - r-ellipsis - r-evaluate @@ -52,12 +51,13 @@ dependencies: - r-testthat - r-vctrs - xlrd - - umap-learn==0.3.10 - - louvain==0.6.1 - - scvi==0.6.7 - - scanorama==1.7.0 + - umap-learn=0.3.10 + - louvain=0.6.1 + - scvi=0.6.7 + - scanorama=1.7.0 - pip: - - git+git://github.com/theislab/scib.git + - git+git://github.com/theislab/scib.git@0.2.0 + - tensorflow==1.15 #- trvae==1.1.2 - trvaep==0.1.0 - mnnpy==0.1.9.5 diff --git a/envs/scib-pipeline.yml b/envs/scib-pipeline.yml index 8a28dd99..2d1679e1 100644 --- a/envs/scib-pipeline.yml +++ b/envs/scib-pipeline.yml @@ -14,6 +14,8 @@ dependencies: - openblas - llvmlite - libgcc-ng + - numba<=0.46 # for mnnpy + - anndata2ri - r-base - r-essentials - r-devtools @@ -21,14 +23,12 @@ dependencies: - bioconductor-scater - bioconductor-scran # Methods - - scvi==0.6.7 - - scanorama==1.7.0 - - bbknn==1.3.9 - - r-seurat==3.1.1 - - numba<=0.46 # for mnnpy - - anndata2ri==1.0.5 # 1.0.6 has issues with HDF5 conversion + - scvi=0.6.7 + - scanorama=1.7.0 + - bbknn=1.3.9 + - r-seurat=3.1.1 - pip: - - git+git://github.com/theislab/scib.git + - scib==1.0.0 - trvaep==0.1.0 - mnnpy==0.1.9.5 - scgen==1.1.5 diff --git a/scripts/integration/runIntegration.py b/scripts/integration/runIntegration.py index 34ad30b1..e61070c8 100755 --- a/scripts/integration/runIntegration.py +++ b/scripts/integration/runIntegration.py @@ -2,8 +2,9 @@ # coding: utf-8 import scanpy as sc -import scIB +import scib import warnings + warnings.filterwarnings('ignore') @@ -16,16 +17,15 @@ def runIntegration(inPath, outPath, method, hvg, batch, celltype=None): """ adata = sc.read(inPath) - + if timing: if celltype is not None: - integrated_tmp = scIB.metrics.measureTM(method, adata, batch, celltype) + integrated_tmp = scib.metrics.measureTM(method, adata, batch, celltype) else: - integrated_tmp = scIB.metrics.measureTM(method, adata, batch) + integrated_tmp = scib.metrics.measureTM(method, adata, batch) integrated = integrated_tmp[2][0] - integrated.uns['mem'] = integrated_tmp[0] integrated.uns['runtime'] = integrated_tmp[1] @@ -34,10 +34,11 @@ def runIntegration(inPath, outPath, method, hvg, batch, celltype=None): integrated = method(adata, batch, celltype) else: integrated = method(adata, batch) - + sc.write(outPath, integrated) -if __name__=='__main__': + +if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Run the integration methods') @@ -59,22 +60,22 @@ def runIntegration(inPath, outPath, method, hvg, batch, celltype=None): celltype = args.celltype method = args.method methods = { - 'scanorama': scIB.integration.runScanorama, - 'trvae': scIB.integration.runTrVae, - 'trvaep': scIB.integration.runTrVaep, - 'scgen': scIB.integration.runScGen, - 'mnn': scIB.integration.runMNN, - 'bbknn': scIB.integration.runBBKNN, - 'scvi': scIB.integration.runScvi, - 'scanvi': scIB.integration.runScanvi, - 'combat': scIB.integration.runCombat, - 'saucie': scIB.integration.runSaucie, - 'desc': scIB.integration.runDESC + 'scanorama': scib.integration.scanorama, + 'trvae': scib.integration.trvae, + 'trvaep': scib.integration.trvaep, + 'scgen': scib.integration.scgen, + 'mnn': scib.integration.mnn, + 'bbknn': scib.integration.bbknn, + 'scvi': scib.integration.scvi, + 'scanvi': scib.integration.scanvi, + 'combat': scib.integration.combat, + 'saucie': scib.integration.saucie, + 'desc': scib.integration.desc } - + if method not in methods.keys(): raise ValueError(f'Method "{method}" does not exist. Please use one of ' f'the following:\n{list(methods.keys())}') - - run= methods[method] + + run = methods[method] runIntegration(file, out, run, hvg, batch, celltype) diff --git a/scripts/integration/runMethods.R b/scripts/integration/runMethods.R index 6fca2ee5..ff6eb53a 100644 --- a/scripts/integration/runMethods.R +++ b/scripts/integration/runMethods.R @@ -11,6 +11,7 @@ getScriptPath <- function(){ setwd(getScriptPath()) library('optparse') +library(rlang) require(Seurat) option_list <- list(make_option(c("-m", "--method"), type="character", default=NA, help="integration method to use"), diff --git a/scripts/integration/runPost.py b/scripts/integration/runPost.py index c52f6e5a..6ecf8bf7 100755 --- a/scripts/integration/runPost.py +++ b/scripts/integration/runPost.py @@ -1,9 +1,9 @@ #!/usr/bin/env python # coding: utf-8 -import scanpy as sc -import scIB +import scib import warnings + warnings.filterwarnings('ignore') @@ -15,14 +15,14 @@ def runPost(inPath, outPath, conos): conos: set if input is conos obect """ if conos: - adata = scIB.pp.readConos(inPath) + adata = scib.pp.read_conos(inPath) else: - adata = scIB.pp.readSeurat(inPath) + adata = scib.pp.read_seurat(inPath) adata.write(outPath) -if __name__=='__main__': +if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Run the integration methods') @@ -35,5 +35,5 @@ def runPost(inPath, outPath, conos): file = args.input_file out = args.output_file conos = args.conos - + runPost(file, out, conos) diff --git a/scripts/integration_fail_file.py b/scripts/integration_fail_file.py index 648adac3..f39c8f23 100755 --- a/scripts/integration_fail_file.py +++ b/scripts/integration_fail_file.py @@ -1,8 +1,7 @@ - from snakemake.io import load_configfile from pathlib import Path -if __name__=='__main__': +if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Create an empty output file for failed integration runs') @@ -25,31 +24,30 @@ # Check inputs if method not in params['METHODS']: - raise ValueError(f'{method} is not a valid method.\n' + raise ValueError(f'{method} is not a valid method.\n' f'Please choose one of: {list(params["METHODS"].keys())}') if task not in params['DATA_SCENARIOS']: - raise ValueError(f'{task} is not a valid integration task.\n' + raise ValueError(f'{task} is not a valid integration task.\n' f'Please choose one of: {list(params["DATA_SCENARIOS"].keys())}') - + # Get path values folder = params['ROOT'] t_folder = task s_folder = 'scaled' if scale else 'unscaled' h_folder = 'hvg' if hvgs else 'full_feature' r_folder = 'R/' if 'R' in params['METHODS'][method] else '' - filename = method+'.h5ad' + filename = method + '.h5ad' - folder_path = '/'.join([folder,task,'integration',s_folder,h_folder])+'/'+r_folder - full_path = folder_path+filename + folder_path = '/'.join([folder, task, 'integration', s_folder, h_folder]) + '/' + r_folder + full_path = folder_path + filename if 'R' in params['METHODS'][method]: - filename_r = method+'.RDS' - full_path_r = folder_path+filename_r + filename_r = method + '.RDS' + full_path_r = folder_path + filename_r Path(full_path_r).touch() - Path(full_path_r+".benchmark").touch() + Path(full_path_r + ".benchmark").touch() - #print(full_path) + # print(full_path) Path(full_path).touch() - Path(full_path+".benchmark").touch() - + Path(full_path + ".benchmark").touch() diff --git a/scripts/merge_benchmarks.py b/scripts/merge_benchmarks.py index 1678afc2..a1dcaea7 100644 --- a/scripts/merge_benchmarks.py +++ b/scripts/merge_benchmarks.py @@ -2,7 +2,7 @@ import argparse import os -if __name__=='__main__': +if __name__ == '__main__': """ Merge benchmark output for all scenarios, methods and settings """ @@ -14,7 +14,6 @@ help='root directory for scIB output') args = parser.parse_args() - print("Searching for .benchmark files...") bench_files = [] for path, dirs, files in os.walk(args.root): @@ -43,4 +42,3 @@ results.to_csv(args.output, index_label='scenario') print("Done!") - diff --git a/scripts/metrics/merge_metrics.py b/scripts/metrics/merge_metrics.py index 8ec7137c..58d45ad0 100644 --- a/scripts/metrics/merge_metrics.py +++ b/scripts/metrics/merge_metrics.py @@ -2,35 +2,32 @@ # coding: utf-8 import pandas as pd -import scIB import warnings + warnings.filterwarnings('ignore') import argparse from functools import reduce -if __name__=='__main__': +if __name__ == '__main__': """ Merge metrics output for all scenarios, methods and settings """ - + parser = argparse.ArgumentParser(description='Collect all metrics') parser.add_argument('-i', '--input', nargs='+', required=True, help='input directory') parser.add_argument('-o', '--output', required=True, help='output file') - parser.add_argument('-r', '--root', required=True, + parser.add_argument('-r', '--root', required=True, help='root directory for inferring column names from path') args = parser.parse_args() - - + res_list = [] for file in args.input: clean_name = file.replace(args.root, "").replace(".csv", "") res = pd.read_csv(file, index_col=0) res.rename(columns={res.columns[0]: clean_name}, inplace=True) res_list.append(res) - - results = reduce(lambda left,right: pd.merge(left, right, left_index=True, right_index=True), res_list) + + results = reduce(lambda left, right: pd.merge(left, right, left_index=True, right_index=True), res_list) results = results.T results.to_csv(args.output) - - diff --git a/scripts/metrics/metrics.py b/scripts/metrics/metrics.py index 42de39af..ec7a548c 100755 --- a/scripts/metrics/metrics.py +++ b/scripts/metrics/metrics.py @@ -2,45 +2,49 @@ # coding: utf-8 import scanpy as sc -import scIB +import scib import numpy as np import warnings + warnings.filterwarnings('ignore') # types of integration output RESULT_TYPES = [ - "full", # reconstructed expression data - "embed", # embedded/latent space - "knn" # only corrected neighbourhood graph as output + "full", # reconstructed expression data + "embed", # embedded/latent space + "knn" # only corrected neighbourhood graph as output ] ASSAYS = ["expression", "atac", "simulation"] -if __name__=='__main__': +if __name__ == '__main__': """ read adata object, compute all metrics and output csv. """ - + import argparse import os parser = argparse.ArgumentParser(description='Compute all metrics') - + parser.add_argument('-u', '--uncorrected', required=True) parser.add_argument('-i', '--integrated', required=True) parser.add_argument('-o', '--output', required=True, help='Output file') parser.add_argument('-m', '--method', required=True, help='Name of method') - + parser.add_argument('-b', '--batch_key', required=True, help='Key of batch') parser.add_argument('-l', '--label_key', required=True, help='Key of annotated labels e.g. "cell_type"') - + parser.add_argument('--organism', required=True) - parser.add_argument('--type', required=True, choices=RESULT_TYPES, help='Type of result: full, embed, knn\n full: scanorama, seurat, MNN\n embed: scanorama, Harmony\n knn: BBKNN') + parser.add_argument('--type', required=True, choices=RESULT_TYPES, + help='Type of result: full, embed, knn\n full: scanorama, seurat, MNN\n embed: scanorama, Harmony\n knn: BBKNN') parser.add_argument('--assay', default='expression', choices=ASSAYS, help='Experimental assay') - parser.add_argument('--hvgs', default=0, help='Number of highly variable genes. Use 0 to specify that no feature selection had been used.', type=int) + parser.add_argument('--hvgs', default=0, + help='Number of highly variable genes. Use 0 to specify that no feature selection had been used.', + type=int) parser.add_argument('-v', '--verbose', action='store_true') - + args = parser.parse_args() - + verbose = args.verbose type_ = args.type batch_key = args.batch_key @@ -48,10 +52,10 @@ assay = args.assay organism = args.organism n_hvgs = args.hvgs if args.hvgs > 0 else None - + # encode setup for column name setup = f'{args.method}_{args.type}' - + # create cluster NMI output file file_stump = os.path.splitext(args.output)[0] cluster_nmi = f'{file_stump}_nmi.txt' @@ -66,9 +70,8 @@ print(f' n_hvgs:\t{n_hvgs}') print(f' setup:\t{setup}') print(f' optimised clustering results:\t{cluster_nmi}') - - ### + ### empty_file = False @@ -87,51 +90,50 @@ if (n_hvgs is not None): if (adata_int.n_vars < n_hvgs): raise ValueError("There are less genes in the corrected adata than specified for HVG selection") - + # check input files if adata.n_obs != adata_int.n_obs: message = "The datasets have different numbers of cells before and after integration." message += "Please make sure that both datasets match." raise ValueError(message) - - #check if the obsnames were changed and rename them in that case + + # check if the obsnames were changed and rename them in that case if len(set(adata.obs_names).difference(set(adata_int.obs_names))) > 0: - #rename adata_int.obs[batch_key] labels by overwriting them with the pre-integration labels + # rename adata_int.obs[batch_key] labels by overwriting them with the pre-integration labels new_obs_names = ['-'.join(idx.split('-')[:-1]) for idx in adata_int.obs_names] if len(set(adata.obs_names).difference(set(new_obs_names))) == 0: adata_int.obs_names = new_obs_names else: raise ValueError('obs_names changed after integration!') - - #batch_key might be overwritten, so we match it to the pre-integrated labels + + # batch_key might be overwritten, so we match it to the pre-integrated labels adata_int.obs[batch_key] = adata_int.obs[batch_key].astype('category') - if not np.array_equal(adata.obs[batch_key].cat.categories,adata_int.obs[batch_key].cat.categories): - #pandas uses the table index to match the correct labels + batch_u = adata.obs[batch_key].value_counts().index + batch_i = adata_int.obs[batch_key].value_counts().index + if not batch_i.equals(batch_u): + # pandas uses the table index to match the correct labels adata_int.obs[batch_key] = adata.obs[batch_key] - #print(adata.obs[batch_key].value_counts()) - #print(adata_int.obs[batch_key].value_counts()) - if (n_hvgs is not None) and (adata_int.n_vars < n_hvgs): # check number of HVGs to be computed message = "There are fewer genes in the uncorrected adata " message += "than specified for HVG selection." raise ValueError(message) - + # DATA REDUCTION # select options according to type - + # case 1: full expression matrix, default settings precompute_pca = True recompute_neighbors = True embed = 'X_pca' - + # distinguish between subsetted and full expression matrix # compute HVGs only if output is not already subsetted if adata.n_vars > adata_int.n_vars: n_hvgs = None - + # case 2: embedding output if (type_ == "embed"): n_hvgs = None @@ -139,13 +141,13 @@ # legacy check if ('emb' in adata_int.uns) and (adata_int.uns['emb']): adata_int.obsm["X_emb"] = adata_int.obsm["X_pca"].copy() - + # case3: kNN graph output elif (type_ == "knn"): n_hvgs = None precompute_pca = False recompute_neighbors = False - + if verbose: print('reduce integrated data:') print(f' HVG selection:\t{n_hvgs}') @@ -156,11 +158,15 @@ print(f' precompute PCA:\t{precompute_pca}') if not empty_file: - scIB.preprocessing.reduce_data(adata_int, - n_top_genes=n_hvgs, - neighbors=recompute_neighbors, use_rep=embed, - pca=precompute_pca, umap=False) - + scib.preprocessing.reduce_data( + adata_int, + n_top_genes=n_hvgs, + neighbors=recompute_neighbors, + use_rep=embed, + pca=precompute_pca, + umap=False + ) + print("computing metrics") # DEFAULT silhouette_ = True @@ -172,10 +178,9 @@ hvg_score_ = True graph_conn_ = True kBET_ = True - #lisi_ = True + # lisi_ = True lisi_graph_ = True - - + # by output type if (type_ == "embed"): hvg_score_ = False @@ -184,15 +189,15 @@ pcr_ = False cell_cycle_ = False hvg_score_ = False - #lisi_ = False - - # by assay + # lisi_ = False + + # by assay if args.assay == 'atac': cell_cycle_ = False hvg_score_ = False elif args.assay == 'simulation': cell_cycle_ = False - + # check if pseudotime data exists in original data if 'dpt_pseudotime' in adata.obs: trajectory_ = True @@ -200,21 +205,21 @@ trajectory_ = False if empty_file: - silhouette_=False - nmi_=False - ari_=False - pcr_=False - cell_cycle_=False - isolated_labels_=False - hvg_score_=False - graph_conn_=False - kBET_=False - #lisi_=False - lisi_graph_=False - trajectory_=False + silhouette_ = False + nmi_ = False + ari_ = False + pcr_ = False + cell_cycle_ = False + isolated_labels_ = False + hvg_score_ = False + graph_conn_ = False + kBET_ = False + # lisi_=False + lisi_graph_ = False + trajectory_ = False if adata.n_obs > 300000: - kBET_=False + kBET_ = False if verbose: print(f'type:\t{type_}') @@ -227,31 +232,41 @@ print(f' iso lab ASW:\t{isolated_labels_ and silhouette_}') print(f' HVGs:\t{hvg_score_}') print(f' kBET:\t{kBET_}') - #print(f' LISI:\t{lisi_}') + # print(f' LISI:\t{lisi_}') print(f' LISI:\t{lisi_graph_}') print(f' Trajectory:\t{trajectory_}') - - results = scIB.me.metrics(adata, adata_int, verbose=verbose, - hvg_score_=hvg_score_, cluster_nmi=cluster_nmi, - batch_key=batch_key, label_key=label_key, - silhouette_=silhouette_, embed=embed, - type_ = type_, - nmi_=nmi_, nmi_method='arithmetic', nmi_dir=None, - ari_=ari_, - pcr_=pcr_, - cell_cycle_=cell_cycle_, organism=organism, - isolated_labels_=isolated_labels_, n_isolated=None, - graph_conn_=graph_conn_, - kBET_=kBET_, - #lisi_=lisi_, - lisi_graph_= lisi_graph_, - trajectory_=trajectory_ - ) - results.rename(columns={results.columns[0]:setup}, inplace=True) + + results = scib.me.metrics( + adata, + adata_int, + verbose=verbose, + hvg_score_=hvg_score_, + cluster_nmi=cluster_nmi, + batch_key=batch_key, + label_key=label_key, + silhouette_=silhouette_, + embed=embed, + type_=type_, + nmi_=nmi_, + nmi_method='arithmetic', + nmi_dir=None, + ari_=ari_, + pcr_=pcr_, + cell_cycle_=cell_cycle_, + organism=organism, + isolated_labels_=isolated_labels_, + n_isolated=None, + graph_conn_=graph_conn_, + kBET_=kBET_, + # lisi_=lisi_, + lisi_graph_=lisi_graph_, + trajectory_=trajectory_ + ) + results.rename(columns={results.columns[0]: setup}, inplace=True) + if verbose: print(results) + # save metrics' results results.to_csv(args.output) - print("done") - diff --git a/scripts/precompute_conn.py b/scripts/precompute_conn.py index 2a113bc0..502aa9e5 100644 --- a/scripts/precompute_conn.py +++ b/scripts/precompute_conn.py @@ -3,65 +3,60 @@ import scanpy as sc import scipy.io as scio -from scIB.metrics import diffusion_conn -import numpy as np +from scib.metrics.utils import diffusion_conn import warnings + warnings.filterwarnings('ignore') # types of integration output RESULT_TYPES = [ - "full", # reconstructed expression data - "embed", # embedded/latent space - "knn" # only corrected neighbourhood graph as output + "full", # reconstructed expression data + "embed", # embedded/latent space + "knn" # only corrected neighbourhood graph as output ] -if __name__=='__main__': +if __name__ == '__main__': """ read adata object, precompute diffusion connectivities for knn data integration methods and output connectivity matrix. """ - + import argparse import os - parser = argparse.ArgumentParser(description='Precompute diffusion connectivities for knn data integration methods.') - - + parser = argparse.ArgumentParser( + description='Precompute diffusion connectivities for knn data integration methods.') + parser.add_argument('-i', '--input', required=True) parser.add_argument('-o', '--output', required=True, help='output directory') parser.add_argument('-v', '--verbose', action='store_true') - parser.add_argument('-t', '--type', required=True, choices=RESULT_TYPES, help='Type of result: full, embed, knn\n full: scanorama, seurat, MNN\n embed: scanorama, Harmony\n knn: BBKNN') - + parser.add_argument('-t', '--type', required=True, choices=RESULT_TYPES, + help='Type of result: full, embed, knn\n full: scanorama, seurat, MNN\n embed: scanorama, Harmony\n knn: BBKNN') + args = parser.parse_args() - + verbose = args.verbose type_ = args.type - + # set prefix for output and results column name base = os.path.basename(args.input).split('.h5ad')[0] - - + if verbose: print('Options') print(f' type:\t{type_}') - - - + ### - + print("reading adata input file") - if os.stat(args.input).st_size>0: + if os.stat(args.input).st_size > 0: adata = sc.read(args.input, cache=True) print(adata) if (type_ == 'knn'): neighbors = adata.obsp['connectivities'] del adata diff_neighbors = diffusion_conn(neighbors, min_k=50, copy=False, max_iterations=20) - scio.mmwrite(target = os.path.join(args.output, f'{base}_diffconn.mtx'), a = diff_neighbors) + scio.mmwrite(target=os.path.join(args.output, f'{base}_diffconn.mtx'), a=diff_neighbors) print("done") else: print('Wrong type chosen, doing nothing.') else: print("No file found. Doing nothing.") - - - diff --git a/scripts/preprocessing/runPP.py b/scripts/preprocessing/runPP.py index 56386990..5737bc01 100755 --- a/scripts/preprocessing/runPP.py +++ b/scripts/preprocessing/runPP.py @@ -2,8 +2,9 @@ # coding: utf-8 import scanpy as sc -import scIB +import scib import warnings + warnings.filterwarnings('ignore') @@ -19,35 +20,42 @@ def runPP(inPath, outPath, hvg, batch, rout, scale, seurat): """ adata = sc.read(inPath) - hvgs=adata.var.index + hvgs = adata.var.index # remove HVG if already precomputed if 'highly_variable' in adata.var: del adata.var['highly_variable'] - + if hvg > 500: print("Computing HVGs ...") if seurat: - hvgs= scIB.preprocessing.hvg_batch(adata,batch_key=batch, target_genes=hvg, adataOut=False) + hvgs = scib.preprocessing.hvg_batch( + adata, + batch_key=batch, + target_genes=hvg, + adataOut=False + ) else: - adata = scIB.preprocessing.hvg_batch(adata, - batch_key=batch, - target_genes=hvg, - adataOut=True) + adata = scib.preprocessing.hvg_batch( + adata, + batch_key=batch, + target_genes=hvg, + adataOut=True + ) if scale: print("Scaling data ...") - adata = scIB.preprocessing.scale_batch(adata, batch) - + adata = scib.preprocessing.scale_batch(adata, batch) + if rout: print("Save as RDS") - scIB.preprocessing.saveSeurat(adata, outPath, batch, hvgs) - + scib.preprocessing.saveSeurat(adata, outPath, batch, hvgs) + else: print("Save as HDF5") sc.write(outPath, adata) -if __name__=='__main__': +if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Run the integration methods') @@ -68,5 +76,5 @@ def runPP(inPath, outPath, hvg, batch, rout, scale, seurat): rout = args.rout seurat = args.seurat scale = args.scale - + runPP(file, out, hvg, batch, rout, scale, seurat) diff --git a/scripts/update_timestamp.py b/scripts/update_timestamp.py index 4fc9ce37..e3de80be 100644 --- a/scripts/update_timestamp.py +++ b/scripts/update_timestamp.py @@ -4,12 +4,13 @@ from pathlib import Path from os.path import isfile + def touch_if_exists(file_path): if isfile(file_path): Path(file_path).touch() else: print(f'{file_path} does not exist.') - + def update_timestamp_task(config, task, update_metrics=True): """ @@ -21,7 +22,7 @@ def update_timestamp_task(config, task, update_metrics=True): Note that this function does not update the timestamp of the aggregated metrics files. """ - + base_folder = config['ROOT'] scaling = config['SCALING'] hvgs = list(config['FEATURE_SELECTION'].keys()) @@ -32,23 +33,23 @@ def update_timestamp_task(config, task, update_metrics=True): for scal in scaling: for feat in hvgs: - folder_path = '/'.join([base_folder,task,'prepare',scal,feat])+'/' + folder_path = '/'.join([base_folder, task, 'prepare', scal, feat]) + '/' file_base = 'adata_pre' for end in prep_endings: - filename = file_base+end - full_path = folder_path+filename + filename = file_base + end + full_path = folder_path + filename touch_if_exists(full_path) - full_path = folder_path+'prep_h5ad.benchmark' + full_path = folder_path + 'prep_h5ad.benchmark' touch_if_exists(full_path) - full_path = folder_path+'prep_RDS.benchmark' + full_path = folder_path + 'prep_RDS.benchmark' touch_if_exists(full_path) # Integration & convert files for scal in scaling: for feat in hvgs: - folder_base = '/'.join([base_folder,task,'integration',scal,feat])+'/' + folder_base = '/'.join([base_folder, task, 'integration', scal, feat]) + '/' for method in methods: if 'R' in config['METHODS'][method]: r_folder = 'R/' @@ -58,8 +59,8 @@ def update_timestamp_task(config, task, update_metrics=True): method_endings = ['.h5ad', '.h5ad.benchmark'] for end in method_endings: - folder_path = folder_base+r_folder - full_path = folder_path+method+end + folder_path = folder_base + r_folder + full_path = folder_path + method + end touch_if_exists(full_path) # Metrics files @@ -68,30 +69,29 @@ def update_timestamp_task(config, task, update_metrics=True): if update_metrics: for scal in scaling: for feat in hvgs: - folder_base = '/'.join([base_folder,task,'metrics',scal,feat])+'/' + folder_base = '/'.join([base_folder, task, 'metrics', scal, feat]) + '/' for method in methods: out_types = config['METHODS'][method]['output_type'] if isinstance(out_types, str): out_types = [out_types] for out_type in out_types: - file_base = '_'.join([method,out_type]) + file_base = '_'.join([method, out_type]) for end in metric_endings: - full_path = folder_base+file_base+end + full_path = folder_base + file_base + end touch_if_exists(full_path) - -if __name__=='__main__': +if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Update timestamp on all output files' - ' for an integration task') + ' for an integration task') parser.add_argument('-c', '--config', help='Snakemake config file', required=True) - parser.add_argument('-t', '--task', help='Integration task to update', + parser.add_argument('-t', '--task', help='Integration task to update', required=True) - parser.add_argument('-m', '--include-metrics', action='store_true', + parser.add_argument('-m', '--include-metrics', action='store_true', help='Also update timestamp of metrics files') args = parser.parse_args() @@ -103,7 +103,7 @@ def update_timestamp_task(config, task, update_metrics=True): params = load_configfile(config) if task not in params['DATA_SCENARIOS']: - raise ValueError(f'{task} is not a valid integration task.\n' + raise ValueError(f'{task} is not a valid integration task.\n' 'Please choose one of:\n' f'{list(params["DATA_SCENARIOS"].keys())}') diff --git a/scripts/visualization/save_embeddings.py b/scripts/visualization/save_embeddings.py index dd589c19..71f1bb77 100644 --- a/scripts/visualization/save_embeddings.py +++ b/scripts/visualization/save_embeddings.py @@ -2,7 +2,7 @@ # coding: utf-8 import scanpy as sc -import scIB +import scib import argparse import os import sys @@ -12,11 +12,11 @@ RESULT_TYPES = [ "full", # reconstructed expression data - "embed", # embedded/latent space - "knn" # corrected neighbourhood graph + "embed", # embedded/latent space + "knn" # corrected neighbourhood graph ] -if __name__=='__main__': +if __name__ == '__main__': """ Save embeddings for all scenarios, methods and settings """ @@ -61,12 +61,24 @@ print('Preparing dataset...') if result == 'embed': - scIB.pp.reduce_data(adata, n_top_genes=None, neighbors=True, - use_rep='X_emb', pca=False, umap=False) + scib.pp.reduce_data( + adata, + n_top_genes=None, + neighbors=True, + use_rep='X_emb', + pca=False, + umap=False + ) elif result == 'full': sc.pp.filter_genes(adata, min_cells=1) - scIB.pp.reduce_data(adata, n_top_genes=2000, neighbors=True, - use_rep='X_pca', pca=True, umap=False) + scib.pp.reduce_data( + adata, + n_top_genes=2000, + neighbors=True, + use_rep='X_pca', + pca=True, + umap=False + ) # Calculate embedding if args.method.startswith('conos'): @@ -96,6 +108,6 @@ print('Saving embedding coordinates...') adata.obs[label + '1'] = adata.obsm['X_' + basis][:, 0] adata.obs[label + '2'] = adata.obsm['X_' + basis][:, 1] - coords = adata.obs[[label_key, batch_key, label + '1', label + '2' ]] + coords = adata.obs[[label_key, batch_key, label + '1', label + '2']] coords.to_csv(os.path.join(outfile), index_label='CellID') diff --git a/tests/pipeline/test_metrics.py b/tests/pipeline/test_metrics.py index d25d720d..4b57ad48 100644 --- a/tests/pipeline/test_metrics.py +++ b/tests/pipeline/test_metrics.py @@ -1,5 +1,5 @@ from .test_pipeline import * -from scIB.integration import * +import scib import os import subprocess @@ -10,14 +10,14 @@ def metrics_all_methods(adata_factory): adata = adata_factory() methods = { - 'scanorama': runScanorama, - 'trvae': runTrVae, + 'scanorama': scib.ig.scanorama, + 'trvae': scib.ig.trvae, 'seurat': runSeurat, 'harmony': runHarmony, - 'mnn': runMNN, - 'bbknn': runBBKNN, + 'mnn': scib.ig.mnn, + 'bbknn': scib.ig.bbknn, 'conos': runConos, - 'scvi': runScvi + 'scvi': scib.ig.scvi } # for name, func in methods.items(): @@ -26,7 +26,7 @@ def test_all_metrics(adata_factory, test_metrics): adata = adata_factory() adata_int = adata.copy() - script = os.path.join(os.path.dirname(scIB.__file__), "scripts", "metrics.py") + script = os.path.join(os.path.dirname(scib.__file__), "scripts", "metrics.py") for ot in ["full", "embed", "knn"]: all_metrics(adata, adata_int, script=script, type_=ot, pipeline_dir=test_metrics, method="orig")