Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x #3280

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

kaushalprasadhial
Copy link
Contributor

We had created this PR before #3099. This one is the same PR with editing enabled for maintainers.
Hi,
We are submitting PR for speed up of the _get_mean_var function.

Time(sec)
Original 18.49
Updated 3.97
Speedup 4.65743073

experiment setup : AWS r7i.24xlarge

import time
import numpy as np

import pandas as pd

import scanpy as sc
from sklearn.cluster import KMeans

import os
import wget

import warnings



warnings.filterwarnings('ignore', 'Expected ')
warnings.simplefilter('ignore')
input_file = "./1M_brain_cells_10X.sparse.h5ad"

if not os.path.exists(input_file):
    print('Downloading import file...')
    wget.download('https://rapids-single-cell-examples.s3.us-east-2.amazonaws.com/1M_brain_cells_10X.sparse.h5ad',input_file)


# marker genes
MITO_GENE_PREFIX = "mt-" # Prefix for mitochondrial genes to regress out
markers = ["Stmn2", "Hes1", "Olig1"] # Marker genes for visualization

# filtering cells
min_genes_per_cell = 200 # Filter out cells with fewer genes than this expressed
max_genes_per_cell = 6000 # Filter out cells with more genes than this expressed

# filtering genes
min_cells_per_gene = 1 # Filter out genes expressed in fewer cells than this
n_top_genes = 4000 # Number of highly variable genes to retain

# PCA
n_components = 50 # Number of principal components to compute

# t-SNE
tsne_n_pcs = 20 # Number of principal components to use for t-SNE

# k-means
k = 35 # Number of clusters for k-means

# Gene ranking

ranking_n_top_genes = 50 # Number of differential genes to compute for each cluster

# Number of parallel jobs
sc._settings.ScanpyConfig.n_jobs = os.cpu_count()

start=time.time()
tr=time.time()
adata = sc.read(input_file)
adata.var_names_make_unique()
adata.shape
print("Total read time : %s" % (time.time()-tr))



tr=time.time()
# To reduce the number of cells:
USE_FIRST_N_CELLS = 1300000
adata = adata[0:USE_FIRST_N_CELLS]
adata.shape

sc.pp.filter_cells(adata, min_genes=min_genes_per_cell)
sc.pp.filter_cells(adata, max_genes=max_genes_per_cell)
sc.pp.filter_genes(adata, min_cells=min_cells_per_gene)
sc.pp.normalize_total(adata, target_sum=1e4)
print("Total filter and normalize time : %s" % (time.time()-tr))


tr=time.time()
sc.pp.log1p(adata)
print("Total log time : %s" % (time.time()-tr))


# Select highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor = "cell_ranger")

# Retain marker gene expression
for marker in markers:
        adata.obs[marker + "_raw"] = adata.X[:, adata.var.index == marker].toarray().ravel()

# Filter matrix to only variable genes
adata = adata[:, adata.var.highly_variable]

ts=time.time()
#Regress out confounding factors (number of counts, mitochondrial gene expression)
mito_genes = adata.var_names.str.startswith(MITO_GENE_PREFIX)
n_counts = np.array(adata.X.sum(axis=1))
adata.obs['percent_mito'] = np.array(np.sum(adata[:, mito_genes].X, axis=1)) / n_counts
adata.obs['n_counts'] = n_counts


sc.pp.regress_out(adata, ['n_counts', 'percent_mito'])
print("Total regress out time : %s" % (time.time()-ts))

#scale

ts=time.time()
sc.pp.scale(adata)
print("Total scale time : %s" % (time.time()-ts))

add timer around _get_mean_var call

mean, var = _get_mean_var(X)

we can also create _get_mean_var_std function that return std as well so we don't require to compute it in scale function(L168-L169).

Copy link

codecov bot commented Oct 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 76.59%. Comparing base (5654389) to head (8f357fb).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3280      +/-   ##
==========================================
- Coverage   76.67%   76.59%   -0.08%     
==========================================
  Files         112      112              
  Lines       12946    12949       +3     
==========================================
- Hits         9926     9918       -8     
- Misses       3020     3031      +11     
Files with missing lines Coverage Δ
src/scanpy/preprocessing/_utils.py 55.88% <100.00%> (+0.32%) ⬆️

... and 3 files with indirect coverage changes

@kaushalprasadhial kaushalprasadhial marked this pull request as draft October 14, 2024 09:36
@kaushalprasadhial kaushalprasadhial marked this pull request as ready for review October 14, 2024 09:51
@Intron7 Intron7 self-assigned this Oct 14, 2024
Comment on lines 36 to 38
if isinstance(X, np.ndarray):
n_threads = numba.get_num_threads()
mean, var = _compute_mean_var(X, axis=axis, n_threads=n_threads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened to sparse?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still try to integrate all the stuff it's a mess though.

else:
mean = axis_mean(X, axis=axis, dtype=np.float64)
mean_sq = axis_mean(elem_mul(X, X), axis=axis, dtype=np.float64)
var = mean_sq - mean**2
# enforce R convention (unbiased estimator) for variance
if X.shape[axis] != 1:
var *= X.shape[axis] / (X.shape[axis] - 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove erroneous diffs

src/scanpy/preprocessing/_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! This seems a bit unfinished. If you need our help with a local test environment or so, we’re happy to help.

@@ -33,7 +33,10 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray:
def _get_mean_var(
X: _SupportedArray, *, axis: Literal[0, 1] = 0
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
if isinstance(X, sparse.spmatrix):
if isinstance(X, np.ndarray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a second code path that handles np.ndarray, you should replace the existing one above:

@axis_mean.register(np.ndarray)
def _(X: np.ndarray, ...): ...

Comment on lines +73 to +75
axis_i = 0
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull this out of the if branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean these assignments. When you have two branches, and both start with the same 3 lines, just do those before the if statement instead.

var[r] += value * value
for c in numba.prange(X.shape[0]):
mean[c] = mean[c] / X.shape[1]
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no return statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have updated the code please check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the tests aren’t passing, so it still doesn’t seem to be working

@Intron7
Copy link
Member

Intron7 commented Nov 22, 2024

This also doesn't work for F-Arrays

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants