Skip to content

Commit

Permalink
cleanup stride logic (#2206)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2206

- VBE vs standard KJTs have some ambigious / counter-inuitive beahvior around self.stride().  This diff clarifies contract that stride is always jus the max value of individual strides (aka batch size), ie. the dense representation value if you pad.
- Also, for Inference use cases we need to maintain the correct stride behavior for rebatching logic, so notice the stride is kept now post splits.

Reviewed By: ge0405

Differential Revision: D59340146

fbshipit-source-id: effbde94b3fab1b553acb8f48d018f25ab3234ed
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Jul 9, 2024
1 parent 9c74d8a commit 90f3054
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
16 changes: 7 additions & 9 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,17 +1425,15 @@ def __init__(
self._stride: int = -1

if stride_per_key_per_rank is not None:
if stride is not None:
raise ValueError(
"Cannot initialize KJT with both `stride` and `stride_per_key_per_rank`"
)
self._stride_per_key_per_rank = stride_per_key_per_rank
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
self._variable_stride_per_key = True
if stride_per_key_per_rank is not None:
self._stride = 0
elif all(s == self.stride_per_key()[0] for s in self.stride_per_key()):
self._stride = self.stride_per_key()[0]
if stride is not None:
self._stride = stride
else:
self._stride = (
max(self._stride_per_key) if len(self._stride_per_key) > 0 else 0
)
else:
stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets)
self._stride = stride
Expand Down Expand Up @@ -1790,7 +1788,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
end_offset = _offset_per_key[end]
keys: List[str] = self._keys[start:end]
stride, stride_per_key_per_rank = (
(None, self.stride_per_key_per_rank()[start:end])
(self._stride, self.stride_per_key_per_rank()[start:end])
if self.variable_stride_per_key()
else (self._stride, None)
)
Expand Down
23 changes: 23 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,22 @@ def test_split(self) -> None:
torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0]))
)

def test_empty_vb(self) -> None:
keys = ["index_0"]
values = torch.tensor([])
lengths = torch.tensor([])
stride_per_key_per_rank = [[]]

kjt_0 = KeyedJaggedTensor(
keys=keys,
values=values,
lengths=lengths,
stride_per_key_per_rank=stride_per_key_per_rank,
)
self.assertTrue(torch.equal(kjt_0.lengths(), torch.Tensor([])))
self.assertTrue(torch.equal(kjt_0.values(), torch.Tensor([])))
self.assertEqual(kjt_0.stride(), 0)

def test_split_vb(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
keys = ["index_0", "index_1", "index_2", "index_3"]
Expand All @@ -1214,6 +1230,9 @@ def test_split_vb(self) -> None:
self.assertEqual(j0.keys(), ["index_0"])
self.assertEqual(j1.keys(), ["index_1"])
self.assertEqual(j2.keys(), ["index_2", "index_3"])
self.assertEqual(j0.stride(), 4)
self.assertEqual(j1.stride(), 4)
self.assertEqual(j2.stride(), 4)
self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1])))
self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0])))
self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([])))
Expand All @@ -1229,6 +1248,10 @@ def test_split_vb(self) -> None:
self.assertEqual(j1.keys(), ["index_0", "index_1", "index_2"])
self.assertEqual(j2.keys(), [])
self.assertEqual(j3.keys(), ["index_3"])
self.assertEqual(j0.stride(), 4)
self.assertEqual(j1.stride(), 4)
self.assertEqual(j2.stride(), 4)
self.assertEqual(j3.stride(), 4)
self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([])))
self.assertTrue(torch.equal(j0.values(), torch.Tensor([])))
self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([2, 0, 1, 1])))
Expand Down

0 comments on commit 90f3054

Please sign in to comment.