Skip to content

Commit

Permalink
refactor(scar): refactor scar to allow efficient usage of GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed Jul 29, 2024
1 parent c846d9b commit 4c808ab
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sys, time, contextlib, torch
from typing import Optional, Union
from scipy import sparse
import numpy as np, pandas as pd, anndata as ad

from collections import OrderedDict
Expand Down Expand Up @@ -597,10 +598,11 @@ def inference(
# total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
self.bayesfactor = np.empty([sample_size, n_features])
self.native_frequencies = np.empty([sample_size, n_features])
self.noise_ratio = np.empty([sample_size, 1])

native_counts = sparse.lil_matrix((sample_size, n_features), dtype=np.float32)
bayesfactor = sparse.lil_matrix((sample_size, n_features), dtype=np.float32)
native_frequencies = sparse.lil_matrix((sample_size, n_features), dtype=np.float32)
noise_ratio = sparse.lil_matrix((sample_size, 1), dtype=np.float32)

if not batch_size:
batch_size = sample_size
Expand Down Expand Up @@ -629,20 +631,25 @@ def inference(
round_to_int=round_to_int,
clip_to_obs=clip_to_obs,
)
self.native_counts[
native_counts[
i * batch_size : i * batch_size + minibatch_size, :
] = native_counts_batch
self.bayesfactor[
bayesfactor[
i * batch_size : i * batch_size + minibatch_size, :
] = bayesfactor_batch
self.native_frequencies[
native_frequencies[
i * batch_size : i * batch_size + minibatch_size, :
] = native_frequencies_batch
self.noise_ratio[
noise_ratio[
i * batch_size : i * batch_size + minibatch_size, :
] = noise_ratio_batch
i += 1

self.native_counts = native_counts.tocsr()
self.bayesfactor = bayesfactor.tocsr()
self.native_frequencies = native_frequencies.tocsr()
self.noise_ratio = noise_ratio.tocsr()

if self.feature_type.lower() in [
"sgrna",
"sgrnas",
Expand Down Expand Up @@ -683,7 +690,7 @@ def assignment(self, cutoff=3, moi=None):
index=self.cell_id, columns=[self.feature_type, f"n_{self.feature_type}"]
)
bayesfactor_df = pd.DataFrame(
self.bayesfactor, index=self.cell_id, columns=self.feature_names
self.bayesfactor.toarray(), index=self.cell_id, columns=self.feature_names
)
bayesfactor_df[bayesfactor_df < cutoff] = 0 # Apply the cutoff for Bayesfactors

Expand Down

0 comments on commit 4c808ab

Please sign in to comment.