Skip to content

Commit

Permalink
Fix: Cannot convert a MPS Tensor to float64 dtype as the MPS framewor…
Browse files Browse the repository at this point in the history
…k doesn't support float64.
  • Loading branch information
liuzhaoze committed Jan 19, 2025
1 parent 151000b commit c04cd1b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def to_torch(
x.dtype.type,
np.bool_ | np.number,
): # most often case
x = torch.from_numpy(x).to(device)
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
return x.to(device)
if isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
Expand Down

0 comments on commit c04cd1b

Please sign in to comment.