Skip to content

Commit

Permalink
Fix device mismatch for JaggedTensor from_dense method.
Browse files Browse the repository at this point in the history
Summary:
In the current implementation of `from_dense` method `self.lengths` always created on cpu. This is the problem, because methods that expects both `values` and `lengths` to be on the same device will fail.

Fixed implementation to use the same device as values and added test.

Reviewed By: joshuadeng

Differential Revision: D54330659

fbshipit-source-id: b4aee78c31bb00d03390e37581b349e60a05f6a5
  • Loading branch information
Mike Plekhanov authored and facebook-github-bot committed Feb 29, 2024
1 parent 1f34283 commit 521aae1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,13 @@ def from_dense(
# j1 = [[1.0], [], [7.0], [8.0], [10.0, 11.0, 12.0]]
"""
lengths = torch.IntTensor([value.size(0) for value in values])

values_tensor = torch.cat(values, dim=0)
lengths = torch.tensor(
[value.size(0) for value in values],
dtype=torch.int32,
device=values_tensor.device,
)
weights_tensor = torch.cat(weights, dim=0) if weights is not None else None

return JaggedTensor(
Expand Down
20 changes: 20 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,26 @@ def test_from_dense(self) -> None:
torch.equal(j1.weights(), torch.Tensor([1.0, 7.0, 8.0, 10.0, 11.0, 12.0]))
)

# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_from_dense_device(self) -> None:
device = torch.device("cuda", index=0)
values = [
torch.tensor([1.0], device=device),
torch.tensor([7.0, 8.0], device=device),
torch.tensor([10.0, 11.0, 12.0], device=device),
]

j0 = JaggedTensor.from_dense(
values=values,
)
self.assertEqual(j0.values().device, device)
self.assertEqual(j0.lengths().device, device)
self.assertEqual(j0.offsets().device, device)

def test_to_dense(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
Expand Down

0 comments on commit 521aae1

Please sign in to comment.