Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 30, 2024
1 parent 798afdc commit f4a45d9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion torch_frame/data/multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def cuda(self, *args, **kwargs):
return self._apply(lambda x: x.cuda(*args, **kwargs))

def pin_memory(self, *args, **kwargs):
return self._apply(lambda x: x.pin_memory(*args, **kwargs))
out = self._apply(lambda x: _pin_memory(x, *args, **kwargs))
print('result:', out, out.is_pinned(), out.device, args, kwargs)
return out

def is_pinned(self) -> bool:
return self.values.is_pinned() and self.offset.is_pinned()
Expand Down Expand Up @@ -394,3 +396,8 @@ def _batched_arange(count: Tensor) -> tuple[Tensor, Tensor]:
arange -= ptr[batch]

return batch, arange


def _pin_memory(tensor: Tensor, *args, **kwargs) -> Tensor:
print('input:', tensor, tensor.is_pinned(), tensor.device, args, kwargs)
return tensor.pin_memory(*args, **kwargs)

0 comments on commit f4a45d9

Please sign in to comment.