diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index f3663c09902f52..a808f2cb63e861 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -314,7 +314,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) Args: elements (`torch.Tensor`): Input elements - test_elements (`torch.Tensor`): The elements to check against. + test_elements (`torch.Tensor` or `int`): The elements to check against. Returns: `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements` @@ -322,6 +322,9 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) """ if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + test_elements = torch.tensor(test_elements) + if test_elements.ndim == 0: + test_elements = test_elements.unsqueeze(0) return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() else: # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 8af47cde8e5315..430043496c4217 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1710,7 +1710,12 @@ def test_isin_mps_friendly(self): torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer) ) ) - # We can match against an tensor of integers + # We can match against an 0D tensor + random_test_tensor = torch.randint(0, 100, (1,)).squeeze() + self.assertTrue( + torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) + ) + # We can match against an 1D tensor (with many items) random_test_tensor = torch.randint(0, 100, (10,)) self.assertTrue( torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))