diff --git a/src/steamship/base/tasks.py b/src/steamship/base/tasks.py index 89e5e741e..b88f54c95 100644 --- a/src/steamship/base/tasks.py +++ b/src/steamship/base/tasks.py @@ -244,6 +244,7 @@ def wait( ---------- max_timeout_s : int Max timeout in seconds. Default: 180s. After this timeout, an exception will be thrown. + A timeout of -1 is equivalent to no timeout. retry_delay_s : float Delay between status checks. Default: 1s. on_each_refresh : Optional[Callable[[int, float, Task], None]] @@ -254,7 +255,9 @@ def wait( """ t0 = time.perf_counter() refresh_count = 0 - while time.perf_counter() - t0 < max_timeout_s and self.state not in ( + while ( + (max_timeout_s == -1) or (time.perf_counter() - t0 < max_timeout_s) + ) and self.state not in ( TaskState.succeeded, TaskState.failed, ): @@ -273,6 +276,29 @@ def wait( ) return self.output + def wait_until_completed( + self, + retry_delay_s: float = 1, + on_each_refresh: "Optional[Callable[[int, float, Task], None]]" = None, + ): + """Polls and blocks until the task has succeeded or failed. No timeout on waiting is applied. + + Parameters + ---------- + retry_delay_s : float + Delay between status checks. Default: 1s. + on_each_refresh : Optional[Callable[[int, float, Task], None]] + Optional call back you can get after each refresh is made, including success state refreshes. + The signature represents: (refresh #, total elapsed time, task) + + WARNING: Do not pass a long-running function to this variable. It will block the update polling. + """ + return self.wait( + max_timeout_s=-1, # Indicates to not apply a timeout + retry_delay_s=retry_delay_s, + on_each_refresh=on_each_refresh, + ) + def refresh(self): if self.task_id is None: raise SteamshipError(message="Unable to refresh task because `task_id` is None") diff --git a/src/steamship/data/package/package_version.py b/src/steamship/data/package/package_version.py index d8ce071e5..dc222a0dd 100644 --- a/src/steamship/data/package/package_version.py +++ b/src/steamship/data/package/package_version.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, Dict, Type +from typing import Any, Dict, Optional, Type from pydantic import BaseModel, Field @@ -35,12 +35,12 @@ def parse_obj(cls: Type[BaseModel], obj: Any) -> BaseModel: @staticmethod def create( client: Client, - package_id: str = None, - handle: str = None, - filename: str = None, - filebytes: bytes = None, - config_template: Dict[str, Any] = None, - hosting_handler: str = None, + package_id: Optional[str] = None, + handle: Optional[str] = None, + filename: Optional[str] = None, + filebytes: Optional[bytes] = None, + config_template: Optional[Dict[str, Any]] = None, + hosting_handler: Optional[str] = None, ) -> PackageVersion: if filename is None and filebytes is None: @@ -65,7 +65,7 @@ def create( file=("package.zip", filebytes, "multipart/form-data"), expect=PackageVersion, ) - task.wait() + task.wait_until_completed() return task.output def delete(self) -> PackageVersion: diff --git a/src/steamship/data/plugin/plugin_version.py b/src/steamship/data/plugin/plugin_version.py index 9449d7561..92ca7d3eb 100644 --- a/src/steamship/data/plugin/plugin_version.py +++ b/src/steamship/data/plugin/plugin_version.py @@ -56,15 +56,15 @@ def parse_obj(cls: Type[BaseModel], obj: Any) -> BaseModel: def create( client: Client, handle: str, - plugin_id: str = None, - filename: str = None, - filebytes: bytes = None, + plugin_id: Optional[str] = None, + filename: Optional[str] = None, + filebytes: Optional[bytes] = None, hosting_memory: Optional[HostingMemory] = None, hosting_timeout: Optional[HostingTimeout] = None, - hosting_handler: str = None, - is_public: bool = None, - is_default: bool = None, - config_template: Dict[str, Any] = None, + hosting_handler: Optional[str] = None, + is_public: Optional[bool] = None, + is_default: Optional[bool] = None, + config_template: Optional[Dict[str, Any]] = None, ) -> PluginVersion: if filename is None and filebytes is None: @@ -94,7 +94,7 @@ def create( expect=PluginVersion, ) - task.wait() + task.wait_until_completed() return task.output @staticmethod