diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index efc76bb99f562..f3bbab69f2719 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1108,7 +1108,7 @@ def close(self) -> None: """ Close the channel. """ - ExecutePlanResponseReattachableIterator.shutdown_threadpool() + ExecutePlanResponseReattachableIterator.shutdown() self._channel.close() self._closed = True diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index c20d2b6e2e83d..82c7ae9772188 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -74,7 +74,7 @@ def _release_thread_pool(cls) -> ThreadPool: return cls._release_thread_pool_instance @classmethod - def shutdown_threadpool(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: + def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: """ When the channel is closed, this method will be called before, to make sure all outstanding calls are closed. @@ -85,15 +85,6 @@ def shutdown_threadpool(cls: Type["ExecutePlanResponseReattachableIterator"]) -> cls._release_thread_pool.join() # type: ignore[attr-defined] cls._release_thread_pool_instance = None - def shutdown(self: "ExecutePlanResponseReattachableIterator") -> None: - """ - When the channel is closed, this method will be called before, to make sure all - outstanding calls are closed, and mark this iterator is shutdown. - """ - with self._lock: - self.shutdown_threadpool() - self._is_shutdown = True - def __init__( self, request: pb2.ExecutePlanRequest, @@ -101,7 +92,7 @@ def __init__( retrying: Callable[[], Retrying], metadata: Iterable[Tuple[str, str]], ): - self._is_shutdown = False + self._release_thread_pool # Trigger initialization self._request = request self._retrying = retrying if request.operation_id: @@ -219,8 +210,9 @@ def target() -> None: except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") - if not self._is_shutdown: - self._release_thread_pool.apply_async(target) + with self._lock: + if self._release_thread_pool_instance is not None: + self._release_thread_pool.apply_async(target) def _release_all(self) -> None: """ @@ -243,8 +235,9 @@ def target() -> None: except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") - if not self._is_shutdown: - self._release_thread_pool.apply_async(target) + with self._lock: + if self._release_thread_pool_instance is not None: + self._release_thread_pool.apply_async(target) self._result_complete = True def _call_iter(self, iter_fun: Callable) -> Any: