Skip to content

Commit 521aae1

Browse files
Mike Plekhanovfacebook-github-bot
authored andcommitted
Fix device mismatch for JaggedTensor from_dense method.
Summary: In the current implementation of `from_dense` method `self.lengths` always created on cpu. This is the problem, because methods that expects both `values` and `lengths` to be on the same device will fail. Fixed implementation to use the same device as values and added test. Reviewed By: joshuadeng Differential Revision: D54330659 fbshipit-source-id: b4aee78c31bb00d03390e37581b349e60a05f6a5
1 parent 1f34283 commit 521aae1

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,13 @@ def from_dense(
358358
359359
# j1 = [[1.0], [], [7.0], [8.0], [10.0, 11.0, 12.0]]
360360
"""
361-
lengths = torch.IntTensor([value.size(0) for value in values])
361+
362362
values_tensor = torch.cat(values, dim=0)
363+
lengths = torch.tensor(
364+
[value.size(0) for value in values],
365+
dtype=torch.int32,
366+
device=values_tensor.device,
367+
)
363368
weights_tensor = torch.cat(weights, dim=0) if weights is not None else None
364369

365370
return JaggedTensor(

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,26 @@ def test_from_dense(self) -> None:
257257
torch.equal(j1.weights(), torch.Tensor([1.0, 7.0, 8.0, 10.0, 11.0, 12.0]))
258258
)
259259

260+
# pyre-ignore[56]
261+
@unittest.skipIf(
262+
torch.cuda.device_count() <= 0,
263+
"CUDA is not available",
264+
)
265+
def test_from_dense_device(self) -> None:
266+
device = torch.device("cuda", index=0)
267+
values = [
268+
torch.tensor([1.0], device=device),
269+
torch.tensor([7.0, 8.0], device=device),
270+
torch.tensor([10.0, 11.0, 12.0], device=device),
271+
]
272+
273+
j0 = JaggedTensor.from_dense(
274+
values=values,
275+
)
276+
self.assertEqual(j0.values().device, device)
277+
self.assertEqual(j0.lengths().device, device)
278+
self.assertEqual(j0.offsets().device, device)
279+
260280
def test_to_dense(self) -> None:
261281
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
262282
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])

0 commit comments

Comments
 (0)