From e5d866bc68c4762ebd6e3e888e4abeaf4ccd9302 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 26 Feb 2025 07:01:27 -0800 Subject: [PATCH] Short circuit Index.equal if compared Index isn't same type (#18067) closes https://github.com/rapidsai/cudf/issues/8689 Before, comparing two different Index subclasses would execute a GPU kernel when we know they wouldn't be equal (e.g. DatetimeIndex equals RangeIndex). This PR add a short circuit clause to check that we are comparing the same subclasses. Also ensures we don't return a `np.bool_` object from this result. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/18067 --- python/cudf/cudf/core/column/column.py | 2 +- python/cudf/cudf/core/index.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 06dc4058115..67a0aa7a781 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -713,7 +713,7 @@ def all(self, skipna: bool = True) -> bool: # is empty. if self.null_count == self.size: return True - return self.reduce("all") + return bool(self.reduce("all")) def any(self, skipna: bool = True) -> bool: # Early exit for fast cases. diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 1730a692dc1..f4e5f6e96ae 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -1286,6 +1286,15 @@ def equals(self, other) -> bool: elif other_is_categorical and not self_is_categorical: self = self.astype(other.dtype) check_dtypes = True + elif ( + not self_is_categorical + and not other_is_categorical + and not isinstance(other, RangeIndex) + and not isinstance(self, type(other)) + ): + # Can compare Index to CategoricalIndex or RangeIndex + # Other comparisons are invalid + return False try: return self._column.equals(