diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index c07c9e2c103..2360baaebab 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers import operator from collections.abc import Iterable, Mapping, MutableMapping from functools import partial @@ -566,7 +567,7 @@ def _check_valid_index_key(key: Union[int, slice, range, Iterable], size: int) - def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str: - if isinstance(key, int): + if isinstance(key, numbers.Integral): return "row" elif isinstance(key, str): return "column" diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 2e54aadf7b6..26788631b88 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4494,12 +4494,22 @@ async def f(batch): assert len(out) == 1 +def test_dataset_getitem_int_np_equivalence(): + ds = Dataset.from_dict({"a": [0, 1, 2, 3]}) + + assert ds[1] == ds[np.int64(1)] + + def test_dataset_getitem_raises(): ds = Dataset.from_dict({"a": [0, 1, 2, 3]}) with pytest.raises(TypeError): ds[False] with pytest.raises(TypeError): ds._getitem(True) + with pytest.raises(TypeError): + ds[np.bool_(True)] + with pytest.raises(TypeError): + ds[1.0] def test_categorical_dataset(tmpdir):