diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 9fbc608c12a4b..62b42c1fcd02c 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -810,7 +810,18 @@ def symmetric_difference( # type: ignore[override] sdf_self = self._psdf._internal.spark_frame.select(self._internal.index_spark_columns) sdf_other = other._psdf._internal.spark_frame.select(other._internal.index_spark_columns) - sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other)) + tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__") + tmp_max_col = verify_temp_column_name(sdf_self, "__multi_index_max_tag__") + tmp_min_col = verify_temp_column_name(sdf_self, "__multi_index_min_tag__") + + sdf_symdiff = ( + sdf_self.withColumn(tmp_tag_col, F.lit(0)) + .union(sdf_other.withColumn(tmp_tag_col, F.lit(1))) + .groupBy(*self._internal.index_spark_column_names) + .agg(F.min(tmp_tag_col).alias(tmp_min_col), F.max(tmp_tag_col).alias(tmp_max_col)) + .where(F.col(tmp_min_col) == F.col(tmp_max_col)) + .select(*self._internal.index_spark_column_names) + ) if sort: sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names)