diff --git a/torch_optimizer/shampoo.py b/torch_optimizer/shampoo.py index fc2706d..88e7765 100644 --- a/torch_optimizer/shampoo.py +++ b/torch_optimizer/shampoo.py @@ -8,7 +8,8 @@ def _matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor: # use CPU for svd for speed up device = matrix.device matrix = matrix.cpu() - u, s, v = torch.svd(matrix) + u, s, vh = torch.linalg.svd(matrix, full_matrices=False) + v = vh.mH return (u @ s.pow_(power).diag() @ v.t()).to(device)