Skip to content

Commit

Permalink
fix: fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 26, 2024
1 parent 24e1f95 commit db17c87
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
self.raw_count = torch.from_numpy(raw_count).int().to(device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = batch_id.to(torch.int64).to(device)
self.batch_onehot = self._onehot(batch_id.to(torch.int64)).to(device)
self.batch_onehot = self._onehot()

if list_ids:
self.list_ids = list_ids
Expand All @@ -733,10 +733,9 @@ def __getitem__(self, index):
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :]
return sc_count, sc_ambient, sc_batch_id_onehot

def _onehot(self, batch_id):
def _onehot(self):
"""One-hot encoding"""
batch_id = batch_id.to(self.device)
n_batch = batch_id.unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch)
x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1)
n_batch = self.batch_id.unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch).to(self.device)
x_onehot.scatter_(1, self.batch_id.unique().unsqueeze(1), 1)
return x_onehot

0 comments on commit db17c87

Please sign in to comment.