Skip to content

Commit

Permalink
add simple complex number example
Browse files Browse the repository at this point in the history
  • Loading branch information
KaelanDt committed Apr 18, 2024
1 parent 055c2ba commit 8710202
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,20 @@ def partial_fvp(v):
damping = 1

sol = torch.linalg.solve(fisher + damping * torch.eye(fisher.shape[0]), v)
sol_cg, _ = cg(partial_fvp, v, x0=None, damping=damping, tol=torch.Tensor([1e-10]))
sol_cg, _ = cg(partial_fvp, v, x0=None, damping=damping, tol=1e-10)

assert torch.allclose(sol, sol_cg, rtol=1e-1)

# simple complex number example
A = torch.tensor([[0, -1j], [1j, 0]])

def mvp(x):
return A @ x

b = torch.randn(2, dtype=torch.cfloat)

sol = torch.linalg.solve(A, b)
sol_cg, _ = cg(mvp, b, x0=None, tol=1e-10)

assert torch.allclose(sol, sol_cg, rtol=1e-1)

Expand All @@ -298,9 +311,7 @@ def partial_fvp(v):

v, _ = tree_ravel(params)
sol = torch.linalg.solve(fisher + damping * torch.eye(fisher.shape[0]), v)
sol_cg, _ = cg(
partial_fvp, params, x0=None, damping=damping, tol=torch.Tensor([1e-10])
)
sol_cg, _ = cg(partial_fvp, params, x0=None, damping=damping, tol=1e-10)

assert torch.allclose(sol, tree_ravel(sol_cg)[0], rtol=1e-3)

Expand Down

0 comments on commit 8710202

Please sign in to comment.