Skip to content

Commit

Permalink
feat: add TimingHook and list of Hooks (#1503)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Oct 16, 2023
1 parent 6b066bf commit 26ca051
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## `develop` branch

- feat(pipeline): add `TimingHook` for profiling processing time
- feat(pipeline): add support for list of hooks with `Hooks`
- fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization`

## Version 3.0.1 (2023-09-28)

- fix(pipeline): fix WeSpeaker GPU support
Expand Down
3 changes: 3 additions & 0 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def iter_waveform_and_mask():

embedding_batches = []

if hook is not None:
hook("embeddings", None, total=batch_count, completed=0)

for i, batch in enumerate(batches, 1):
waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch))

Expand Down
89 changes: 86 additions & 3 deletions pyannote/audio/pipelines/utils/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import time
from copy import deepcopy
from typing import Any, Mapping, Optional, Text

Expand Down Expand Up @@ -64,11 +65,9 @@ class ProgressHook:
"""

def __init__(self, transient: bool = False):
super().__init__()
self.transient = transient

def __enter__(self):

self.progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
Expand All @@ -90,7 +89,6 @@ def __call__(
total: Optional[int] = None,
completed: Optional[int] = None,
):

if completed is None:
completed = total = 1

Expand All @@ -103,3 +101,88 @@ def __call__(
# force refresh when completed
if completed >= total:
self.progress.refresh()


class TimingHook:
"""Hook to compute processing time of internal steps
Parameters
----------
file_key: str, optional
Key used to store processing time in `file`.
Defaults to "timing_hook".
Usage
-----
>>> with TimingHook() as hook:
... output = pipeline(file, hook=hook)
# file["timing_hook"] contains processing time for each step
"""

def __init__(self, file_key: str = "timing_hook"):
self.file_key = file_key

def __enter__(self):
self._pipeline_start_time = time.time()
self._start_time = dict()
self._end_time = dict()
return self

def __exit__(self, *args):
_pipeline_end_time = time.time()
processing_time = dict()
processing_time["total"] = _pipeline_end_time - self._pipeline_start_time
for step_name, _start_time in self._start_time.items():
_end_time = self._end_time[step_name]
processing_time[step_name] = _end_time - _start_time

self._file[self.file_key] = processing_time

def __call__(
self,
step_name: Text,
step_artifact: Any,
file: Optional[Mapping] = None,
total: Optional[int] = None,
completed: Optional[int] = None,
):
if not hasattr(self, "_file"):
self._file = file

if completed is None:
return

if completed == 0:
self._start_time[step_name] = time.time()

if completed >= total:
self._end_time[step_name] = time.time()


class Hooks:
"""List of hooks
Usage
-----
>>> with Hooks(ProgessHook(), TimingHook()) as hook:
... output = pipeline("audio.wav", hook=hook)
"""

def __init__(self, *hooks):
self.hooks = hooks

def __enter__(self):
for hook in self.hooks:
if hasattr(hook, "__enter__"):
hook.__enter__()
return self

def __exit__(self, *args):
for hook in self.hooks:
if hasattr(hook, "__exit__"):
hook.__exit__(*args)

def __call__(self, *args: Any, **kwds: Any) -> Any:
for hook in self.hooks:
hook(*args, **kwds)

0 comments on commit 26ca051

Please sign in to comment.