From c04cd1b548bfb1cdc30bfbaa823870577c6b8527 Mon Sep 17 00:00:00 2001 From: liuzhaoze <1045954863@qq.com> Date: Sun, 19 Jan 2025 12:47:19 +0800 Subject: [PATCH] Fix: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. --- tianshou/data/utils/converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 8f07e0494..acda7827c 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -45,10 +45,10 @@ def to_torch( x.dtype.type, np.bool_ | np.number, ): # most often case - x = torch.from_numpy(x).to(device) + x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) - return x + return x.to(device) if isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype)