From ad5ca918e6e7d05799e5708536166b253ff224ed Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 25 Jan 2024 15:04:20 +0000 Subject: [PATCH] Small performance boost and bugfix --- src/mousipy/mousipy.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/mousipy/mousipy.py b/src/mousipy/mousipy.py index 2b5dac9..1ff1b7a 100644 --- a/src/mousipy/mousipy.py +++ b/src/mousipy/mousipy.py @@ -190,7 +190,6 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos """ var = adata.var.copy() ortholog_indices = {gene: i for i, gene in enumerate(var.index)} - new_genes = [] if stay_sparse: # Sparse implementation remains unchanged @@ -200,7 +199,11 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos for hgene in hgenes: if hgene not in ortholog_indices: - new_genes.append(hgene) + # Create a new DataFrame row for the new gene + new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene]) + new_row['original_gene_symbol'] = 'multiple' + var = pd.concat([var, new_row]) + X = csr_matrix(np.hstack((X.toarray(), mgene_data.reshape(-1, 1)))) ortholog_indices[hgene] = X.shape[1] - 1 else: @@ -220,7 +223,11 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos for hgene in hgenes: if hgene not in ortholog_indices: - new_genes.append(hgene) + # Create a new DataFrame row for the new gene + new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene]) + new_row['original_gene_symbol'] = 'multiple' + var = pd.concat([var, new_row]) + new_data[:, next_new_gene_idx] = mgene_data.ravel() ortholog_indices[hgene] = next_new_gene_idx next_new_gene_idx += 1 @@ -230,11 +237,12 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos X = new_data - # Creating new DataFrame rows for new genes - for hgene in new_genes: - new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene]) - new_row['original_gene_symbol'] = 'multiple' - var = pd.concat([var, new_row]) + # Check the dimensions of X and var + if X.shape[1] != var.shape[0]: + # If they do not match, modify var to match the dimensions + missing_rows = X.shape[1] - var.shape[0] + additional_rows = pd.DataFrame(index=range(var.shape[0], X.shape[1])) + var = pd.concat([var, additional_rows]) return AnnData(X, adata.obs, var, adata.uns, adata.obsm)