Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update installation #76

Merged
merged 10 commits into from
May 26, 2024
Merged
6 changes: 4 additions & 2 deletions docs/Installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ Conda install

conda activate scar

4, Install scar::
4, Install `PyTorch <https://pytorch.org/get-started/locally/>`_

5, Install scar::

conda install bioconda::scar

5, Activate the scar conda environment::
6, Activate the scar conda environment::

conda activate scar

Expand Down
84 changes: 63 additions & 21 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ class model:
Thank Will Macnair for the valuable feedback.

.. versionadded:: 0.4.0
batch_key : str, optional
batch key in AnnData.obs, by default None. \
If assigned, batch ambient removel will be performed and \
the ambient profile will be estimated for each batch.

.. versionadded:: 0.6.1

device : str, optional
either "auto, "cpu" or "cuda", by default "auto"
verbose : bool, optional
whether to print the details, by default True

Raises
------
Expand Down Expand Up @@ -200,6 +211,7 @@ def __init__(
feature_type: str = "mRNA",
count_model: str = "binomial",
sparsity: float = 0.9,
batch_key: str = None,
device: str = "auto",
verbose: bool = True,
):
Expand Down Expand Up @@ -262,7 +274,7 @@ def __init__(
"""float, the sparsity of expected native signals. (0, 1]. \
Forced to be one in the mode of "sgRNA(s)" and "tag(s)".
"""

if isinstance(raw_count, str):
raw_count = pd.read_pickle(raw_count)
elif isinstance(raw_count, np.ndarray):
Expand All @@ -274,8 +286,25 @@ def __init__(
elif isinstance(raw_count, pd.DataFrame):
pass
elif isinstance(raw_count, ad.AnnData):
if batch_key:
if batch_key not in raw_count.obs.columns:
raise ValueError(f"{batch_key} not found in AnnData.obs.")

self.logger.info(
f"Estimating ambient profile for each batch defined by {batch_key} in AnnData.obs..."
)
batch_id_per_cell = pd.Categorical(raw_count.obs[batch_key]).codes
ambient_profile = np.empty((len(np.unique(batch_id_per_cell)),raw_count.shape[1]))
for batch_id in np.unique(batch_id_per_cell):
subset = raw_count[batch_id_per_cell==batch_id]
ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()

# add a mapper to locate the batch id
self.batch_id = torch.from_numpy(batch_id_per_cell).int().to(self.device)
self.n_batch = np.unique(batch_id_per_cell).size

# get ambient profile from AnnData.uns
if (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
self.logger.info(
"Found ambient profile in AnnData.uns['ambient_profile_all']"
)
Expand All @@ -296,7 +325,7 @@ def __init__(
raw_count = raw_count.fillna(0) # missing vals -> zeros

# Loading numpy to tensor on GPU
self.raw_count = torch.from_numpy(raw_count.values).int().to(self.device)
self.raw_count = raw_count.values
"""raw_count : np.ndarray, raw count matrix.
"""
self.n_features = raw_count.shape[1]
Expand Down Expand Up @@ -324,9 +353,12 @@ def __init__(
ambient_profile = (
ambient_profile.squeeze()
.reshape(1, -1)
.repeat(raw_count.shape[0], axis=0)
)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device)
# add a mapper to locate the artificial batch id
self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device)
self.n_batch = 1

self.ambient_profile = ambient_profile
"""ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript.
"""

Expand Down Expand Up @@ -410,21 +442,17 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, train_ids)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=train_ids)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
)
val_set = UMIDataset(self.raw_count, self.ambient_profile, test_ids)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=test_ids)
val_generator = torch.utils.data.DataLoader(
val_set, batch_size=batch_size, shuffle=shuffle
)

loss_values = []

# self.n_batch_train = len(training_generator)
# self.n_batch_val = len(val_generator)
# self.batch_size = batch_size

# Define model
vae_nets = VAE(
n_features=self.n_features,
Expand All @@ -435,6 +463,7 @@ def train(
feature_type=self.feature_type,
count_model=self.count_model,
sparsity=self.sparsity,
n_batch=self.n_batch,
verbose=verbose,
).to(self.device)
# Define optimizer
Expand All @@ -459,15 +488,15 @@ def train(
desc="Training",
)
progress_bar.clear()
for epoch in range(epochs):
for _ in range(epochs):
train_tot_loss = 0
train_kld_loss = 0
train_recon_loss = 0

vae_nets.train()
for x_batch, ambient_freq in training_generator:
for x_batch, ambient_freq, batch_id_onehot in training_generator:
optim.zero_grad()
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch)
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot)
recon_loss_minibatch, kld_loss_minibatch, loss_minibatch = loss_fn(
x_batch,
dec_nr,
Expand Down Expand Up @@ -559,7 +588,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile)
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
Expand All @@ -574,7 +603,7 @@ def inference(
total_set, batch_size=batch_size, shuffle=False
)

for x_batch_tot, ambient_freq_tot in generator_full_data:
for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
minibatch_size = x_batch_tot.shape[
0
] # if not the last batch, equals to batch size
Expand All @@ -586,6 +615,7 @@ def inference(
noise_ratio_batch,
) = self.trained_model.inference(
x_batch_tot,
x_batch_id_onehot_tot,
ambient_freq_tot[0, :],
count_model_inf=count_model_inf,
adjust=adjust,
Expand Down Expand Up @@ -677,10 +707,14 @@ def assignment(self, cutoff=3, moi=None):
class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
"""Initialization"""
self.raw_count = raw_count
self.ambient_profile = ambient_profile
self.device = device
self.raw_count = torch.from_numpy(raw_count).int().to(device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = batch_id.to(torch.int64).to(device)
self.batch_onehot = self._onehot()

if list_ids:
self.list_ids = list_ids
else:
Expand All @@ -695,5 +729,13 @@ def __getitem__(self, index):
# Select sample
sc_id = self.list_ids[index]
sc_count = self.raw_count[sc_id, :]
sc_ambient = self.ambient_profile[sc_id, :]
return sc_count, sc_ambient
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :]
return sc_count, sc_ambient, sc_batch_id_onehot

def _onehot(self):
"""One-hot encoding"""
n_batch = self.batch_id.unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch).to(self.device)
x_onehot.scatter_(1, self.batch_id.unique().unsqueeze(1), 1)
return x_onehot
30 changes: 19 additions & 11 deletions scar/main/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
dropout_prob=0,
feature_type="mRNA",
count_model="binomial",
n_batch=1,
sparsity=0.9,
verbose=True,
):
Expand Down Expand Up @@ -81,10 +82,11 @@ def __init__(
sparsity = 1

self.encoder = Encoder(
n_features, nn_layer1, nn_layer2, latent_dim, dropout_prob
n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob
)
self.decoder = Decoder(
n_features,
n_batch,
nn_layer1,
nn_layer2,
latent_dim,
Expand All @@ -105,16 +107,17 @@ def __init__(
vae_logger.info(f"...dropout_prob: {dropout_prob:.2f}")
vae_logger.info(f"...expected data sparsity: {sparsity:.2f}")

def forward(self, input_matrix):
def forward(self, input_matrix, batch_id_onehot=None):
"""forward function"""
sampling, means, var = self.encoder(input_matrix)
dec_nr, dec_prob, dec_dp = self.decoder(sampling)
sampling, means, var = self.encoder(input_matrix, batch_id_onehot)
dec_nr, dec_prob, dec_dp = self.decoder(sampling, batch_id_onehot)
return dec_nr, dec_prob, means, var, dec_dp

@torch.no_grad()
def inference(
self,
input_matrix,
batch_id_onehot,
amb_prob,
count_model_inf="poisson",
adjust="micro",
Expand All @@ -128,7 +131,7 @@ def inference(
assert adjust in [False, "global", "micro"]

# Estimate native signals
dec_nr, dec_prob, _, _, _ = self.forward(input_matrix)
dec_nr, dec_prob, _, _, _ = self.forward(input_matrix, batch_id_onehot)

# Copy tensor to CPU
input_matrix_np = input_matrix.cpu().numpy()
Expand Down Expand Up @@ -230,11 +233,13 @@ class Encoder(nn.Module):
Consists of 2 FC layers.
"""

def __init__(self, n_features, nn_layer1, nn_layer2, latent_dim, dropout_prob):
def __init__(self, n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob):
"""initialization"""
super().__init__()
self.activation = nn.SELU()
self.fc1 = nn.Linear(n_features, nn_layer1)
# if n_batch > 1:
# n_features += n_batch
self.fc1 = nn.Linear(n_features + n_batch, nn_layer1)
self.bn1 = nn.BatchNorm1d(nn_layer1, momentum=0.01, eps=0.001)
self.dp1 = nn.Dropout(p=dropout_prob)
self.fc2 = nn.Linear(nn_layer1, nn_layer2)
Expand All @@ -250,9 +255,10 @@ def reparametrize(self, means, log_vars):
var = log_vars.exp() + 1e-4
return torch.distributions.Normal(means, var.sqrt()).rsample(), var

def forward(self, input_matrix):
def forward(self, input_matrix, batch_id_onehot):
"""forward function"""
input_matrix = (input_matrix + 1).log2() # log transformation of count data
input_matrix = torch.cat([input_matrix, batch_id_onehot], 1)
enc = self.fc1(input_matrix)
enc = self.bn1(enc)
enc = self.activation(enc)
Expand Down Expand Up @@ -284,6 +290,7 @@ class Decoder(nn.Module):
def __init__(
self,
n_features,
n_batch,
nn_layer1,
nn_layer2,
latent_dim,
Expand All @@ -297,7 +304,7 @@ def __init__(
self.normalization_native_freq = hnormalization()
self.noise_activation = mytanh()
self.activation_native_freq = mysoftplus(sparsity)
self.fc4 = nn.Linear(latent_dim, nn_layer2)
self.fc4 = nn.Linear(latent_dim + n_batch, nn_layer2)
self.bn4 = nn.BatchNorm1d(nn_layer2, momentum=0.01, eps=0.001)
self.dp4 = nn.Dropout(p=dropout_prob)
self.fc5 = nn.Linear(nn_layer2, nn_layer1)
Expand All @@ -311,10 +318,11 @@ def __init__(
self.dropoutprob = nn.Linear(nn_layer1, 1)
self.dropout_activation = mytanh()

def forward(self, sampling):
def forward(self, sampling, batch_id_onehot):
"""forward function"""
# decoder
dec = self.fc4(sampling)
cond_sampling = torch.cat([sampling, batch_id_onehot], 1)
dec = self.fc4(cond_sampling)
dec = self.bn4(dec)
dec = self.activation(dec)
dec = self.fc5(dec)
Expand Down
Loading