Skip to content

Commit

Permalink
ensure tensor copy to cpu mem before data-array creation
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Nov 29, 2024
1 parent cf8e3e4 commit 85160ce
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 6 additions & 4 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ def _create_dataarray_from_tensor(
----------
tensor : torch.Tensor
The tensor to convert to a `xr.DataArray` with dimensions [time,
grid_index, feature]
grid_index, feature]. The tensor will be copied to the CPU if it is
not already there.
time : Union[int,List[int]]
The time index or indices for the data, given as integers or a list
of integers representing epoch time in nanoseconds.
of integers representing epoch time in nanoseconds. The ints will be
copied to the CPU memory if they are not already there.
split : str
The split of the data, either 'train', 'val', or 'test'
category : str
Expand All @@ -180,9 +182,9 @@ def _create_dataarray_from_tensor(
# not how this should be done but whether WeatherDataset should be
# provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
time = np.array(time, dtype="datetime64[ns]")
time = np.array(time.cpu(), dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor, time=time, category=category
tensor=tensor.cpu().numpy(), time=time, category=category
)
return da

Expand Down
5 changes: 3 additions & 2 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ def create_dataarray_from_tensor(
tensor : torch.Tensor
The tensor to construct the DataArray from, this assumed to have
the same dimension ordering as returned by the __getitem__ method
(i.e. time, grid_index, {category}_feature).
(i.e. time, grid_index, {category}_feature). The tensor will be
copied to the CPU before constructing the DataArray.
time : datetime.datetime or list[datetime.datetime]
The time or times of the tensor.
category : str
Expand Down Expand Up @@ -581,7 +582,7 @@ def _is_listlike(obj):
coords["time"] = time

da = xr.DataArray(
tensor.numpy(),
tensor.cpu().numpy(),
dims=dims,
coords=coords,
)
Expand Down

0 comments on commit 85160ce

Please sign in to comment.