Skip to content

Commit

Permalink
feat: goodies for portal
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 14, 2025
1 parent 35674b5 commit b71ff5c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 12 deletions.
5 changes: 5 additions & 0 deletions zetta_utils/common/ctx_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ def set_env_ctx_mngr(**environ):
finally:
os.environ.clear()
os.environ.update(old_environ)


@contextlib.contextmanager
def noop_ctx_mngr(): # pragma: no cover
yield
7 changes: 6 additions & 1 deletion zetta_utils/mazepa/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def execute(
checkpoint_interval_sec: Optional[float] = 150,
raise_on_failed_checkpoint: bool = True,
write_progress_summary: bool = False,
require_interrupt_confirm: bool = True,
):
"""
Executes a target until completion using the given execution queue.
Expand Down Expand Up @@ -140,6 +141,7 @@ def execute(
checkpoint_interval_sec=checkpoint_interval_sec,
raise_on_failed_checkpoint=raise_on_failed_checkpoint,
write_progress_summary=write_progress_summary,
require_interrupt_confirm=require_interrupt_confirm,
)

end_time = time.time()
Expand All @@ -159,6 +161,7 @@ def _execute_from_state(
checkpoint_interval_sec: Optional[float],
raise_on_failed_checkpoint: bool,
write_progress_summary: bool,
require_interrupt_confirm: bool,
num_procs: int = 8,
):
if do_dryrun_estimation:
Expand All @@ -180,7 +183,9 @@ def _execute_from_state(

progress_updater = stack.enter_context(
progress_ctx_mngr(
expected_operation_counts, write_progress_to_path=write_progress_to_path
expected_operation_counts,
write_progress_to_path=write_progress_to_path,
require_interrupt_confirm=require_interrupt_confirm,
)
)
else:
Expand Down
29 changes: 21 additions & 8 deletions zetta_utils/mazepa/progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rich.console import Console

from zetta_utils.common import custom_signal_handler_ctx, get_user_confirmation
from zetta_utils.common.ctx_managers import noop_ctx_mngr

from .execution_state import ProgressReport

Expand Down Expand Up @@ -51,6 +52,7 @@ def progress_ctx_mngr(
expected_total_counts: dict[str, int],
write_progress_to_path: str | None = None,
write_progress_interval_sec: int = 5,
require_interrupt_confirm: bool = True,
) -> Generator[ProgressUpdateFN, None, None]: # pragma: no cover
progress_bar = progress.Progress(
progress.SpinnerColumn(),
Expand Down Expand Up @@ -91,7 +93,13 @@ def custom_debugger_hook():
last_progress_writeout_ts = 0.0

with progress_bar as progress_bar:
with custom_signal_handler_ctx(get_confirm_sigint_fn(progress_bar), signal.SIGINT):
if require_interrupt_confirm:
handler_ctx = custom_signal_handler_ctx(
get_confirm_sigint_fn(progress_bar), signal.SIGINT
)
else:
handler_ctx = noop_ctx_mngr()
with handler_ctx:
submission_tracker_ids = {
k: progress_bar.add_task(
f"[cyan]Submission {k}",
Expand All @@ -105,6 +113,15 @@ def custom_debugger_hook():

execution_tracker_ids: dict[str, progress.TaskID] = {}

def write_progress_file():
temp_console = Console(record=True, width=80)
for line in progress_bar.get_renderables():
temp_console.print(line)
progress_html = temp_console.export_html(inline_styles=True)

with fsspec.open(write_progress_to_path, "w") as f:
f.write(progress_html)

def update_fn(progress_reports: dict[str, ProgressReport]) -> None:
nonlocal last_progress_writeout_ts
if not hasattr(sys, "gettrace") or sys.gettrace() is None:
Expand All @@ -130,17 +147,13 @@ def update_fn(progress_reports: dict[str, ProgressReport]) -> None:
if (write_progress_to_path is not None) and (
time.time() - last_progress_writeout_ts > write_progress_interval_sec
):
temp_console = Console(record=True, width=80)
for line in progress_bar.get_renderables():
temp_console.print(line)
progress_html = temp_console.export_html(inline_styles=True)

with fsspec.open(write_progress_to_path, "w") as f:
f.write(progress_html)
write_progress_file()

last_progress_writeout_ts = time.time()

yield update_fn
if write_progress_to_path is not None:
write_progress_file()
try:
progress_bar.stop()
except IndexError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,5 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
checkpoint_interval_sec=checkpoint_interval_sec,
raise_on_failed_checkpoint=raise_on_failed_checkpoint,
write_progress_summary=write_progress_summary,
require_interrupt_confirm=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zetta_utils import builder, mazepa, tensor_ops
from zetta_utils.common import ComparablePartial
from zetta_utils.geometry import BBox3D, Vec3D
from zetta_utils.layer.volumetric.protocols import VolumetricBasedLayerProtocol
from zetta_utils.layer.volumetric.layer import VolumetricLayer
from zetta_utils.mazepa.flows import sequential_flow
from zetta_utils.mazepa_layer_processing.common.subchunkable_apply_flow import (
build_subchunkable_apply_flow,
Expand Down Expand Up @@ -54,8 +54,8 @@ def make_interpolate_operation(

@builder.register("build_interpolate_flow")
def build_interpolate_flow( # pylint: disable=too-many-locals
src: VolumetricBasedLayerProtocol,
dst: VolumetricBasedLayerProtocol | None,
src: VolumetricLayer,
dst: VolumetricLayer | None,
src_resolution: Sequence[float],
dst_resolutions: Sequence[Sequence[float]] | Sequence[float],
mode: tensor_ops.InterpolationMode,
Expand Down

0 comments on commit b71ff5c

Please sign in to comment.