Skip to content

Commit

Permalink
refactor: cast Unitary to cp array on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
andrea-pasquale committed Feb 3, 2025
1 parent 3fac930 commit dadf2ec
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/qibojit/backends/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def _cast(self, x, dtype):

# Necessary to avoid https://github.com/qiboteam/qibo/issues/928
def Unitary(self, u):
if isinstance(u, self.cp.ndarray):
u = u.get()
return super().Unitary(u)
dtype = getattr(np, self.dtype)
return self._cast(u, dtype=dtype)


class CustomCuQuantumMatrices(CustomMatrices): # pragma: no cover
Expand Down

0 comments on commit dadf2ec

Please sign in to comment.