Skip to content

Commit

Permalink
change dtype of block_bucketize_row_pos and fix flaky test_kjt_bucket…
Browse files Browse the repository at this point in the history
…ize_before_all2all_cpu (#2689)

Summary:
Pull Request resolved: #2689

# context
* found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461
* the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function
```
        block_bucketize_pos=(
            _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
            if block_bucketize_row_pos is not None
            else None
        ),
```
where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor
```
torch.fx.wrap
def _fx_wrap_tensor_to_device_dtype(
    t: torch.Tensor, tensor_device_dtype: torch.Tensor
) -> torch.Tensor:
    return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
```
* the fix is supposed to be straightforward to apply a list-comprehension over the function
```
        block_bucketize_pos=(
            [
                _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths())  # <---- pay attention here, kjt.lengths()
                for pos in block_bucketize_row_pos
            ]
```
* according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error
 {F1974430883}
* according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`.
length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`.

Reviewed By: dstaay-fb

Differential Revision: D68358894

fbshipit-source-id: 13303c54288c99c6cf58d550365f8d3c698c34b1
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jan 21, 2025
1 parent 53752ea commit f52fd32
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 100 deletions.
5 changes: 4 additions & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ def bucketize_kjt_before_all2all(
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=(
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
[
_fx_wrap_tensor_to_device_dtype(pos, kjt.values())
for pos in block_bucketize_row_pos
]
if block_bucketize_row_pos is not None
else None
),
Expand Down
113 changes: 14 additions & 99 deletions torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,98 +263,6 @@ def block_bucketize_ref(


class KJTBucketizeTest(unittest.TestCase):
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
# pyre-ignore[56]
@given(
index_type=st.sampled_from([torch.int, torch.long]),
offset_type=st.sampled_from([torch.int, torch.long]),
world_size=st.integers(1, 129),
num_features=st.integers(1, 15),
batch_size=st.integers(1, 15),
)
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_kjt_bucketize_before_all2all(
self,
index_type: torch.dtype,
offset_type: torch.dtype,
world_size: int,
num_features: int,
batch_size: int,
) -> None:
MAX_BATCH_SIZE = 15
MAX_LENGTH = 10
# max number of rows needed for a given feature to have unique row index
MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE

lengths_list = [
random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size)
]
keys_list = [f"feature_{i}" for i in range(num_features)]
# for each feature, generate unrepeated row indices
indices_lists = [
random.sample(
range(MAX_ROW_COUNT),
# number of indices needed is the length sum of all batches for a feature
sum(
lengths_list[
feature_offset * batch_size : (feature_offset + 1) * batch_size
]
),
)
for feature_offset in range(num_features)
]
indices_list = list(itertools.chain(*indices_lists))

weights_list = [random.randint(1, 100) for _ in range(len(indices_list))]

# for each feature, calculate the minimum block size needed to
# distribute all rows to the available trainers
block_sizes_list = [
(
math.ceil((max(feature_indices_list) + 1) / world_size)
if feature_indices_list
else 1
)
for feature_indices_list in indices_lists
]

kjt = KeyedJaggedTensor(
keys=keys_list,
lengths=torch.tensor(lengths_list, dtype=offset_type)
.view(num_features * batch_size)
.cuda(),
values=torch.tensor(indices_list, dtype=index_type).cuda(),
weights=torch.tensor(weights_list, dtype=torch.float).cuda(),
)
"""
each entry in block_sizes identifies how many hashes for each feature goes
to every rank; we have three featues in `self.features`
"""
block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda()

block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt=kjt,
num_buckets=world_size,
block_sizes=block_sizes,
)

expected_block_bucketized_kjt = block_bucketize_ref(
kjt,
world_size,
block_sizes,
)

self.assertTrue(
keyed_jagged_tensor_equals(
block_bucketized_kjt,
expected_block_bucketized_kjt,
is_pooled_features=True,
)
)

# pyre-ignore[56]
@given(
index_type=st.sampled_from([torch.int, torch.long]),
Expand All @@ -363,16 +271,20 @@ def test_kjt_bucketize_before_all2all(
num_features=st.integers(1, 15),
batch_size=st.integers(1, 15),
variable_bucket_pos=st.booleans(),
device=st.sampled_from(
["cpu"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
),
)
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_kjt_bucketize_before_all2all_cpu(
@settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None)
def test_kjt_bucketize_before_all2all(
self,
index_type: torch.dtype,
offset_type: torch.dtype,
world_size: int,
num_features: int,
batch_size: int,
variable_bucket_pos: bool,
device: str,
) -> None:
MAX_BATCH_SIZE = 15
MAX_LENGTH = 10
Expand Down Expand Up @@ -423,17 +335,17 @@ def test_kjt_bucketize_before_all2all_cpu(

kjt = KeyedJaggedTensor(
keys=keys_list,
lengths=torch.tensor(lengths_list, dtype=offset_type).view(
lengths=torch.tensor(lengths_list, dtype=offset_type, device=device).view(
num_features * batch_size
),
values=torch.tensor(indices_list, dtype=index_type),
weights=torch.tensor(weights_list, dtype=torch.float),
values=torch.tensor(indices_list, dtype=index_type, device=device),
weights=torch.tensor(weights_list, dtype=torch.float, device=device),
)
"""
each entry in block_sizes identifies how many hashes for each feature goes
to every rank; we have three featues in `self.features`
"""
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
block_sizes = torch.tensor(block_sizes_list, dtype=index_type, device=device)
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt=kjt,
num_buckets=world_size,
Expand All @@ -442,7 +354,10 @@ def test_kjt_bucketize_before_all2all_cpu(
)

expected_block_bucketized_kjt = block_bucketize_ref(
kjt, world_size, block_sizes, "cpu"
kjt,
world_size,
block_sizes,
device,
)

self.assertTrue(
Expand Down

0 comments on commit f52fd32

Please sign in to comment.