diff --git a/src/mousipy/mousipy.py b/src/mousipy/mousipy.py index 1ff1b7a..f5733e5 100644 --- a/src/mousipy/mousipy.py +++ b/src/mousipy/mousipy.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from anndata import AnnData -from scipy.sparse import csr_matrix, issparse +from scipy.sparse import csr_matrix, issparse, lil_matrix from tqdm import tqdm # Biomart tables @@ -185,66 +185,51 @@ def translate_direct(adata, direct, no_index): def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbose=False): - """ - Adds the counts of multiple-hit genes to ALL their orthologs. - """ + """Adds the counts of multiple-hit genes to ALL their orthologs in an optimized manner.""" + # Ensure X is in the desired format from the start, reducing unnecessary conversions. + X = adata.X if stay_sparse else adata.X.toarray() if issparse(adata.X) else adata.X var = adata.var.copy() - ortholog_indices = {gene: i for i, gene in enumerate(var.index)} - - if stay_sparse: - # Sparse implementation remains unchanged - X = adata.X.copy() - for mgene, hgenes in (tqdm(multiple.items()) if verbose else multiple.items()): - mgene_data = make_dense(original_data[:, mgene].X) - - for hgene in hgenes: - if hgene not in ortholog_indices: - # Create a new DataFrame row for the new gene - new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene]) - new_row['original_gene_symbol'] = 'multiple' - var = pd.concat([var, new_row]) - - X = csr_matrix(np.hstack((X.toarray(), mgene_data.reshape(-1, 1)))) - ortholog_indices[hgene] = X.shape[1] - 1 - else: - idx = ortholog_indices[hgene] - X[:, idx] += csr_matrix(mgene_data).reshape(-1, 1) - else: - # Dense implementation - num_new_genes = sum(1 for hgenes in multiple.values() for hgene in hgenes if hgene not in ortholog_indices) - X = make_dense(adata.X) - new_data = np.zeros((X.shape[0], X.shape[1] + num_new_genes)) - - new_data[:, :X.shape[1]] = X - next_new_gene_idx = X.shape[1] - - for mgene, hgenes in (tqdm(multiple.items()) if verbose else multiple.items()): - mgene_data = make_dense(original_data[:, mgene].X).reshape(-1, 1) - - for hgene in hgenes: - if hgene not in ortholog_indices: - # Create a new DataFrame row for the new gene - new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene]) - new_row['original_gene_symbol'] = 'multiple' - var = pd.concat([var, new_row]) - - new_data[:, next_new_gene_idx] = mgene_data.ravel() - ortholog_indices[hgene] = next_new_gene_idx - next_new_gene_idx += 1 - else: - idx = ortholog_indices[hgene] - new_data[:, idx] += mgene_data.ravel() - X = new_data + # Prepare for efficient updates + new_genes = [] # To track genes not currently in `var` + gene_updates = {} # To aggregate updates before applying them + + # Use tqdm for verbose mode + iterator = tqdm(multiple.items()) if verbose else multiple.items() - # Check the dimensions of X and var - if X.shape[1] != var.shape[0]: - # If they do not match, modify var to match the dimensions - missing_rows = X.shape[1] - var.shape[0] - additional_rows = pd.DataFrame(index=range(var.shape[0], X.shape[1])) - var = pd.concat([var, additional_rows]) + for mgene, hgenes in iterator: + mgene_data = original_data[:, mgene].X.toarray().flatten() if issparse(original_data[:, mgene].X) else original_data[:, mgene].X - return AnnData(X, adata.obs, var, adata.uns, adata.obsm) + for hgene in hgenes: + if hgene not in var.index: + # Prepare to add a new gene + new_genes.append(hgene) + gene_updates[hgene] = mgene_data + else: + # Aggregate updates for existing genes + if hgene in gene_updates: + gene_updates[hgene] += mgene_data + else: + idx = np.where(var.index == hgene)[0][0] + if stay_sparse: + X[:, idx] += csr_matrix(mgene_data).transpose() + else: + X[:, idx] += mgene_data + + # Efficiently handle new genes + if new_genes: + new_gene_matrix = np.array([gene_updates[hgene] for hgene in new_genes]).T + if stay_sparse: + new_gene_matrix = csr_matrix(new_gene_matrix) + X = np.hstack((X, new_gene_matrix)) if not stay_sparse else csr_matrix(np.hstack((X.toarray(), new_gene_matrix.toarray()))) + new_var_entries = pd.DataFrame(index=new_genes) + new_var_entries['original_gene_symbol'] = 'multiple' + var = pd.concat([var, new_var_entries]) + + # Convert back to csr_matrix if originally sparse and requested to stay sparse + X_final = csr_matrix(X) if stay_sparse and not issparse(adata.X) else X + + return AnnData(X_final, adata.obs, var, adata.uns, adata.obsm) def collapse_duplicate_genes(adata, stay_sparse=False): diff --git a/src/tests/main_test.py b/src/tests/main_test.py index aaa70de..9420d21 100644 --- a/src/tests/main_test.py +++ b/src/tests/main_test.py @@ -42,6 +42,6 @@ def test_PBMC_hcop(): adata = read("data/Pancreas/pbmc3k_raw.h5ad", backup_url=url) adata.var_names_make_unique() - mousified_adata = translate(adata, source='hcop') + mousified_adata = translate(adata, source='hcop', stay_sparse=True) assert mousified_adata.n_obs == adata.n_obs, "We lost cells during mapping, which should not happen!" - assert mousified_adata.n_vars > 10000, "Very few genes (less than 10k) could be mapped! Expecting more!" \ No newline at end of file + assert mousified_adata.n_vars > 10000, "Very few genes (less than 10k) could be mapped! Expecting more!"