-
Notifications
You must be signed in to change notification settings - Fork 467
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
change dtype of block_bucketize_row_pos and fix flaky test_kjt_bucket…
…ize_before_all2all_cpu (#2689) Summary: Pull Request resolved: #2689 # context * found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461 * the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function ``` block_bucketize_pos=( _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) if block_bucketize_row_pos is not None else None ), ``` where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor ``` torch.fx.wrap def _fx_wrap_tensor_to_device_dtype( t: torch.Tensor, tensor_device_dtype: torch.Tensor ) -> torch.Tensor: return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) ``` * the fix is supposed to be straightforward to apply a list-comprehension over the function ``` block_bucketize_pos=( [ _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths()) # <---- pay attention here, kjt.lengths() for pos in block_bucketize_row_pos ] ``` * according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error {F1974430883} * according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`. length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`. Reviewed By: dstaay-fb Differential Revision: D68358894 fbshipit-source-id: 13303c54288c99c6cf58d550365f8d3c698c34b1
- Loading branch information
1 parent
53752ea
commit f52fd32
Showing
2 changed files
with
18 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters