Skip to content

Commit

Permalink
minor changes to make things work on a gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
vineetbansal committed Nov 16, 2024
1 parent b0479a3 commit 899d5ac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/paste3/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def center_ot(
use_gpu=use_gpu,
)
pis.append(pi)
losses.append(loss["loss"][-1])
losses.append(loss["loss"][-1].item())
return pis, np.array(losses)


Expand Down Expand Up @@ -547,11 +547,11 @@ def center_NMF(
)

if fast:
nmf_model.fit(feature_matrix.T)
nmf_model.to(feature_matrix).fit(feature_matrix.T)
new_feature_matrix = nmf_model.W.double().detach().cpu().numpy()
new_coeff_matrix = nmf_model.H.T.detach().cpu().numpy()
else:
new_feature_matrix = nmf_model.fit_transform(feature_matrix)
new_feature_matrix = nmf_model.fit_transform(feature_matrix.cpu())
new_coeff_matrix = nmf_model.components_
return new_feature_matrix, new_coeff_matrix

Expand Down
2 changes: 1 addition & 1 deletion src/paste3/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def generalized_procrustes_analysis(
U, S, Vt = torch.linalg.svd(covariance_matrix, full_matrices=True)
rotation_matrix = Vt.T.matmul(U.T)
target_coordinates = rotation_matrix.matmul(target_coordinates.T).T
M = torch.Tensor([[0, -1], [1, 0]]).double()
M = torch.Tensor([[0, -1], [1, 0]]).to(covariance_matrix)
rotation_angle = torch.arctan(
torch.trace(M.matmul(covariance_matrix)) / torch.trace(covariance_matrix)
)
Expand Down

0 comments on commit 899d5ac

Please sign in to comment.