Skip to content

Commit

Permalink
feat: write out progress report when flag is set
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 7, 2025
1 parent 6f212bb commit 6d8866b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
20 changes: 18 additions & 2 deletions zetta_utils/mazepa/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=too-many-locals
from __future__ import annotations

import os
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
Expand All @@ -16,7 +17,7 @@
from zetta_utils.message_queues.base import PullMessageQueue, PushMessageQueue

from . import Flow, Task, dryrun, sequential_flow
from .execution_checkpoint import record_execution_checkpoint
from .execution_checkpoint import EXECUTION_CHECKPOINT_PATH, record_execution_checkpoint
from .execution_state import ExecutionState, InMemoryExecutionState
from .id_generation import get_unique_id
from .progress_tracker import progress_ctx_mngr
Expand Down Expand Up @@ -75,6 +76,7 @@ def execute(
checkpoint: Optional[str] = None,
checkpoint_interval_sec: Optional[float] = 150,
raise_on_failed_checkpoint: bool = True,
write_progress_summary: bool = False,
):
"""
Executes a target until completion using the given execution queue.
Expand Down Expand Up @@ -137,6 +139,7 @@ def execute(
show_progress=show_progress,
checkpoint_interval_sec=checkpoint_interval_sec,
raise_on_failed_checkpoint=raise_on_failed_checkpoint,
write_progress_summary=write_progress_summary,
)

end_time = time.time()
Expand All @@ -155,6 +158,7 @@ def _execute_from_state(
show_progress: bool,
checkpoint_interval_sec: Optional[float],
raise_on_failed_checkpoint: bool,
write_progress_summary: bool,
num_procs: int = 8,
):
if do_dryrun_estimation:
Expand All @@ -166,7 +170,19 @@ def _execute_from_state(

with ExitStack() as stack:
if show_progress:
progress_updater = stack.enter_context(progress_ctx_mngr(expected_operation_counts))
write_progress_to_path = None
if write_progress_summary: # pragma: no cover
zetta_user = os.environ["ZETTA_USER"]
info_path = os.environ.get("EXECUTION_CHECKPOINT_PATH", EXECUTION_CHECKPOINT_PATH)
write_progress_to_path = os.path.join(
info_path, zetta_user, execution_id, "progress.html"
)

progress_updater = stack.enter_context(
progress_ctx_mngr(
expected_operation_counts, write_progress_to_path=write_progress_to_path
)
)
else:
progress_updater = lambda *args, **kwargs: None # pylint: disable=C3001
with ThreadPoolExecutor(max_workers=num_procs) as pool:
Expand Down
21 changes: 20 additions & 1 deletion zetta_utils/mazepa/progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import os
import signal
import sys
import time
from typing import Generator, Protocol

import fsspec
import rich
from rich import progress
from rich.console import Console

from zetta_utils.common import custom_signal_handler_ctx, get_user_confirmation

Expand Down Expand Up @@ -45,7 +48,9 @@ def __call__(self, progress_reports: dict[str, ProgressReport]) -> None:

@contextlib.contextmanager
def progress_ctx_mngr(
expected_total_counts: dict[str, int]
expected_total_counts: dict[str, int],
write_progress_to_path: str | None = None,
write_progress_interval_sec: int = 5,
) -> Generator[ProgressUpdateFN, None, None]: # pragma: no cover
progress_bar = progress.Progress(
progress.SpinnerColumn(),
Expand Down Expand Up @@ -83,6 +88,7 @@ def custom_debugger_hook():
return pdb.Pdb().set_trace(sys._getframe().f_back) # pylint: disable=protected-access

sys.breakpointhook = custom_debugger_hook
last_progress_writeout_ts = 0.0

with progress_bar as progress_bar:
with custom_signal_handler_ctx(get_confirm_sigint_fn(progress_bar), signal.SIGINT):
Expand All @@ -100,6 +106,7 @@ def custom_debugger_hook():
execution_tracker_ids: dict[str, progress.TaskID] = {}

def update_fn(progress_reports: dict[str, ProgressReport]) -> None:
nonlocal last_progress_writeout_ts
if not hasattr(sys, "gettrace") or sys.gettrace() is None:
progress_bar.start()

Expand All @@ -120,6 +127,18 @@ def update_fn(progress_reports: dict[str, ProgressReport]) -> None:
progress_bar.update(execution_tracker_ids[k], completed=v.completed_count)

progress_bar.refresh()
if (write_progress_to_path is not None) and (
time.time() - last_progress_writeout_ts > write_progress_interval_sec
):
temp_console = Console(record=True, width=80)
for line in progress_bar.get_renderables():
temp_console.print(line)
progress_html = temp_console.export_html(inline_styles=True)

with fsspec.open(write_progress_to_path, "w") as f:
f.write(progress_html)

last_progress_writeout_ts = time.time()

yield update_fn
try:
Expand Down
2 changes: 2 additions & 0 deletions zetta_utils/mazepa_addons/configurations/execute_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def execute_locally(
num_procs: int = 1,
semaphores_spec: dict[SemaphoreType, int] | None = None,
debug: bool = False,
write_progress_summary: bool = False,
):

queues_dir_ = queues_dir if queues_dir else ""
Expand Down Expand Up @@ -73,4 +74,5 @@ def execute_locally(
checkpoint=checkpoint,
checkpoint_interval_sec=checkpoint_interval_sec,
raise_on_failed_checkpoint=raise_on_failed_checkpoint,
write_progress_summary=write_progress_summary,
)
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
checkpoint: Optional[str] = None,
checkpoint_interval_sec: float = 300.0,
raise_on_failed_checkpoint: bool = True,
write_progress_summary: bool = False,
):
if debug and not local_test:
raise ValueError("`debug` can only be set to `True` when `local_test` is also `True`.")
Expand All @@ -247,6 +248,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
num_procs=num_procs,
semaphores_spec=semaphores_spec,
debug=debug,
write_progress_summary=write_progress_summary,
)
else:
assert gcloud.check_image_exists(worker_image), worker_image
Expand Down Expand Up @@ -307,4 +309,5 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals
checkpoint=checkpoint,
checkpoint_interval_sec=checkpoint_interval_sec,
raise_on_failed_checkpoint=raise_on_failed_checkpoint,
write_progress_summary=write_progress_summary,
)

0 comments on commit 6d8866b

Please sign in to comment.