From b71ff5c1869d174ad0fcdbe32dc0236cfcdb5943 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Fri, 14 Feb 2025 21:29:09 +0000 Subject: [PATCH] feat: goodies for portal --- zetta_utils/common/ctx_managers.py | 5 ++++ zetta_utils/mazepa/execution.py | 7 ++++- zetta_utils/mazepa/progress_tracker.py | 29 ++++++++++++++----- .../configurations/execute_on_gcp_with_sqs.py | 1 + .../common/interpolate_flow.py | 6 ++-- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/zetta_utils/common/ctx_managers.py b/zetta_utils/common/ctx_managers.py index 35772a6ef..729bfab3d 100644 --- a/zetta_utils/common/ctx_managers.py +++ b/zetta_utils/common/ctx_managers.py @@ -25,3 +25,8 @@ def set_env_ctx_mngr(**environ): finally: os.environ.clear() os.environ.update(old_environ) + + +@contextlib.contextmanager +def noop_ctx_mngr(): # pragma: no cover + yield diff --git a/zetta_utils/mazepa/execution.py b/zetta_utils/mazepa/execution.py index fb8e83364..721e16089 100644 --- a/zetta_utils/mazepa/execution.py +++ b/zetta_utils/mazepa/execution.py @@ -77,6 +77,7 @@ def execute( checkpoint_interval_sec: Optional[float] = 150, raise_on_failed_checkpoint: bool = True, write_progress_summary: bool = False, + require_interrupt_confirm: bool = True, ): """ Executes a target until completion using the given execution queue. @@ -140,6 +141,7 @@ def execute( checkpoint_interval_sec=checkpoint_interval_sec, raise_on_failed_checkpoint=raise_on_failed_checkpoint, write_progress_summary=write_progress_summary, + require_interrupt_confirm=require_interrupt_confirm, ) end_time = time.time() @@ -159,6 +161,7 @@ def _execute_from_state( checkpoint_interval_sec: Optional[float], raise_on_failed_checkpoint: bool, write_progress_summary: bool, + require_interrupt_confirm: bool, num_procs: int = 8, ): if do_dryrun_estimation: @@ -180,7 +183,9 @@ def _execute_from_state( progress_updater = stack.enter_context( progress_ctx_mngr( - expected_operation_counts, write_progress_to_path=write_progress_to_path + expected_operation_counts, + write_progress_to_path=write_progress_to_path, + require_interrupt_confirm=require_interrupt_confirm, ) ) else: diff --git a/zetta_utils/mazepa/progress_tracker.py b/zetta_utils/mazepa/progress_tracker.py index fd7135d8a..240b0d908 100644 --- a/zetta_utils/mazepa/progress_tracker.py +++ b/zetta_utils/mazepa/progress_tracker.py @@ -13,6 +13,7 @@ from rich.console import Console from zetta_utils.common import custom_signal_handler_ctx, get_user_confirmation +from zetta_utils.common.ctx_managers import noop_ctx_mngr from .execution_state import ProgressReport @@ -51,6 +52,7 @@ def progress_ctx_mngr( expected_total_counts: dict[str, int], write_progress_to_path: str | None = None, write_progress_interval_sec: int = 5, + require_interrupt_confirm: bool = True, ) -> Generator[ProgressUpdateFN, None, None]: # pragma: no cover progress_bar = progress.Progress( progress.SpinnerColumn(), @@ -91,7 +93,13 @@ def custom_debugger_hook(): last_progress_writeout_ts = 0.0 with progress_bar as progress_bar: - with custom_signal_handler_ctx(get_confirm_sigint_fn(progress_bar), signal.SIGINT): + if require_interrupt_confirm: + handler_ctx = custom_signal_handler_ctx( + get_confirm_sigint_fn(progress_bar), signal.SIGINT + ) + else: + handler_ctx = noop_ctx_mngr() + with handler_ctx: submission_tracker_ids = { k: progress_bar.add_task( f"[cyan]Submission {k}", @@ -105,6 +113,15 @@ def custom_debugger_hook(): execution_tracker_ids: dict[str, progress.TaskID] = {} + def write_progress_file(): + temp_console = Console(record=True, width=80) + for line in progress_bar.get_renderables(): + temp_console.print(line) + progress_html = temp_console.export_html(inline_styles=True) + + with fsspec.open(write_progress_to_path, "w") as f: + f.write(progress_html) + def update_fn(progress_reports: dict[str, ProgressReport]) -> None: nonlocal last_progress_writeout_ts if not hasattr(sys, "gettrace") or sys.gettrace() is None: @@ -130,17 +147,13 @@ def update_fn(progress_reports: dict[str, ProgressReport]) -> None: if (write_progress_to_path is not None) and ( time.time() - last_progress_writeout_ts > write_progress_interval_sec ): - temp_console = Console(record=True, width=80) - for line in progress_bar.get_renderables(): - temp_console.print(line) - progress_html = temp_console.export_html(inline_styles=True) - - with fsspec.open(write_progress_to_path, "w") as f: - f.write(progress_html) + write_progress_file() last_progress_writeout_ts = time.time() yield update_fn + if write_progress_to_path is not None: + write_progress_file() try: progress_bar.stop() except IndexError: diff --git a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py index 208567bf0..da9554b73 100644 --- a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py @@ -310,4 +310,5 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals checkpoint_interval_sec=checkpoint_interval_sec, raise_on_failed_checkpoint=raise_on_failed_checkpoint, write_progress_summary=write_progress_summary, + require_interrupt_confirm=False, ) diff --git a/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py b/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py index 23d123915..cc0def8af 100644 --- a/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/interpolate_flow.py @@ -7,7 +7,7 @@ from zetta_utils import builder, mazepa, tensor_ops from zetta_utils.common import ComparablePartial from zetta_utils.geometry import BBox3D, Vec3D -from zetta_utils.layer.volumetric.protocols import VolumetricBasedLayerProtocol +from zetta_utils.layer.volumetric.layer import VolumetricLayer from zetta_utils.mazepa.flows import sequential_flow from zetta_utils.mazepa_layer_processing.common.subchunkable_apply_flow import ( build_subchunkable_apply_flow, @@ -54,8 +54,8 @@ def make_interpolate_operation( @builder.register("build_interpolate_flow") def build_interpolate_flow( # pylint: disable=too-many-locals - src: VolumetricBasedLayerProtocol, - dst: VolumetricBasedLayerProtocol | None, + src: VolumetricLayer, + dst: VolumetricLayer | None, src_resolution: Sequence[float], dst_resolutions: Sequence[Sequence[float]] | Sequence[float], mode: tensor_ops.InterpolationMode,