From d6c78ae93ba2691edf3ed9c7d52b05b9bacb6b82 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Thu, 19 Sep 2024 14:19:07 -0700 Subject: [PATCH] remove extra squeeze length check --- modin/pandas/dataframe.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index de96ea0ab26..773e62b2e2a 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -2074,16 +2074,19 @@ def squeeze( Squeeze 1 dimensional axis objects into scalars. """ axis = self._get_axis_number(axis) if axis is not None else None - if axis is None and (len(self.columns) == 1 or len(self.index) == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 1 and len(self.columns) == 1: + len_columns = len(self.columns) + if axis == 1 and len_columns == 1: self._query_compiler._shape_hint = "column" return Series(query_compiler=self._query_compiler) - if axis == 0 and len(self.index) == 1: - qc = self.T._query_compiler - qc._shape_hint = "column" - return Series(query_compiler=qc) - else: + if axis in [0, None]: + # Only compute the length of the index if axis is 0 or None. + len_index = len(self) + if axis is None and (len_columns == 1 or len_index == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 0 and len_index == 1: + qc = self.T._query_compiler + qc._shape_hint = "column" + return Series(query_compiler=qc) return self.copy() def stack(