diff --git a/millipede/util.py b/millipede/util.py index e6a8e62..a0bc76b 100644 --- a/millipede/util.py +++ b/millipede/util.py @@ -20,7 +20,7 @@ def safe_cholesky(A, epsilon=1.0e-8): if A.shape == (1, 1): return A.sqrt() try: - return torch.linalg.cholesky(A) + return torch.linalg.cholesky(A, upper=False) except RuntimeError as e: Aprime = A.clone() jitter_prev = 0.0 @@ -29,7 +29,7 @@ def safe_cholesky(A, epsilon=1.0e-8): Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev) jitter_prev = jitter_new try: - return torch.linalg.cholesky(Aprime) + return torch.linalg.cholesky(Aprime, upper=False) except RuntimeError: continue raise e