diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index c875688b..0d8e6e3c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -167,10 +167,12 @@ def _create_dataarray_from_tensor( ---------- tensor : torch.Tensor The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature] + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. time : Union[int,List[int]] The time index or indices for the data, given as integers or a list - of integers representing epoch time in nanoseconds. + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. split : str The split of the data, either 'train', 'val', or 'test' category : str @@ -180,9 +182,9 @@ def _create_dataarray_from_tensor( # not how this should be done but whether WeatherDataset should be # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) - time = np.array(time, dtype="datetime64[ns]") + time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor, time=time, category=category + tensor=tensor.cpu().numpy(), time=time, category=category ) return da diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..b5f85580 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -529,7 +529,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +582,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, )