Skip to content

Commit

Permalink
Merge pull request #54 from christoph-blessing/progress_bar
Browse files Browse the repository at this point in the history
Add progress bar
  • Loading branch information
christoph-blessing authored Nov 9, 2023
2 parents 776d3ee + df2167b commit 78ee173
Show file tree
Hide file tree
Showing 18 changed files with 357 additions and 72 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Table().source
All the rows can be pulled like so:

```python
Table().source.pull()
Table().source.pull() # Hint: Pass display_progress=True to get a progress bar
```

That said usually we only want to pull rows that match a certain criteria:
Expand Down
64 changes: 64 additions & 0 deletions link/adapters/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Contains DataJoint-specific code for relaying progress information to the user."""
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable

from link.domain.custom_types import Identifier
from link.domain.state import Processes
from link.service.progress import ProgessDisplay

from .identification import IdentificationTranslator


class ProgressView(ABC):
"""Progress display."""

@abstractmethod
def open(self, description: str, total: int, unit: str) -> None:
"""Open the progress display showing information to the user."""

@abstractmethod
def update_current(self, new: str) -> None:
"""Update the display with new information regarding the current iteration."""

@abstractmethod
def update_iteration(self) -> None:
"""Update the display to reflect that the current iteration finished."""

@abstractmethod
def close(self) -> None:
"""Close the progress display."""

@abstractmethod
def enable(self) -> None:
"""Enable the view."""

@abstractmethod
def disable(self) -> None:
"""Disable the view."""


class DJProgressDisplayAdapter(ProgessDisplay):
"""DataJoint-specific adapter for the progress display."""

def __init__(self, translator: IdentificationTranslator, display: ProgressView) -> None:
"""Initialize the display."""
self._translator = translator
self._display = display

def start(self, process: Processes, to_be_processed: Iterable[Identifier]) -> None:
"""Start showing progress information to the user."""
self._display.open(process.name, len(list(to_be_processed)), "row")

def update_current(self, new: Identifier) -> None:
"""Update the display to reflect a new entity being currently processed."""
self._display.update_current(repr(self._translator.to_primary_key(new)))

def finish_current(self) -> None:
"""Update the display to reflect that the current entity finished processing."""
self._display.update_iteration()

def stop(self) -> None:
"""Stop showing progress information to the user."""
self._display.close()
14 changes: 14 additions & 0 deletions link/domain/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ class Command:
"""Base class for all commands."""


@dataclass(frozen=True)
class PullEntity(Command):
"""Pull the requested entity."""

requested: Identifier


@dataclass(frozen=True)
class DeleteEntity(Command):
"""Delete the requested entity."""

requested: Identifier


@dataclass(frozen=True)
class PullEntities(Command):
"""Pull the requested entities."""
Expand Down
34 changes: 33 additions & 1 deletion link/domain/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .custom_types import Identifier

if TYPE_CHECKING:
from .state import Commands, Operations, State, Transition
from .state import Commands, Operations, Processes, State, Transition


@dataclass(frozen=True)
Expand Down Expand Up @@ -43,3 +43,35 @@ class IdleEntitiesListed(Event):
"""Idle entities in a link have been listed."""

identifiers: frozenset[Identifier]


@dataclass(frozen=True)
class ProcessStarted(Event):
"""A process for an entity was started."""

process: Processes
identifier: Identifier


@dataclass(frozen=True)
class ProcessFinished(Event):
"""A process for an entity was finished."""

process: Processes
identifier: Identifier


@dataclass(frozen=True)
class BatchProcessingStarted(Event):
"""The processing of a batch of entities started."""

process: Processes
identifiers: frozenset[Identifier]


@dataclass(frozen=True)
class BatchProcessingFinished(Event):
"""The processing of a batch of entities finished."""

process: Processes
identifiers: frozenset[Identifier]
22 changes: 6 additions & 16 deletions link/domain/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,17 @@ def identifiers(self) -> frozenset[Identifier]:
"""Return the identifiers of all entities in the link."""
return frozenset(entity.identifier for entity in self)

def pull(self, requested: Iterable[Identifier]) -> None:
"""Pull the requested entities."""
requested = set(requested)
self._validate_requested(requested)
for entity in (entity for entity in self if entity.identifier in requested):
entity.pull()

def delete(self, requested: Iterable[Identifier]) -> None:
"""Delete the requested entities."""
requested = set(requested)
self._validate_requested(requested)
for entity in (entity for entity in self if entity.identifier in requested):
entity.delete()
def __getitem__(self, identifier: Identifier) -> Entity:
"""Return the entity with the given identifier."""
try:
return next(entity for entity in self if entity.identifier == identifier)
except StopIteration as error:
raise KeyError("Requested entity not present in link") from error

def list_idle_entities(self) -> frozenset[Identifier]:
"""List the identifiers of all idle entities in the link."""
return frozenset(entity.identifier for entity in self if entity.state is Idle)

def _validate_requested(self, requested: Iterable[Identifier]) -> None:
assert set(requested) <= self.identifiers, "Requested identifiers not present in link."

def __contains__(self, entity: object) -> bool:
"""Check if the link contains the given entity."""
return entity in self._entities
Expand Down
35 changes: 29 additions & 6 deletions link/infrastructure/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,27 @@
from link.adapters.gateway import DJLinkGateway
from link.adapters.identification import IdentificationTranslator
from link.adapters.present import create_idle_entities_updater, create_state_change_logger
from link.adapters.progress import DJProgressDisplayAdapter
from link.domain import commands, events
from link.service.handlers import delete, list_idle_entities, log_state_change, pull
from link.service.handlers import (
delete,
delete_entity,
inform_batch_processing_finished,
inform_batch_processing_started,
inform_current_process_finished,
inform_next_process_started,
list_idle_entities,
log_state_change,
pull,
pull_entity,
)
from link.service.messagebus import CommandHandlers, EventHandlers, MessageBus
from link.service.uow import UnitOfWork

from . import DJConfiguration, create_tables
from .facade import DJLinkFacade
from .mixin import create_local_endpoint
from .progress import TQDMProgressView
from .sequence import IterationCallbackList, create_content_replacer


Expand Down Expand Up @@ -48,21 +61,31 @@ def inner(obj: type) -> Any:
source_restriction: IterationCallbackList[PrimaryKey] = IterationCallbackList()
idle_entities_updater = create_idle_entities_updater(translator, create_content_replacer(source_restriction))
logger = logging.getLogger(obj.__name__)

command_handlers: CommandHandlers = {}
command_handlers[commands.PullEntities] = partial(pull, uow=uow)
command_handlers[commands.DeleteEntities] = partial(delete, uow=uow)
event_handlers: EventHandlers = {}
bus = MessageBus(uow, command_handlers, event_handlers)
command_handlers[commands.PullEntity] = partial(pull_entity, uow=uow, message_bus=bus)
command_handlers[commands.DeleteEntity] = partial(delete_entity, uow=uow, message_bus=bus)
command_handlers[commands.PullEntities] = partial(pull, message_bus=bus)
command_handlers[commands.DeleteEntities] = partial(delete, message_bus=bus)
command_handlers[commands.ListIdleEntities] = partial(
list_idle_entities, uow=uow, output_port=idle_entities_updater
)
event_handlers: EventHandlers = {}
progress_view = TQDMProgressView()
display = DJProgressDisplayAdapter(translator, progress_view)
event_handlers[events.ProcessStarted] = [partial(inform_next_process_started, display=display)]
event_handlers[events.ProcessFinished] = [partial(inform_current_process_finished, display=display)]
event_handlers[events.BatchProcessingStarted] = [partial(inform_batch_processing_started, display=display)]
event_handlers[events.BatchProcessingFinished] = [partial(inform_batch_processing_finished, display=display)]
event_handlers[events.StateChanged] = [
partial(log_state_change, log=create_state_change_logger(translator, logger.info))
]
event_handlers[events.InvalidOperationRequested] = [lambda event: None]
bus = MessageBus(uow, command_handlers, event_handlers)

controller = DJController(bus, translator)
source_restriction.callback = controller.list_idle_entities

return create_local_endpoint(controller, tables, source_restriction)
return create_local_endpoint(controller, tables, source_restriction, progress_view)

return inner
22 changes: 18 additions & 4 deletions link/infrastructure/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from link.adapters.controller import DJController
from link.adapters.custom_types import PrimaryKey
from link.adapters.progress import ProgressView

from . import DJTables

Expand All @@ -17,11 +18,15 @@ class SourceEndpoint(Table):

_controller: DJController
_outbound_table: Callable[[], Table]
_progress_view: ProgressView

def pull(self) -> None:
def pull(self, *, display_progress: bool = False) -> None:
"""Pull idle entities from the source table into the local table."""
if display_progress:
self._progress_view.enable()
primary_keys = self.proj().fetch(as_dict=True)
self._controller.pull(primary_keys)
self._progress_view.disable()

@property
def flagged(self) -> Sequence[PrimaryKey]:
Expand All @@ -34,6 +39,7 @@ def create_source_endpoint_factory(
source_table: Callable[[], Table],
outbound_table: Callable[[], Table],
restriction: Iterable[PrimaryKey],
progress_view: ProgressView,
) -> Callable[[], SourceEndpoint]:
"""Create a callable that returns the source endpoint when called."""

Expand All @@ -47,6 +53,7 @@ def create_source_endpoint() -> SourceEndpoint:
{
"_controller": controller,
"_outbound_table": staticmethod(outbound_table),
"_progress_view": progress_view,
},
)()
& restriction,
Expand All @@ -60,11 +67,15 @@ class LocalEndpoint(Table):

_controller: DJController
_source: Callable[[], SourceEndpoint]
_progress_view: ProgressView

def delete(self) -> None:
def delete(self, *, display_progress: bool = False) -> None:
"""Delete pulled entities from the local table."""
if display_progress:
self._progress_view.enable()
primary_keys = self.proj().fetch(as_dict=True)
self._controller.delete(primary_keys)
self._progress_view.disable()

@property
def source(self) -> SourceEndpoint:
Expand All @@ -73,7 +84,7 @@ def source(self) -> SourceEndpoint:


def create_local_endpoint(
controller: DJController, tables: DJTables, source_restriction: Iterable[PrimaryKey]
controller: DJController, tables: DJTables, source_restriction: Iterable[PrimaryKey], progress_view: ProgressView
) -> type[LocalEndpoint]:
"""Create the local endpoint."""
return cast(
Expand All @@ -87,8 +98,11 @@ def create_local_endpoint(
{
"_controller": controller,
"_source": staticmethod(
create_source_endpoint_factory(controller, tables.source, tables.outbound, source_restriction)
create_source_endpoint_factory(
controller, tables.source, tables.outbound, source_restriction, progress_view
),
),
"_progress_view": progress_view,
},
),
)
50 changes: 50 additions & 0 deletions link/infrastructure/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Contains views for showing progress information to the user."""
from __future__ import annotations

import logging
from typing import NoReturn

from tqdm.auto import tqdm

from link.adapters.progress import ProgressView

logger = logging.getLogger(__name__)


class TQDMProgressView(ProgressView):
"""A view that uses tqdm to show a progress bar."""

def __init__(self) -> None:
"""Initialize the view."""
self.__progress_bar: tqdm[NoReturn] | None = None
self._is_disabled: bool = False

@property
def _progress_bar(self) -> tqdm[NoReturn]:
assert self.__progress_bar
return self.__progress_bar

def open(self, description: str, total: int, unit: str) -> None:
"""Start showing the progress bar."""
self.__progress_bar = tqdm(total=total, desc=description, unit=unit, disable=self._is_disabled)

def update_current(self, new: str) -> None:
"""Update information about the current iteration shown at the end of the bar."""
self._progress_bar.set_postfix(current=new)

def update_iteration(self) -> None:
"""Update the bar to show an iteration finished."""
self._progress_bar.update()

def close(self) -> None:
"""Stop showing the progress bar."""
self._progress_bar.close()
self.__progress_bar = None

def enable(self) -> None:
"""Enable the progress bar."""
self._is_disabled = False

def disable(self) -> None:
"""Disable the progress bar."""
self._is_disabled = True
Loading

0 comments on commit 78ee173

Please sign in to comment.