diff --git a/.faq/suggest.md b/.faq/suggest.md index 0a9233998..19e7b6d69 100644 --- a/.faq/suggest.md +++ b/.faq/suggest.md @@ -1,3 +1,5 @@ +Thank you for your issue. + {%- if questions -%} {% if questions|length == 1 %} We found the following entry in the [FAQ]({{ faq_url }}) which you may find helpful: @@ -9,12 +11,24 @@ We found the following entries in the [FAQ]({{ faq_url }}) which you may find he - [{{ question.title }}]({{ faq_url }}#{{ question.slug }}) {%- endfor %} -Feel free to close this issue if you found an answer in the FAQ. Otherwise, please give us a little time to review. - {%- else -%} -Thank you for your issue. Give us a little time to review it. - -PS. You might want to check the [FAQ]({{ faq_url }}) if you haven't done so already. +You might want to check the [FAQ]({{ faq_url }}) if you haven't done so already. {%- endif %} -This is an automated reply, generated by [FAQtory](https://github.com/willmcgugan/faqtory) +Feel free to close this issue if you found an answer in the FAQ. + +If your issue is a feature request, please read [this](https://xyproblem.info/) first and update your request accordingly, if needed. + +If your issue is a bug report, please provide a [minimum reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) as a link to a self-contained [Google Colab](https://colab.research.google.com/) notebook containing everthing needed to reproduce the bug: + - installation + - data preparation + - model download + - etc. + +Providing an MRE will increase your chance of getting an answer from the community (either maintainers or other power users). + +Companies relying on `pyannote.audio` in production may contact [me](https://herve.niderb.fr) via email regarding: +* paid scientific consulting around speaker diarization and speech processing in general; +* custom models and tailored features (via the local tech transfer office). + +> This is an automated reply, generated by [FAQtory](https://github.com/willmcgugan/faqtory) diff --git a/.github/workflows/new_issue.yml b/.github/workflows/new_issue.yml index a67bcdcd5..b8477dc16 100644 --- a/.github/workflows/new_issue.yml +++ b/.github/workflows/new_issue.yml @@ -14,7 +14,9 @@ jobs: - name: Install FAQtory run: pip install FAQtory - name: Run Suggest - run: faqtory suggest "${{ github.event.issue.title }}" > suggest.md + env: + TITLE: ${{ github.event.issue.title }} + run: faqtory suggest "$TITLE" > suggest.md - name: Read suggest.md id: suggest uses: juliangruber/read-file-action@v1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b266179eb..df1182cf3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,5 +29,4 @@ jobs: pip install -e .[dev,testing] - name: Test with pytest run: | - export PYANNOTE_DATABASE_CONFIG=$GITHUB_WORKSPACE/tests/data/database.yml pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index d7e50dba7..6e7d220fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,10 @@ ### Breaking changes - BREAKING(task): rename `Segmentation` task to `SpeakerDiarization` - - BREAKING(task): remove support for variable chunk duration + - BREAKING(task): remove support for variable chunk duration for segmentation tasks - BREAKING(pipeline): pipeline defaults to CPU (use `pipeline.to(device)`) - BREAKING(pipeline): remove `SpeakerSegmentation` pipeline (use `SpeakerDiarization` pipeline) - - BREAKING(pipeline): remove support `FINCHClustering` and `HiddenMarkovModelClustering` + - BREAKING(pipeline): remove support for `FINCHClustering` and `HiddenMarkovModelClustering` - BREAKING(pipeline): remove `segmentation_duration` parameter from `SpeakerDiarization` pipeline (defaults to `duration` of segmentation model) - BREAKING(setup): drop support for Python 3.7 - BREAKING(io): channels are now 0-indexed (used to be 1-indexed) @@ -26,12 +26,17 @@ * replace `Audio()` by `Audio(mono="downmix")`; * replace `Audio(mono=True)` by `Audio(mono="downmix")`; * replace `Audio(mono=False)` by `Audio()`. + - BREAKING(model): get rid of (flaky) `Model.introspection` + If, for some weird reason, you wrote some custom code based on that, + you should instead rely on `Model.example_output`. ### Features and improvements + - feat(task): add support for multi-task models - feat(pipeline): send pipeline to device with `pipeline.to(device)` - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task + - feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline - feat(pipeline): add progress hook to pipelines - feat(pipeline): check version compatibility at load time - feat(task): add support for label scope in speaker diarization task @@ -44,6 +49,7 @@ - fix(pipeline): fix reproducibility issue with Ampere CUDA devices - fix(pipeline): fix support for IOBase audio - fix(pipeline): fix corner case with no speaker + - fix(train): prevent metadata preparation to happen twice - improve(task): shorten and improve structure of Tensorboard tags ### Dependencies @@ -82,7 +88,7 @@ - last release before complete rewriting -## Version 1.0.1 (2018--07-19) +## Version 1.0.1 (2018-07-19) - fix: fix regression in Precomputed.__call__ (#110, #105) diff --git a/README.md b/README.md index c3f9a8dcc..1d314cd92 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ -# Neural speaker diarization with `pyannote.audio` +> [!IMPORTANT] +> I propose (paid) scientific [consulting services](https://herve.niderb.fr/consulting.html) to companies willing to make the most of their data and open-source speech processing toolkits (and `pyannote` in particular). + +# Speaker diarization with `pyannote.audio` `pyannote.audio` is an open-source toolkit written in Python for speaker diarization. Based on [PyTorch](pytorch.org) machine learning framework, it provides a set of trainable end-to-end neural building blocks that can be combined and jointly optimized to build speaker diarization pipelines. @@ -126,9 +129,8 @@ pip install -e .[dev,testing] pre-commit install ``` -Tests rely on a set of debugging files available in [`test/data`](test/data) directory. -Set `PYANNOTE_DATABASE_CONFIG` environment variable to `test/data/database.yml` before running tests: +## Test ```bash -PYANNOTE_DATABASE_CONFIG=tests/data/database.yml pytest +pytest ``` diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index c38fef0f9..dcf21868d 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,19 +27,20 @@ import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature from pytorch_lightning.utilities.memory import is_oom_error from pyannote.audio.core.io import AudioFile -from pyannote.audio.core.model import Model +from pyannote.audio.core.model import Model, Specifications from pyannote.audio.core.task import Resolution +from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.permutation import mae_cost_func, permutate from pyannote.audio.utils.powerset import Powerset from pyannote.audio.utils.reproducibility import fix_reproducibility -TaskName = Union[Text, None] - class BaseInference: pass @@ -68,10 +69,10 @@ class Inference(BaseInference): skip_aggregation : bool, optional Do not aggregate outputs when using "sliding" window. Defaults to False. skip_conversion: bool, optional - In case `model` has been trained with `powerset` mode, its output is automatically + In case a task has been trained with `powerset` mode, output is automatically converted to `multi-label`, unless `skip_conversion` is set to True. batch_size : int, optional - Batch size. Larger values make inference faster. Defaults to 32. + Batch size. Larger values (should) make inference faster. Defaults to 32. device : torch.device, optional Device used for inference. Defaults to `model.device`. In case `device` and `model.device` are different, model is sent to device. @@ -94,6 +95,7 @@ def __init__( batch_size: int = 32, use_auth_token: Union[Text, None] = None, ): + # ~~~~ model ~~~~~ self.model = ( model @@ -106,50 +108,70 @@ def __init__( ) ) - if window not in ["sliding", "whole"]: - raise ValueError('`window` must be "sliding" or "whole".') - - specifications = self.model.specifications - if specifications.resolution == Resolution.FRAME and window == "whole": - warnings.warn( - 'Using "whole" `window` inference with a frame-based model might lead to bad results ' - 'and huge memory consumption: it is recommended to set `window` to "sliding".' - ) - - self.window = window - self.skip_aggregation = skip_aggregation - if device is None: device = self.model.device self.device = device - self.pre_aggregation_hook = pre_aggregation_hook - self.model.eval() self.model.to(self.device) - # chunk duration used during training specifications = self.model.specifications - training_duration = specifications.duration - if duration is None: - duration = training_duration - elif training_duration != duration: + # ~~~~ sliding window ~~~~~ + + if window not in ["sliding", "whole"]: + raise ValueError('`window` must be "sliding" or "whole".') + + if window == "whole" and any( + s.resolution == Resolution.FRAME for s in specifications + ): + warnings.warn( + 'Using "whole" `window` inference with a frame-based model might lead to bad results ' + 'and huge memory consumption: it is recommended to set `window` to "sliding".' + ) + self.window = window + + training_duration = next(iter(specifications)).duration + duration = duration or training_duration + if training_duration != duration: warnings.warn( f"Model was trained with {training_duration:g}s chunks, and you requested " f"{duration:g}s chunks for inference: this might lead to suboptimal results." ) self.duration = duration - self.warm_up = specifications.warm_up + # ~~~~ powerset to multilabel conversion ~~~~ + + self.skip_conversion = skip_conversion + + conversion = list() + for s in specifications: + if s.powerset and not skip_conversion: + c = Powerset(len(s.classes), s.powerset_max_classes) + else: + c = nn.Identity() + conversion.append(c.to(self.device)) + + if isinstance(specifications, Specifications): + self.conversion = conversion[0] + else: + self.conversion = nn.ModuleList(conversion) + + # ~~~~ overlap-add aggregation ~~~~~ + + self.skip_aggregation = skip_aggregation + self.pre_aggregation_hook = pre_aggregation_hook + + self.warm_up = next(iter(specifications)).warm_up # Use that many seconds on the left- and rightmost parts of each chunk # to warm up the model. While the model does process those left- and right-most # parts, only the remaining central part of each chunk is used for aggregating # scores during inference. # step between consecutive chunks - if step is None: - step = 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + step = step or ( + 0.1 * self.duration if self.warm_up[0] == 0.0 else self.warm_up[0] + ) if step > self.duration: raise ValueError( @@ -160,23 +182,21 @@ def __init__( self.step = step self.batch_size = batch_size - self.skip_conversion = skip_conversion - if specifications.powerset and not self.skip_conversion: - self._powerset = Powerset( - len(specifications.classes), specifications.powerset_max_classes - ) - self._powerset.to(self.device) - def to(self, device: torch.device): + def to(self, device: torch.device) -> "Inference": """Send internal model to `device`""" + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.model.to(device) - if self.model.specifications.powerset and not self.skip_conversion: - self._powerset.to(device) + self.conversion.to(device) self.device = device return self - def infer(self, chunks: torch.Tensor) -> np.ndarray: + def infer(self, chunks: torch.Tensor) -> Union[np.ndarray, Tuple[np.ndarray]]: """Forward pass Takes care of sending chunks to right device and outputs back to CPU @@ -188,11 +208,11 @@ def infer(self, chunks: torch.Tensor) -> np.ndarray: Returns ------- - outputs : (batch_size, ...) np.ndarray + outputs : (tuple of) (batch_size, ...) np.ndarray Model output. """ - with torch.no_grad(): + with torch.inference_mode(): try: outputs = self.model(chunks.to(self.device)) except RuntimeError as exception: @@ -204,22 +224,19 @@ def infer(self, chunks: torch.Tensor) -> np.ndarray: else: raise exception - # convert powerset to multi-label unless specifically requested not to - if self.model.specifications.powerset and not self.skip_conversion: - powerset = torch.nn.functional.one_hot( - torch.argmax(outputs, dim=-1), - self.model.specifications.num_powerset_classes, - ).float() - outputs = self._powerset.to_multilabel(powerset) + def __convert(output: torch.Tensor, conversion: nn.Module, **kwargs): + return conversion(output).cpu().numpy() - return outputs.cpu().numpy() + return map_with_specifications( + self.model.specifications, __convert, outputs, self.conversion + ) def slide( self, waveform: torch.Tensor, sample_rate: int, hook: Optional[Callable], - ) -> SlidingWindowFeature: + ) -> Union[SlidingWindowFeature, Tuple[SlidingWindowFeature]]: """Slide model on a waveform Parameters @@ -236,23 +253,27 @@ def slide( Returns ------- - output : SlidingWindowFeature + output : (tuple of) SlidingWindowFeature Model output. Shape is (num_chunks, dimension) for chunk-level tasks, and (num_frames, dimension) for frame-level tasks. """ - window_size: int = round(self.duration * sample_rate) + window_size: int = self.model.audio.get_num_samples(self.duration) step_size: int = round(self.step * sample_rate) _, num_samples = waveform.shape - specifications = self.model.specifications - resolution = specifications.resolution - introspection = self.model.introspection - if resolution == Resolution.CHUNK: - frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) - elif resolution == Resolution.FRAME: - frames = introspection.frames - num_frames_per_chunk, dimension = introspection(window_size) + def __frames( + example_output, specifications: Optional[Specifications] = None + ) -> SlidingWindow: + if specifications.resolution == Resolution.CHUNK: + return SlidingWindow(start=0.0, duration=self.duration, step=self.step) + return example_output.frames + + frames: Union[SlidingWindow, Tuple[SlidingWindow]] = map_with_specifications( + self.model.specifications, + __frames, + self.model.example_output, + ) # prepare complete chunks if num_samples >= window_size: @@ -269,75 +290,113 @@ def slide( num_samples - window_size ) % step_size > 0 if has_last_chunk: + # pad last chunk with zeros last_chunk: torch.Tensor = waveform[:, num_chunks * step_size :] + _, last_window_size = last_chunk.shape + last_pad = window_size - last_window_size + last_chunk = F.pad(last_chunk, (0, last_pad)) + + def __empty_list(**kwargs): + return list() - outputs: Union[List[np.ndarray], np.ndarray] = list() + outputs: Union[ + List[np.ndarray], Tuple[List[np.ndarray]] + ] = map_with_specifications(self.model.specifications, __empty_list) if hook is not None: hook(completed=0, total=num_chunks + has_last_chunk) + def __append_batch(output, batch_output, **kwargs) -> None: + output.append(batch_output) + return + # slide over audio chunks in batch for c in np.arange(0, num_chunks, self.batch_size): batch: torch.Tensor = chunks[c : c + self.batch_size] - outputs.append(self.infer(batch)) + + batch_outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(batch) + + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, batch_outputs + ) + if hook is not None: hook(completed=c + self.batch_size, total=num_chunks + has_last_chunk) # process orphan last chunk if has_last_chunk: + last_outputs = self.infer(last_chunk[None]) - last_output = self.infer(last_chunk[None]) - - if specifications.resolution == Resolution.FRAME: - pad = num_frames_per_chunk - last_output.shape[1] - last_output = np.pad(last_output, ((0, 0), (0, pad), (0, 0))) + _ = map_with_specifications( + self.model.specifications, __append_batch, outputs, last_outputs + ) - outputs.append(last_output) if hook is not None: hook( completed=num_chunks + has_last_chunk, total=num_chunks + has_last_chunk, ) - outputs = np.vstack(outputs) - - # skip aggregation when requested, - # or when model outputs just one vector per chunk - # or when model is permutation-invariant (and not post-processed) - if ( - self.skip_aggregation - or specifications.resolution == Resolution.CHUNK - or ( - specifications.permutation_invariant - and self.pre_aggregation_hook is None - ) - ): - frames = SlidingWindow(start=0.0, duration=self.duration, step=self.step) - return SlidingWindowFeature(outputs, frames) - - if self.pre_aggregation_hook is not None: - outputs = self.pre_aggregation_hook(outputs) - - aggregated = self.aggregate( - SlidingWindowFeature( - outputs, - SlidingWindow(start=0.0, duration=self.duration, step=self.step), - ), - frames=frames, - warm_up=self.warm_up, - hamming=True, - missing=0.0, + def __vstack(output: List[np.ndarray], **kwargs) -> np.ndarray: + return np.vstack(output) + + outputs: Union[np.ndarray, Tuple[np.ndarray]] = map_with_specifications( + self.model.specifications, __vstack, outputs ) - if has_last_chunk: - num_frames = aggregated.data.shape[0] - aggregated.data = aggregated.data[: num_frames - pad, :] + def __aggregate( + outputs: np.ndarray, + frames: SlidingWindow, + specifications: Optional[Specifications] = None, + ) -> SlidingWindowFeature: + # skip aggregation when requested, + # or when model outputs just one vector per chunk + # or when model is permutation-invariant (and not post-processed) + if ( + self.skip_aggregation + or specifications.resolution == Resolution.CHUNK + or ( + specifications.permutation_invariant + and self.pre_aggregation_hook is None + ) + ): + frames = SlidingWindow( + start=0.0, duration=self.duration, step=self.step + ) + return SlidingWindowFeature(outputs, frames) + + if self.pre_aggregation_hook is not None: + outputs = self.pre_aggregation_hook(outputs) + + aggregated = self.aggregate( + SlidingWindowFeature( + outputs, + SlidingWindow(start=0.0, duration=self.duration, step=self.step), + ), + frames=frames, + warm_up=self.warm_up, + hamming=True, + missing=0.0, + ) + + # remove padding that was added to last chunk + if has_last_chunk: + aggregated.data = aggregated.crop( + Segment(0.0, num_samples / sample_rate), mode="loose" + ) - return aggregated + return aggregated + + return map_with_specifications( + self.model.specifications, __aggregate, outputs, frames + ) def __call__( self, file: AudioFile, hook: Optional[Callable] = None - ) -> Union[SlidingWindowFeature, np.ndarray]: + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: """Run inference on a whole file Parameters @@ -352,7 +411,7 @@ def __call__( Returns ------- - output : SlidingWindowFeature or np.ndarray + output : (tuple of) SlidingWindowFeature or np.ndarray Model output, as `SlidingWindowFeature` if `window` is set to "sliding" and `np.ndarray` if is set to "whole". @@ -365,7 +424,14 @@ def __call__( if self.window == "sliding": return self.slide(waveform, sample_rate, hook=hook) - return self.infer(waveform[None])[0] + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) def crop( self, @@ -373,7 +439,10 @@ def crop( chunk: Union[Segment, List[Segment]], duration: Optional[float] = None, hook: Optional[Callable] = None, - ) -> Union[SlidingWindowFeature, np.ndarray]: + ) -> Union[ + Tuple[Union[SlidingWindowFeature, np.ndarray]], + Union[SlidingWindowFeature, np.ndarray], + ]: """Run inference on a chunk or a list of chunks Parameters @@ -398,7 +467,7 @@ def crop( Returns ------- - output : SlidingWindowFeature or np.ndarray + output : (tuple of) SlidingWindowFeature or np.ndarray Model output, as `SlidingWindowFeature` if `window` is set to "sliding" and `np.ndarray` if is set to "whole". @@ -415,7 +484,6 @@ def crop( fix_reproducibility(self.device) if self.window == "sliding": - if not isinstance(chunk, Segment): start = min(c.start for c in chunk) end = max(c.end for c in chunk) @@ -424,32 +492,37 @@ def crop( waveform, sample_rate = self.model.audio.crop( file, chunk, duration=duration ) - output = self.slide(waveform, sample_rate, hook=hook) - - frames = output.sliding_window - shifted_frames = SlidingWindow( - start=chunk.start, duration=frames.duration, step=frames.step - ) - return SlidingWindowFeature(output.data, shifted_frames) - - elif self.window == "whole": - - if isinstance(chunk, Segment): - waveform, sample_rate = self.model.audio.crop( - file, chunk, duration=duration - ) - else: - waveform = torch.cat( - [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 + outputs: Union[ + SlidingWindowFeature, Tuple[SlidingWindowFeature] + ] = self.slide(waveform, sample_rate, hook=hook) + + def __shift(output: SlidingWindowFeature, **kwargs) -> SlidingWindowFeature: + frames = output.sliding_window + shifted_frames = SlidingWindow( + start=chunk.start, duration=frames.duration, step=frames.step ) + return SlidingWindowFeature(output.data, shifted_frames) - return self.infer(waveform[None])[0] + return map_with_specifications(self.model.specifications, __shift, outputs) + if isinstance(chunk, Segment): + waveform, sample_rate = self.model.audio.crop( + file, chunk, duration=duration + ) else: - raise NotImplementedError( - f"Unsupported window type '{self.window}': should be 'sliding' or 'whole'." + waveform = torch.cat( + [self.model.audio.crop(file, c)[0] for c in chunk], dim=1 ) + outputs: Union[np.ndarray, Tuple[np.ndarray]] = self.infer(waveform[None]) + + def __first_sample(outputs: np.ndarray, **kwargs) -> np.ndarray: + return outputs[0] + + return map_with_specifications( + self.model.specifications, __first_sample, outputs + ) + @staticmethod def aggregate( scores: SlidingWindowFeature, @@ -691,7 +764,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): stitches = [] for C, (chunk, activation) in enumerate(activations): - local_stitch = np.NAN * np.zeros( (sum(lookahead) + 1, num_frames, num_classes) ) @@ -699,7 +771,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): for c in range( max(0, C - lookahead[0]), min(num_chunks, C + lookahead[1] + 1) ): - # extract common temporal support shift = round((C - c) * num_frames * chunks.step / chunks.duration) @@ -720,7 +791,6 @@ def always_match(this: np.ndarray, that: np.ndarray, cost: float): ) for this, that in enumerate(permutation): - # only stitch under certain condiditions matching = (c == C) or ( match_func( diff --git a/pyannote/audio/core/io.py b/pyannote/audio/core/io.py index b2e8842b1..0a44e75ea 100644 --- a/pyannote/audio/core/io.py +++ b/pyannote/audio/core/io.py @@ -150,7 +150,6 @@ def validate_file(file: AudioFile) -> Mapping: raise ValueError(AudioFileDocString) if "waveform" in file: - waveform: Union[np.ndarray, Tensor] = file["waveform"] if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]: raise ValueError( @@ -166,7 +165,6 @@ def validate_file(file: AudioFile) -> Mapping: file.setdefault("uri", "waveform") elif "audio" in file: - if isinstance(file["audio"], IOBase): return file @@ -177,7 +175,6 @@ def validate_file(file: AudioFile) -> Mapping: file.setdefault("uri", path.stem) else: - raise ValueError( "Neither 'waveform' nor 'audio' is available for this file." ) @@ -185,7 +182,6 @@ def validate_file(file: AudioFile) -> Mapping: return file def __init__(self, sample_rate=None, mono=None): - super().__init__() self.sample_rate = sample_rate self.mono = mono @@ -257,6 +253,18 @@ def get_duration(self, file: AudioFile) -> float: return frames / sample_rate + def get_num_samples(self, duration: float, sample_rate: int = None) -> int: + """Deterministic number of samples from duration and sample rate""" + + sample_rate = sample_rate or self.sample_rate + + if sample_rate is None: + raise ValueError( + "`sample_rate` must be provided to compute number of samples." + ) + + return math.floor(duration * sample_rate) + def __call__(self, file: AudioFile) -> Tuple[Tensor, int]: """Obtain waveform @@ -359,7 +367,6 @@ def crop( num_frames = end_frame - start_frame if mode == "raise": - if num_frames > frames: raise ValueError( f"requested fixed duration ({duration:6f}s, or {num_frames:d} frames) is longer " @@ -400,7 +407,6 @@ def crop( if isinstance(file["audio"], IOBase): file["audio"].seek(0) except RuntimeError: - if isinstance(file["audio"], IOBase): msg = "torchaudio failed to seek-and-read in file-like object." raise RuntimeError(msg) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 18b301086..bedb7f6c4 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,8 @@ import os import warnings +from dataclasses import dataclass +from functools import cached_property from importlib import import_module from pathlib import Path from typing import Any, Dict, List, Optional, Text, Tuple, Union @@ -49,6 +51,7 @@ Task, UnknownSpecificationsError, ) +from pyannote.audio.utils.multi_task import map_with_specifications from pyannote.audio.utils.version import check_version CACHE_DIR = os.getenv( @@ -59,195 +62,16 @@ HF_LIGHTNING_CONFIG_NAME = "config.yaml" +# NOTE: needed to backward compatibility to load models trained before pyannote.audio 3.x class Introspection: - """Model introspection + pass - Parameters - ---------- - min_num_samples: int - Minimum number of input samples - min_num_frames: int - Corresponding minimum number of output frames - inc_num_samples: int - Number of input samples leading to an increase of number of output frames - inc_num_frames: int - Corresponding increase in number of output frames - dimension: int - Output dimension - sample_rate: int - Expected input sample rate - - Usage - ----- - >>> introspection = Introspection.from_model(model) - >>> isinstance(introspection.frames, SlidingWindow) - >>> num_samples = 16000 # 1s at 16kHz - >>> num_frames, dimension = introspection(num_samples) - """ - - def __init__( - self, - min_num_samples: int, - min_num_frames: int, - inc_num_samples: int, - inc_num_frames: int, - dimension: int, - sample_rate: int, - ): - super().__init__() - self.min_num_samples = min_num_samples - self.min_num_frames = min_num_frames - self.inc_num_samples = inc_num_samples - self.inc_num_frames = inc_num_frames - self.dimension = dimension - self.sample_rate = sample_rate - - @classmethod - def from_model(cls, model: "Model", task: str = None) -> Introspection: - - specifications = model.specifications - if task is not None: - specifications = specifications[task] - - example_input_array = model.example_input_array - batch_size, num_channels, num_samples = example_input_array.shape - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - - # dichotomic search of "min_num_samples" - lower, upper, min_num_samples = 1, num_samples, None - while True: - num_samples = (lower + upper) // 2 - try: - with torch.no_grad(): - frames = model(example_input_array[:, :, :num_samples]) - if task is not None: - frames = frames[task] - except Exception: - lower = num_samples - else: - min_num_samples = num_samples - if specifications.resolution == Resolution.FRAME: - _, min_num_frames, dimension = frames.shape - elif specifications.resolution == Resolution.CHUNK: - _, dimension = frames.shape - else: - # should never happen - pass - upper = num_samples - - if lower + 1 == upper: - break - - # if "min_num_samples" is still None at this point, it means that - # the forward pass always failed and raised an exception. most likely, - # it means that there is a problem with the model definition. - # we try again without catching the exception to help the end user debug - # their model - if min_num_samples is None: - frames = model(example_input_array) - - # corner case for chunk-level tasks - if specifications.resolution == Resolution.CHUNK: - return cls( - min_num_samples=min_num_samples, - min_num_frames=1, - inc_num_samples=0, - inc_num_frames=0, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - # search reasonable upper bound for "inc_num_samples" - while True: - num_samples = 2 * min_num_samples - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - with torch.no_grad(): - frames = model(example_input_array) - if task is not None: - frames = frames[task] - num_frames = frames.shape[1] - if num_frames > min_num_frames: - break - - # dichotomic search of "inc_num_samples" - lower, upper = min_num_samples, num_samples - while True: - num_samples = (lower + upper) // 2 - example_input_array = torch.randn( - (batch_size, num_channels, num_samples), - dtype=example_input_array.dtype, - layout=example_input_array.layout, - device=example_input_array.device, - requires_grad=False, - ) - with torch.no_grad(): - frames = model(example_input_array) - if task is not None: - frames = frames[task] - num_frames = frames.shape[1] - if num_frames > min_num_frames: - inc_num_frames = num_frames - min_num_frames - inc_num_samples = num_samples - min_num_samples - upper = num_samples - else: - lower = num_samples - - if lower + 1 == upper: - break - return cls( - min_num_samples=min_num_samples, - min_num_frames=min_num_frames, - inc_num_samples=inc_num_samples, - inc_num_frames=inc_num_frames, - dimension=dimension, - sample_rate=model.hparams.sample_rate, - ) - - def __call__(self, num_samples: int) -> Tuple[int, int]: - """Predict output shape, given number of input samples - - Parameters - ---------- - num_samples : int - Number of input samples. - - Returns - ------- - num_frames : int - Number of output frames - dimension : int - Dimension of output frames - """ - - if num_samples < self.min_num_samples: - return 0, self.dimension - - return ( - self.min_num_frames - + self.inc_num_frames - * ((num_samples - self.min_num_samples + 1) // self.inc_num_samples), - self.dimension, - ) - - @property - def frames(self) -> SlidingWindow: - # HACK to support model trained before 'sample_rate' was an Introspection attribute - sample_rate = getattr(self, "sample_rate", 16000) - step = (self.inc_num_samples / self.inc_num_frames) / sample_rate - return SlidingWindow(start=0.0, step=step, duration=step) +@dataclass +class Output: + num_frames: int + dimension: int + frames: SlidingWindow class Model(pl.LightningModule): @@ -281,31 +105,26 @@ def __init__( self.audio = Audio(sample_rate=self.hparams.sample_rate, mono="downmix") @property - def example_input_array(self) -> torch.Tensor: - batch_size = 3 if self.task is None else self.task.batch_size - duration = 2.0 if self.task is None else self.task.duration - - return torch.randn( - ( - batch_size, - self.hparams.num_channels, - int(self.hparams.sample_rate * duration), - ), - device=self.device, - ) - - @property - def task(self): + def task(self) -> Task: return self._task @task.setter - def task(self, task): - self._task = task - del self.introspection + def task(self, task: Task): + # reset (cached) properties when task changes del self.specifications + try: + del self.example_output + except AttributeError: + pass + self._task = task + + def build(self): + # use this method to add task-dependent layers to the model + # (e.g. the final classification and activation layers) + pass @property - def specifications(self): + def specifications(self) -> Union[Specifications, Tuple[Specifications]]: if self.task is None: try: specifications = self._specifications @@ -330,7 +149,22 @@ def specifications(self): return specifications @specifications.setter - def specifications(self, specifications): + def specifications( + self, specifications: Union[Specifications, Tuple[Specifications]] + ): + if not isinstance(specifications, (Specifications, tuple)): + raise ValueError( + "Only regular specifications or tuple of specifications are supported." + ) + + durations = set(s.duration for s in specifications) + if len(durations) > 1: + raise ValueError("All tasks must share the same (maximum) duration.") + + min_durations = set(s.min_duration for s in specifications) + if len(min_durations) > 1: + raise ValueError("All tasks must share the same minimum duration.") + self._specifications = specifications @specifications.deleter @@ -338,39 +172,54 @@ def specifications(self): if hasattr(self, "_specifications"): del self._specifications - def build(self): - # use this method to add task-dependent layers to the model - # (e.g. the final classification and activation layers) - pass + def __example_input_array(self, duration: Optional[float] = None) -> torch.Tensor: + duration = duration or next(iter(self.specifications)).duration + return torch.randn( + ( + 1, + self.hparams.num_channels, + self.audio.get_num_samples(duration), + ), + device=self.device, + ) @property - def introspection(self) -> Introspection: - """Introspection - - Returns - ------- - introspection: Introspection - Model introspection - """ - - if not hasattr(self, "_introspection"): - self._introspection = Introspection.from_model(self) - - return self._introspection + def example_input_array(self) -> torch.Tensor: + return self.__example_input_array() + + @cached_property + def example_output(self) -> Union[Output, Tuple[Output]]: + """Example output""" + example_input_array = self.__example_input_array() + with torch.inference_mode(): + example_output = self(example_input_array) + + def __example_output( + example_output: torch.Tensor, + specifications: Specifications = None, + ) -> Output: + if specifications.resolution == Resolution.FRAME: + _, num_frames, dimension = example_output.shape + frame_duration = specifications.duration / num_frames + frames = SlidingWindow(step=frame_duration, duration=frame_duration) + else: + _, dimension = example_output.shape + num_frames = None + frames = None - @introspection.setter - def introspection(self, introspection): - self._introspection = introspection + return Output( + num_frames=num_frames, + dimension=dimension, + frames=frames, + ) - @introspection.deleter - def introspection(self): - if hasattr(self, "_introspection"): - del self._introspection + return map_with_specifications( + self.specifications, __example_output, example_output + ) def setup(self, stage=None): - if stage == "fit": - self.task.setup() + self.task.setup_metadata() # list of layers before adding task-dependent layers before = set((name, id(module)) for name, module in self.named_modules()) @@ -411,8 +260,8 @@ def setup(self, stage=None): # setup custom validation metrics self.task.setup_validation_metric() - # this is to make sure introspection is performed here, once and for all - _ = self.introspection + # cache for later (and to avoid later CUDA error with multiprocessing) + _ = self.example_output # list of layers after adding task-dependent layers after = set((name, id(module)) for name, module in self.named_modules()) @@ -421,7 +270,6 @@ def setup(self, stage=None): self.task_dependent = list(name for name, _ in after - before) def on_save_checkpoint(self, checkpoint): - # put everything pyannote.audio-specific under pyannote.audio # to avoid any future conflicts with pytorch-lightning updates checkpoint["pyannote.audio"] = { @@ -433,12 +281,10 @@ def on_save_checkpoint(self, checkpoint): "module": self.__class__.__module__, "class": self.__class__.__name__, }, - "introspection": self.introspection, "specifications": self.specifications, } def on_load_checkpoint(self, checkpoint: Dict[str, Any]): - check_version( "pyannote.audio", checkpoint["pyannote.audio"]["versions"]["pyannote.audio"], @@ -462,43 +308,17 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]): self.specifications = checkpoint["pyannote.audio"]["specifications"] + # add task-dependent (e.g. final classifier) layers self.setup() - self.introspection = checkpoint["pyannote.audio"]["introspection"] - - def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + def forward( + self, waveforms: torch.Tensor, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: msg = "Class {self.__class__.__name__} should define a `forward` method." raise NotImplementedError(msg) - def helper_default_activation(self, specifications: Specifications) -> nn.Module: - """Helper function for default_activation - - Parameters - ---------- - specifications: Specifications - Task specification. - - Returns - ------- - activation : nn.Module - Default activation function. - """ - - if specifications.problem == Problem.BINARY_CLASSIFICATION: - return nn.Sigmoid() - - elif specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: - return nn.LogSoftmax(dim=-1) - - elif specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: - return nn.Sigmoid() - - else: - msg = "TODO: implement default activation for other types of problems" - raise NotImplementedError(msg) - # convenience function to automate the choice of the final activation function - def default_activation(self) -> nn.Module: + def default_activation(self) -> Union[nn.Module, Tuple[nn.Module]]: """Guess default activation function according to task specification * sigmoid for binary classification @@ -507,10 +327,25 @@ def default_activation(self) -> nn.Module: Returns ------- - activation : nn.Module + activation : (tuple of) nn.Module Activation. """ - return self.helper_default_activation(self.specifications) + + def __default_activation(specifications: Specifications = None) -> nn.Module: + if specifications.problem == Problem.BINARY_CLASSIFICATION: + return nn.Sigmoid() + + elif specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: + return nn.LogSoftmax(dim=-1) + + elif specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: + return nn.Sigmoid() + + else: + msg = "TODO: implement default activation for other types of problems" + raise NotImplementedError(msg) + + return map_with_specifications(self.specifications, __default_activation) # training data logic is delegated to the task because the # model does not really need to know how it is being used. @@ -535,9 +370,7 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) - def _helper_up_to( - self, module_name: Text, requires_grad: bool = False - ) -> List[Text]: + def __up_to(self, module_name: Text, requires_grad: bool = False) -> List[Text]: """Helper function for freeze_up_to and unfreeze_up_to""" tokens = module_name.split(".") @@ -594,7 +427,7 @@ def freeze_up_to(self, module_name: Text) -> List[Text]: If your model does not follow a sequential structure, you might want to use freeze_by_name for more control. """ - return self._helper_up_to(module_name, requires_grad=False) + return self.__up_to(module_name, requires_grad=False) def unfreeze_up_to(self, module_name: Text) -> List[Text]: """Unfreeze model up to specific module @@ -619,9 +452,9 @@ def unfreeze_up_to(self, module_name: Text) -> List[Text]: If your model does not follow a sequential structure, you might want to use freeze_by_name for more control. """ - return self._helper_up_to(module_name, requires_grad=True) + return self.__up_to(module_name, requires_grad=True) - def _helper_by_name( + def __by_name( self, modules: Union[List[Text], Text], recurse: bool = True, @@ -636,7 +469,6 @@ def _helper_by_name( modules = [modules] for name, module in ModelSummary(self, max_depth=-1).named_modules: - if name not in modules: continue @@ -678,7 +510,7 @@ def freeze_by_name( ValueError if at least one of `modules` does not exist. """ - return self._helper_by_name( + return self.__by_name( modules, recurse=recurse, requires_grad=False, @@ -709,7 +541,7 @@ def unfreeze_by_name( ValueError if at least one of `modules` does not exist. """ - return self._helper_by_name(modules, recurse=recurse, requires_grad=True) + return self.__by_name(modules, recurse=recurse, requires_grad=True) @classmethod def from_pretrained( @@ -826,7 +658,6 @@ def from_pretrained( # HACK do not use it. Fails silently in case model does not # HACK have a config.yaml file. try: - _ = hf_hub_download( model_id, HF_LIGHTNING_CONFIG_NAME, diff --git a/pyannote/audio/core/pipeline.py b/pyannote/audio/core/pipeline.py index a5b4b1bc5..f844d584f 100644 --- a/pyannote/audio/core/pipeline.py +++ b/pyannote/audio/core/pipeline.py @@ -324,9 +324,14 @@ def __call__(self, file: AudioFile, **kwargs): return self.apply(file, **kwargs) - def to(self, device): + def to(self, device: torch.device): """Send pipeline to `device`""" + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + for _, pipeline in self._pipelines.items(): if hasattr(pipeline, "to"): _ = pipeline.to(device) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 9bf93bf1c..1edfbc35c 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -72,9 +72,11 @@ class Specifications: problem: Problem resolution: Resolution - # chunk duration in seconds. - # use None for variable-length chunks - duration: Optional[float] = None + # (maximum) chunk duration in seconds + duration: float + + # (for variable-duration tasks only) minimum chunk duration in seconds + min_duration: Optional[float] = None # use that many seconds on the left- and rightmost parts of each chunk # to warm up the model. This is mostly useful for segmentation tasks. @@ -95,7 +97,7 @@ class Specifications: permutation_invariant: bool = False @cached_property - def powerset(self): + def powerset(self) -> bool: if self.powerset_max_classes is None: return False @@ -118,6 +120,12 @@ def num_powerset_classes(self) -> int: ) ) + def __len__(self): + return 1 + + def __iter__(self): + yield self + class TrainDataset(IterableDataset): def __init__(self, task: Task): @@ -191,7 +199,7 @@ class Task(pl.LightningDataModule): Attributes ---------- - specifications : Specifications or dict of Specifications + specifications : Specifications or tuple of Specifications Task specifications (available after `Task.setup` has been called.) """ @@ -260,7 +268,28 @@ def prepare_data(self): """ pass - def setup(self, stage: Optional[str] = None): + @property + def specifications(self) -> Union[Specifications, Tuple[Specifications]]: + # setup metadata on-demand the first time specifications are requested and missing + if not hasattr(self, "_specifications"): + self.setup_metadata() + return self._specifications + + @specifications.setter + def specifications( + self, specifications: Union[Specifications, Tuple[Specifications]] + ): + self._specifications = specifications + + @property + def has_setup_metadata(self): + return getattr(self, "_has_setup_metadata", False) + + @has_setup_metadata.setter + def has_setup_metadata(self, value: bool): + self._has_setup_metadata = value + + def setup_metadata(self): """Called at the beginning of training at the very beginning of Model.setup(stage="fit") Notes @@ -270,7 +299,10 @@ def setup(self, stage: Optional[str] = None): If `specifications` attribute has not been set in `__init__`, `setup` is your last chance to set it. """ - pass + + if not self.has_setup_metadata: + self.setup() + self.has_setup_metadata = True def setup_loss_func(self): pass @@ -362,6 +394,11 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): {"loss": loss} """ + if isinstance(self.specifications, tuple): + raise NotImplementedError( + "Default training/validation step is not implemented for multi-task." + ) + # forward pass y_pred = self.model(batch["X"]) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 1b68a32a9..5af3734b1 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -80,7 +80,6 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, ): - super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet) @@ -140,7 +139,6 @@ def __init__( ) def build(self): - if self.hparams.linear["num_layers"] > 0: in_features = self.hparams.linear["hidden_size"] else: @@ -148,6 +146,9 @@ def build(self): 2 if self.hparams.lstm["bidirectional"] else 1 ) + if isinstance(self.specifications, tuple): + raise ValueError("PyanNet does not support multi-tasking.") + if self.specifications.powerset: out_features = self.specifications.num_powerset_classes else: diff --git a/pyannote/audio/models/segmentation/debug.py b/pyannote/audio/models/segmentation/debug.py index 498faee27..89512320c 100644 --- a/pyannote/audio/models/segmentation/debug.py +++ b/pyannote/audio/models/segmentation/debug.py @@ -39,7 +39,6 @@ def __init__( num_channels: int = 1, task: Optional[Task] = None, ): - super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) self.mfcc = MFCC( @@ -60,7 +59,16 @@ def __init__( def build(self): # define task-dependent layers - self.classifier = nn.Linear(32 * 2, len(self.specifications.classes)) + + if isinstance(self.specifications, tuple): + raise ValueError("SimpleSegmentationModel does not support multi-tasking.") + + if self.specifications.powerset: + out_features = self.specifications.num_powerset_classes + else: + out_features = len(self.specifications.classes) + + self.classifier = nn.Linear(32 * 2, out_features) self.activation = self.default_activation() def forward(self, waveforms: torch.Tensor) -> torch.Tensor: diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index f282ea39c..a779016cb 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -48,7 +48,6 @@ def __init__( max_num_embeddings: int = 1000, constrained_assignment: bool = False, ): - super().__init__() self.metric = metric self.max_num_embeddings = max_num_embeddings @@ -61,7 +60,6 @@ def set_num_clusters( min_clusters: int = None, max_clusters: int = None, ): - min_clusters = num_clusters or min_clusters or 1 min_clusters = max(1, min(num_embeddings, min_clusters)) max_clusters = num_clusters or max_clusters or num_embeddings @@ -113,7 +111,6 @@ def filter_embeddings( return embeddings[chunk_idx, speaker_idx], chunk_idx, speaker_idx def constrained_argmax(self, soft_clusters: np.ndarray) -> np.ndarray: - soft_clusters = np.nan_to_num(soft_clusters, nan=np.nanmin(soft_clusters)) num_chunks, num_speakers, num_clusters = soft_clusters.shape # num_chunks, num_speakers, num_clusters @@ -156,6 +153,8 @@ def assign_embeddings( ------- soft_clusters : (num_chunks, num_speakers, num_clusters)-shaped array hard_clusters : (num_chunks, num_speakers)-shaped array + centroids : (num_clusters, dimension)-shaped array + Clusters centroids """ # TODO: option to add a new (dummy) cluster in case num_clusters < max(frame_speaker_count) @@ -191,10 +190,11 @@ def assign_embeddings( else: hard_clusters = np.argmax(soft_clusters, axis=2) - # TODO: add a flag to revert argmax for trainign subset - # hard_clusters[train_chunk_idx, train_speaker_idx] = train_clusters + # NOTE: train_embeddings might be reassigned to a different cluster + # in the process. based on experiments, this seems to lead to better + # results than sticking to the original assignment. - return hard_clusters, soft_clusters + return hard_clusters, soft_clusters, centroids def __call__( self, @@ -230,6 +230,8 @@ def __call__( soft_clusters : (num_chunks, num_speakers, num_clusters) array Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely the sth speaker of cth chunk belongs to kth cluster) + centroids : (num_clusters, dimension) array + Centroid vectors of each cluster """ train_embeddings, train_chunk_idx, train_speaker_idx = self.filter_embeddings( @@ -250,7 +252,9 @@ def __call__( num_chunks, num_speakers, _ = embeddings.shape hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8) soft_clusters = np.ones((num_chunks, num_speakers, 1)) - return hard_clusters, soft_clusters + centroids = np.mean(train_embeddings, axis=0, keepdims=True) + + return hard_clusters, soft_clusters, centroids train_clusters = self.cluster( train_embeddings, @@ -259,7 +263,7 @@ def __call__( num_clusters=num_clusters, ) - hard_clusters, soft_clusters = self.assign_embeddings( + hard_clusters, soft_clusters, centroids = self.assign_embeddings( embeddings, train_chunk_idx, train_speaker_idx, @@ -267,7 +271,7 @@ def __call__( constrained=self.constrained_assignment, ) - return hard_clusters, soft_clusters + return hard_clusters, soft_clusters, centroids class AgglomerativeClustering(BaseClustering): @@ -286,19 +290,6 @@ class AgglomerativeClustering(BaseClustering): Clustering threshold. min_cluster_size : int in range [1, 20] Minimum cluster size - - Usage - ----- - >>> clustering = AgglomerativeClustering(metric="cosine") - >>> clustering.instantiate({"method": "average", - ... "threshold": 1.0, - ... "min_cluster_size": 1}) - >>> clusters, _ = clustering(embeddings, # shape - ... num_clusters=None, - ... min_clusters=None, - ... max_clusters=None) - where `embeddings` is a np.ndarray with shape (num_embeddings, embedding_dimension) - and `clusters` is a np.ndarray with shape (num_embeddings, ) """ def __init__( @@ -307,7 +298,6 @@ def __init__( max_num_embeddings: int = np.inf, constrained_assignment: bool = False, ): - super().__init__( metric=metric, max_num_embeddings=max_num_embeddings, @@ -397,7 +387,6 @@ def cluster( num_clusters = max_clusters if num_clusters is not None: - # switch stopping criterion from "inter-cluster distance" stopping to "iteration index" _dendrogram = np.copy(dendrogram) _dendrogram[:, 2] = np.arange(num_embeddings - 1) @@ -409,7 +398,6 @@ def cluster( # from the "optimal" threshold for iteration in np.argsort(np.abs(dendrogram[:, 2] - self.threshold)): - # only consider iterations that might have resulted # in changing the number of (large) clusters new_cluster_size = _dendrogram[iteration, 3] @@ -481,6 +469,7 @@ class OracleClustering(BaseClustering): def __call__( self, + embeddings: np.ndarray = None, segmentations: SlidingWindowFeature = None, file: AudioFile = None, frames: SlidingWindow = None, @@ -490,6 +479,9 @@ def __call__( Parameters ---------- + embeddings : (num_chunks, num_speakers, dimension) array, optional + Sequence of embeddings. When provided, compute speaker centroids + based on these embeddings. segmentations : (num_chunks, num_frames, num_speakers) array Binary segmentations. file : AudioFile @@ -503,6 +495,8 @@ def __call__( soft_clusters : (num_chunks, num_speakers, num_clusters) array Soft cluster assignment (the higher soft_clusters[c, s, k], the most likely the sth speaker of cth chunk belongs to kth cluster) + centroids : (num_clusters, dimension), optional + Clusters centroids if `embeddings` is provided, None otherwise. """ num_chunks, num_frames, num_speakers = segmentations.data.shape @@ -532,7 +526,27 @@ def __call__( hard_clusters[c, i] = j soft_clusters[c, i, j] = 1.0 - return hard_clusters, soft_clusters + if embeddings is None: + return hard_clusters, soft_clusters, None + + ( + train_embeddings, + train_chunk_idx, + train_speaker_idx, + ) = self.filter_embeddings( + embeddings, + segmentations=segmentations, + ) + + train_clusters = hard_clusters[train_chunk_idx, train_speaker_idx] + centroids = np.vstack( + [ + np.mean(train_embeddings[train_clusters == k], axis=0) + for k in range(num_clusters) + ] + ) + + return hard_clusters, soft_clusters, centroids class Clustering(Enum): diff --git a/pyannote/audio/pipelines/overlapped_speech_detection.py b/pyannote/audio/pipelines/overlapped_speech_detection.py index 9b14ee10f..064cae1be 100644 --- a/pyannote/audio/pipelines/overlapped_speech_detection.py +++ b/pyannote/audio/pipelines/overlapped_speech_detection.py @@ -128,7 +128,7 @@ def __init__( # load model model = get_model(segmentation, use_auth_token=use_auth_token) - if model.introspection.dimension > 1: + if model.example_output.dimension > 1: inference_kwargs["pre_aggregation_hook"] = lambda scores: np.partition( scores, -2, axis=-1 )[:, :, -2, np.newaxis] diff --git a/pyannote/audio/pipelines/resegmentation.py b/pyannote/audio/pipelines/resegmentation.py index 57cf9004b..bb71abf22 100644 --- a/pyannote/audio/pipelines/resegmentation.py +++ b/pyannote/audio/pipelines/resegmentation.py @@ -88,7 +88,6 @@ def __init__( der_variant: dict = None, use_auth_token: Union[Text, None] = None, ): - super().__init__() self.segmentation = segmentation @@ -96,7 +95,7 @@ def __init__( model: Model = get_model(segmentation, use_auth_token=use_auth_token) self._segmentation = Inference(model) - self._frames = self._segmentation.model.introspection.frames + self._frames = self._segmentation.model.example_output.frames self._audio = model.audio diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 6bc81f28a..18b6565d3 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -89,11 +89,20 @@ class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline): Usage ----- - >>> pipeline = SpeakerDiarization() + # perform (unconstrained) diarization >>> diarization = pipeline("/path/to/audio.wav") + + # perform diarization, targetting exactly 4 speakers >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) + + # perform diarization, with at least 2 speakers and at most 10 speakers >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + # perform diarization and get one representative embedding per speaker + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) + >>> for s, speaker in enumerate(diarization.labels()): + ... # embeddings[s] is the embedding of speaker `speaker` + Hyper-parameters ---------------- segmentation.threshold @@ -136,7 +145,7 @@ def __init__( skip_aggregation=True, batch_size=segmentation_batch_size, ) - self._frames: SlidingWindow = self._segmentation.model.introspection.frames + self._frames: SlidingWindow = self._segmentation.model.example_output.frames if self._segmentation.model.specifications.powerset: self.segmentation = ParamDict( @@ -417,6 +426,7 @@ def apply( num_speakers: int = None, min_speakers: int = None, max_speakers: int = None, + return_embeddings: bool = False, hook: Optional[Callable] = None, ) -> Annotation: """Apply speaker diarization @@ -431,6 +441,8 @@ def apply( Minimum number of speakers. Has no effect when `num_speakers` is provided. max_speakers : int, optional Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. hook : callable, optional Callback called after each major steps of the pipeline as follows: hook(step_name, # human-readable name of current step @@ -444,6 +456,10 @@ def apply( ------- diarization : Annotation Speaker diarization + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. """ # setup hook (e.g. for debugging purposes) @@ -466,6 +482,7 @@ def apply( if self._segmentation.model.specifications.powerset else self.segmentation.threshold, frames=self._frames, + warm_up=(0.0, 0.0), ) hook("speaker_counting", count) # shape: (num_frames, 1) @@ -473,7 +490,11 @@ def apply( # exit early when no speaker is ever active if np.nanmax(count.data) == 0.0: - return Annotation(uri=file["uri"]) + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, np.zeros((0, self._embedding.dimension)) + + return diarization # binarize segmentation if self._segmentation.model.specifications.powerset: @@ -485,7 +506,7 @@ def apply( initial_state=False, ) - if self.klustering == "OracleClustering": + if self.klustering == "OracleClustering" and not return_embeddings: embeddings = None else: embeddings = self.get_embeddings( @@ -497,7 +518,7 @@ def apply( hook("embeddings", embeddings) # shape: (num_chunks, local_num_speakers, dimension) - hard_clusters, _ = self.clustering( + hard_clusters, _, centroids = self.clustering( embeddings=embeddings, segmentations=binarized_segmentations, num_clusters=num_speakers, @@ -506,7 +527,8 @@ def apply( file=file, # <== for oracle clustering frames=self._frames, # <== for oracle clustering ) - # hard_clusters: (num_chunks, num_speakers) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) # reconstruct discrete diarization from raw hard clusters @@ -530,20 +552,52 @@ def apply( ) diarization.uri = file["uri"] - # when reference is available, use it to map hypothesized speakers - # to reference speakers (this makes later error analysis easier - # but does not modify the actual output of the diarization pipeline) + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + if "annotation" in file and file["annotation"]: - return self.optimal_mapping(file["annotation"], diarization) + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} - # when reference is not available, rename hypothesized speakers - # to human-readable SPEAKER_00, SPEAKER_01, ... - return diarization.rename_labels( - { + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { label: expected_label for label, expected_label in zip(diarization.labels(), self.classes()) } - ) + + diarization = diarization.rename_labels(mapping=mapping) + + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + + if not return_embeddings: + return diarization + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + # FIXME: the number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. Will be fixed in the future + + return diarization, centroids def get_metric(self) -> GreedyDiarizationErrorRate: return GreedyDiarizationErrorRate(**self.der_variant) diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index 1a672d614..b30ea2b21 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -28,6 +28,7 @@ import torch import torch.nn.functional as F import torchaudio +import torchaudio.compliance.kaldi as kaldi from torch.nn.utils.rnn import pad_sequence from pyannote.audio import Inference, Model, Pipeline @@ -57,6 +58,13 @@ except ImportError: NEMO_IS_AVAILABLE = False +try: + import onnxruntime as ort + + ONNX_IS_AVAILABLE = True +except ImportError: + ONNX_IS_AVAILABLE = False + class NeMoPretrainedSpeakerEmbedding(BaseInference): def __init__( @@ -64,7 +72,6 @@ def __init__( embedding: Text = "nvidia/speakerverification_en_titanet_large", device: torch.device = None, ): - if not NEMO_IS_AVAILABLE: raise ImportError( f"'NeMo' must be installed to use '{embedding}' embeddings. " @@ -80,6 +87,11 @@ def __init__( self.model_.to(self.device) def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.model_.to(device) self.device = device return self @@ -90,7 +102,6 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - input_signal = torch.rand(1, self.sample_rate).to(self.device) input_signal_length = torch.tensor([self.sample_rate]).to(self.device) _, embeddings = self.model_( @@ -105,7 +116,6 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 while lower + 1 < upper: @@ -152,7 +162,6 @@ def __call__( wav_lens = signals.shape[1] * torch.ones(batch_size) else: - batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks @@ -229,7 +238,6 @@ def __init__( device: torch.device = None, use_auth_token: Union[Text, None] = None, ): - if not SPEECHBRAIN_IS_AVAILABLE: raise ImportError( f"'speechbrain' must be installed to use '{embedding}' embeddings. " @@ -255,6 +263,11 @@ def __init__( ) def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.classifier_ = SpeechBrain_EncoderClassifier.from_hparams( source=self.embedding, savedir=f"{CACHE_DIR}/speechbrain", @@ -281,19 +294,19 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - - lower, upper = 2, round(0.5 * self.sample_rate) - middle = (lower + upper) // 2 - while lower + 1 < upper: - try: - _ = self.classifier_.encode_batch( - torch.randn(1, middle).to(self.device) - ) - upper = middle - except RuntimeError: - lower = middle - + with torch.inference_mode(): + lower, upper = 2, round(0.5 * self.sample_rate) middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + _ = self.classifier_.encode_batch( + torch.randn(1, middle).to(self.device) + ) + upper = middle + except RuntimeError: + lower = middle + + middle = (lower + upper) // 2 return upper @@ -324,7 +337,6 @@ def __call__( wav_lens = signals.shape[1] * torch.ones(batch_size) else: - batch_size_masks, _ = masks.shape assert batch_size == batch_size_masks @@ -371,6 +383,214 @@ def __call__( return embeddings +class WeSpeakerPretrainedSpeakerEmbedding(BaseInference): + """Pretrained WeSpeaker speaker embedding + + Parameters + ---------- + embedding : str + Path to WeSpeaker pretrained speaker embedding + device : torch.device, optional + Device + + Usage + ----- + >>> get_embedding = WeSpeakerPretrainedSpeakerEmbedding("wespeaker.xxxx.onnx") + >>> assert waveforms.ndim == 3 + >>> batch_size, num_channels, num_samples = waveforms.shape + >>> assert num_channels == 1 + >>> embeddings = get_embedding(waveforms) + >>> assert embeddings.ndim == 2 + >>> assert embeddings.shape[0] == batch_size + + >>> assert binary_masks.ndim == 1 + >>> assert binary_masks.shape[0] == batch_size + >>> embeddings = get_embedding(waveforms, masks=binary_masks) + """ + + def __init__( + self, + embedding: Text = "speechbrain/spkrec-ecapa-voxceleb", + device: torch.device = None, + ): + if not ONNX_IS_AVAILABLE: + raise ImportError( + f"'onnxruntime' must be installed to use '{embedding}' embeddings. " + ) + + super().__init__() + + self.embedding = embedding + + self.to(device or torch.device("cpu")) + + def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + + if device.type == "cpu": + providers = ["CPUExecutionProvider"] + elif device.type == "cuda": + providers = ["CUDAExecutionProvider"] + else: + warnings.warn( + f"Unsupported device type: {device.type}, falling back to CPU" + ) + device = torch.device("cpu") + providers = ["CPUExecutionProvider"] + + sess_options = ort.SessionOptions() + sess_options.inter_op_num_threads = 1 + sess_options.intra_op_num_threads = 1 + self.session_ = ort.InferenceSession( + self.embedding, sess_options=sess_options, providers=providers + ) + + self.device = device + return self + + @cached_property + def sample_rate(self) -> int: + return 16000 + + @cached_property + def dimension(self) -> int: + dummy_waveforms = torch.rand(1, 1, 16000) + features = self.compute_fbank(dummy_waveforms) + embeddings = self.session_.run( + output_names=["embs"], input_feed={"feats": features.numpy()} + )[0] + _, dimension = embeddings.shape + return dimension + + @cached_property + def metric(self) -> str: + return "cosine" + + @cached_property + def min_num_samples(self) -> int: + lower, upper = 2, round(0.5 * self.sample_rate) + middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + features = self.compute_fbank(torch.randn(1, 1, middle)) + + except AssertionError: + lower = middle + middle = (lower + upper) // 2 + continue + + embeddings = self.session_.run( + output_names=["embs"], input_feed={"feats": features.numpy()} + )[0] + + if np.any(np.isnan(embeddings)): + lower = middle + else: + upper = middle + middle = (lower + upper) // 2 + + return upper + + @cached_property + def min_num_frames(self) -> int: + return self.compute_fbank(torch.randn(1, 1, self.min_num_samples)).shape[1] + + def compute_fbank( + self, + waveforms: torch.Tensor, + num_mel_bins: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + dither: float = 0.0, + ) -> torch.Tensor: + """Extract fbank features + + Parameters + ---------- + waveforms : (batch_size, num_channels, num_samples) + + Returns + ------- + fbank : (batch_size, num_frames, num_mel_bins) + + Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50 + """ + + waveforms = waveforms * (1 << 15) + features = torch.stack( + [ + kaldi.fbank( + waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + sample_frequency=self.sample_rate, + window_type="hamming", + use_energy=False, + ) + for waveform in waveforms + ] + ) + return features - torch.mean(features, dim=1, keepdim=True) + + def __call__( + self, waveforms: torch.Tensor, masks: torch.Tensor = None + ) -> np.ndarray: + """ + + Parameters + ---------- + waveforms : (batch_size, num_channels, num_samples) + Only num_channels == 1 is supported. + masks : (batch_size, num_samples), optional + + Returns + ------- + embeddings : (batch_size, dimension) + + """ + + batch_size, num_channels, num_samples = waveforms.shape + assert num_channels == 1 + + features = self.compute_fbank(waveforms) + _, num_frames, _ = features.shape + + if masks is None: + embeddings = self.session_.run( + output_names=["embs"], input_feed={"feats": features.numpy()} + )[0] + + return embeddings + + batch_size_masks, _ = masks.shape + assert batch_size == batch_size_masks + + imasks = F.interpolate( + masks.unsqueeze(dim=1), size=num_frames, mode="nearest" + ).squeeze(dim=1) + + imasks = imasks > 0.5 + + embeddings = np.NAN * np.zeros((batch_size, self.dimension)) + + for f, (feature, imask) in enumerate(zip(features, imasks)): + masked_feature = feature[imask] + if masked_feature.shape[0] < self.min_num_frames: + continue + + embeddings[f] = self.session_.run( + output_names=["embs"], + input_feed={"feats": masked_feature.numpy()[None]}, + )[0][0] + + return embeddings + + class PyannoteAudioPretrainedSpeakerEmbedding(BaseInference): """Pretrained pyannote.audio speaker embedding @@ -415,6 +635,11 @@ def __init__( self.model_.to(self.device) def to(self, device: torch.device): + if not isinstance(device, torch.device): + raise TypeError( + f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`" + ) + self.model_.to(device) self.device = device return self @@ -425,7 +650,7 @@ def sample_rate(self) -> int: @cached_property def dimension(self) -> int: - return self.model_.introspection.dimension + return self.model_.example_output.dimension @cached_property def metric(self) -> str: @@ -433,12 +658,24 @@ def metric(self) -> str: @cached_property def min_num_samples(self) -> int: - return self.model_.introspection.min_num_samples + with torch.inference_mode(): + lower, upper = 2, round(0.5 * self.sample_rate) + middle = (lower + upper) // 2 + while lower + 1 < upper: + try: + _ = self.model_(torch.randn(1, 1, middle).to(self.device)) + upper = middle + except RuntimeError: + lower = middle + + middle = (lower + upper) // 2 + + return upper def __call__( self, waveforms: torch.Tensor, masks: torch.Tensor = None ) -> np.ndarray: - with torch.no_grad(): + with torch.inference_mode(): if masks is None: embeddings = self.model_(waveforms.to(self.device)) else: @@ -494,6 +731,9 @@ def PretrainedSpeakerEmbedding( elif isinstance(embedding, str) and "nvidia" in embedding: return NeMoPretrainedSpeakerEmbedding(embedding, device=device) + elif isinstance(embedding, str) and "wespeaker" in embedding: + return WeSpeakerPretrainedSpeakerEmbedding(embedding, device=device) + else: return PyannoteAudioPretrainedSpeakerEmbedding( embedding, device=device, use_auth_token=use_auth_token @@ -557,7 +797,6 @@ def __init__( ) def apply(self, file: AudioFile) -> np.ndarray: - device = self.embedding_model_.device # read audio file and send it to GPU @@ -583,7 +822,6 @@ def main( embedding: str = "pyannote/embedding", segmentation: str = None, ): - import typer from pyannote.database import FileFinder, get_protocol from pyannote.metrics.binary_classification import det_curve @@ -601,7 +839,6 @@ def main( trials = getattr(protocol, f"{subset}_trial")() for t, trial in enumerate(tqdm(trials)): - audio1 = trial["file1"]["audio"] if audio1 not in emb: emb[audio1] = pipeline(audio1) diff --git a/pyannote/audio/pipelines/utils/diarization.py b/pyannote/audio/pipelines/utils/diarization.py index de07524e6..f494c6073 100644 --- a/pyannote/audio/pipelines/utils/diarization.py +++ b/pyannote/audio/pipelines/utils/diarization.py @@ -20,10 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Mapping, Tuple, Union +from typing import Dict, Mapping, Tuple, Union import numpy as np from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature +from pyannote.core.utils.types import Label from pyannote.metrics.diarization import DiarizationErrorRate from pyannote.audio.core.inference import Inference @@ -74,8 +75,10 @@ def set_num_speakers( @staticmethod def optimal_mapping( - reference: Union[Mapping, Annotation], hypothesis: Annotation - ) -> Annotation: + reference: Union[Mapping, Annotation], + hypothesis: Annotation, + return_mapping: bool = False, + ) -> Union[Annotation, Tuple[Annotation, Dict[Label, Label]]]: """Find the optimal bijective mapping between reference and hypothesis labels Parameters @@ -84,13 +87,19 @@ def optimal_mapping( Reference annotation. Can be an Annotation instance or a mapping with an "annotation" key. hypothesis : Annotation + Hypothesized annotation. + return_mapping : bool, optional + Return the label mapping itself along with the mapped annotation. Defaults to False. Returns ------- mapped : Annotation Hypothesis mapped to reference speakers. - + mapping : dict, optional + Mapping between hypothesis (key) and reference (value) labels + Only returned if `return_mapping` is True. """ + if isinstance(reference, Mapping): reference = reference["annotation"] annotated = reference["annotated"] if "annotated" in reference else None @@ -100,7 +109,13 @@ def optimal_mapping( mapping = DiarizationErrorRate().optimal_mapping( reference, hypothesis, uem=annotated ) - return hypothesis.rename_labels(mapping=mapping) + mapped_hypothesis = hypothesis.rename_labels(mapping=mapping) + + if return_mapping: + return mapped_hypothesis, mapping + + else: + return mapped_hypothesis # TODO: get rid of onset/offset (binarization should be applied before calling speaker_count) # TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count) @@ -171,7 +186,8 @@ def to_annotation( Returns ------- continuous_diarization : Annotation - Continuous diarization + Continuous diarization, with speaker labels as integers, + corresponding to the speaker indices in the discrete diarization. """ binarize = Binarize( diff --git a/pyannote/audio/pipelines/utils/oracle.py b/pyannote/audio/pipelines/utils/oracle.py index 486b09274..44b4ded61 100644 --- a/pyannote/audio/pipelines/utils/oracle.py +++ b/pyannote/audio/pipelines/utils/oracle.py @@ -39,7 +39,7 @@ def oracle_segmentation( Simulates inference based on an (imaginary) oracle segmentation model: >>> oracle = Model.from_pretrained("oracle") - >>> assert frames == oracle.introspection.frames + >>> assert frames == oracle.example_output.frames >>> inference = Inference(oracle, duration=window.duration, step=window.step, skip_aggregation=True) >>> oracle_segmentation = inference(file) diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index f5e41d3ee..da164f04e 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -21,7 +21,7 @@ # SOFTWARE. import math -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Sequence, Union import torch import torch.nn.functional as F @@ -75,13 +75,10 @@ def batch_size(self) -> int: def batch_size(self, batch_size: int): self.batch_size_ = batch_size - def setup(self, stage: Optional[str] = None): + def setup(self): # loop over the training set, remove annotated regions shorter than # chunk duration, and keep track of the reference annotations, per class. - # FIXME: it looks like this time consuming step is called multiple times. - # it should not be... - self._train = dict() desc = f"Loading {self.protocol.name} training labels" @@ -118,6 +115,7 @@ def setup(self, stage: Optional[str] = None): problem=Problem.REPRESENTATION, resolution=Resolution.CHUNK, duration=self.duration, + min_duration=self.min_duration, classes=sorted(self._train), ) @@ -151,6 +149,7 @@ def train__iter__(self): classes = list(self.specifications.classes) + # select batch-wise duration at random batch_duration = rng.uniform(self.min_duration, self.duration) num_samples = 0 diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 3db93824d..142245ae8 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -25,7 +25,7 @@ import random import warnings from collections import defaultdict -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Sequence, Union import matplotlib.pyplot as plt import numpy as np @@ -71,14 +71,8 @@ def get_file(self, file_id): return file - def setup(self, stage: Optional[str] = None): - """Setup method - - Parameters - ---------- - stage : {'fit', 'validate', 'test'}, optional - Setup stage. Defaults to 'fit'. - """ + def setup(self): + """Setup""" # duration of training chunks # TODO: handle variable duration case diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index da6104386..8b9cef60c 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -20,16 +20,18 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from functools import cached_property from typing import Dict, List, Optional, Sequence, Text, Tuple, Union +from einops import rearrange import numpy as np import torch import torch.nn.functional as F -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from pyannote.database.protocol import SegmentationProtocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform -from torchmetrics import Metric +from torchmetrics import Accuracy, F1Score, Metric, MetricCollection, Precision, Recall from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin @@ -77,8 +79,13 @@ class MultiLabelSegmentation(SegmentationTaskMixin, Task): torch_audiomentations waveform transform, used by dataloader during training. metric : optional - Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. - Defaults to AUROC (area under the ROC curve). + Multilabel validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Make sure the metric's ignore_index==-1 if your data contains un-annotated frames. + Defaults to F1, Precision, Recall & Accuracy in macro mode. + metric_classwise: Union[Metric, Sequence[Metric], Dict[str, Metric]], optional + Validation metric(s) to compute for each class (binary). Can be anything supported by torchmetrics.MetricCollection. + No need for ignore_index=-1 here, as the metric is computed only on the labelled frames. + Defaults to F1, Precision, Recall & Accuracy. """ def __init__( @@ -94,6 +101,7 @@ def __init__( pin_memory: bool = False, augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric_classwise: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): if not isinstance(protocol, SegmentationProtocol): raise ValueError( @@ -114,19 +122,21 @@ def __init__( self.balance = balance self.weight = weight self.classes = classes + self._metric_classwise = metric_classwise # task specification depends on the data: we do not know in advance which # classes should be detected. therefore, we postpone the definition of # specifications to setup() - def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) + def setup(self): + super().setup() self.specifications = Specifications( classes=self.classes, problem=Problem.MULTI_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, ) @@ -167,14 +177,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - - # TODO: this should be cached - # use model introspection to predict how many frames it will output - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -185,19 +187,23 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets (-1 for un-annotated classes) - y = -np.ones((num_frames, len(self.classes)), dtype=np.int8) + y = -np.ones( + (self.model.example_output.num_frames, len(self.classes)), dtype=np.int8 + ) y[:, self.annotated_classes[file_id]] = 0 for start, end, label in zip( start_idx, end_idx, chunk_annotations["global_label_idx"] ): y[start:end, label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=self.classes) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=self.classes + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} @@ -243,14 +249,51 @@ def validation_step(self, batch, batch_idx: int): # TODO: add support for frame weights # TODO: add support for class weights - # TODO: compute metrics for each class separately - # mask (frame, class) index for which label is missing mask: torch.Tensor = y_true != -1 - y_pred = y_pred[mask] - y_true = y_true[mask] - loss = F.binary_cross_entropy(y_pred, y_true.type(torch.float)) + y_pred_labelled = y_pred[mask] + y_true_labelled = y_true[mask] + loss = F.binary_cross_entropy( + y_pred_labelled, y_true_labelled.type(torch.float) + ) + # log global metric (multilabel) + self.model.validation_metric( + y_pred.reshape((-1, y_pred.shape[-1])), + y_true.reshape((-1, y_true.shape[-1])), + ) + self.model.log_dict( + self.model.validation_metric, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + # log metrics per class (binary) + for class_id, class_name in enumerate(self.classes): + mask: torch.Tensor = y_true[..., class_id] != -1 + if mask.sum() == 0: + continue + + y_pred_labelled = y_pred[..., class_id][mask] + y_true_labelled = y_true[..., class_id][mask] + + metric = self.model.validation_metric_classwise[class_name] + metric( + y_pred_labelled, + y_true_labelled, + ) + + self.model.log_dict( + metric, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + # log losses self.model.log( "loss/val", loss, @@ -261,6 +304,75 @@ def validation_step(self, batch, batch_idx: int): ) return {"loss": loss} + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + class_count = len(self.classes) + if class_count > 1: # multilabel + return { + "F1": F1Score( + task="multilabel", + num_labels=class_count, + ignore_index=-1, + average="macro", + ), + "Precision": Precision( + task="multilabel", + num_labels=class_count, + ignore_index=-1, + average="macro", + ), + "Recall": Recall( + task="multilabel", + num_labels=class_count, + ignore_index=-1, + average="macro", + ), + "Accuracy": Accuracy( + task="multilabel", + num_labels=class_count, + ignore_index=-1, + average="macro", + ), + } + else: + # Binary classification, this case is handled by the per-class metric, see 'default_metric_per_class'/'metric_classwise' + return [] + + def default_metric_classwise( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + return { + "F1": F1Score(task="binary"), + "Precision": Precision(task="binary"), + "Recall": Recall(task="binary"), + "Accuracy": Accuracy(task="binary"), + } + + @cached_property + def metric_classwise(self) -> MetricCollection: + if self._metric_classwise is None: + self._metric_classwise = self.default_metric_classwise() + + return MetricCollection(self._metric_classwise) + + def setup_validation_metric(self): + # setup global/multilabel validation metric + super().setup_validation_metric() + + # and then setup validation metric per class / classwise metrics + metric = self.metric_classwise + if metric is None: + return + + self.model.validation_metric_classwise = torch.nn.ModuleDict().to( + self.model.device + ) + for class_name in self.classes: + self.model.validation_metric_classwise[class_name] = metric.clone( + prefix=f"{class_name}/" + ) + @property def val_monitor(self): """Quantity (and direction) to monitor @@ -280,4 +392,4 @@ def val_monitor(self): pytorch_lightning.callbacks.EarlyStopping """ - return "ValLoss", "min" + return "loss/val", "min" diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 658c350a7..cd3711d61 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -24,7 +24,7 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric @@ -106,7 +106,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - super().__init__( protocol, duration=duration, @@ -122,6 +121,7 @@ def __init__( problem=Problem.BINARY_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[ "overlap", @@ -162,13 +162,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -179,17 +172,19 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets - y = np.zeros((num_frames, 1), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] += 1 y = 1 * (y > 1) - sample["y"] = SlidingWindowFeature(y, frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=["speech"] + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index b6838a71c..eac795a47 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -23,13 +23,13 @@ import math import warnings from collections import Counter -from typing import Dict, Literal, Optional, Sequence, Text, Tuple, Union +from typing import Dict, Literal, Sequence, Text, Tuple, Union import numpy as np import torch import torch.nn.functional from matplotlib import pyplot as plt -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database.protocol import SpeakerDiarizationProtocol from pyannote.database.protocol.protocol import Scope, Subset from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger @@ -186,8 +186,8 @@ def __init__( self.weight = weight self.vad_loss = vad_loss - def setup(self, stage: Optional[str] = None): - super().setup(stage=stage) + def setup(self): + super().setup() # estimate maximum number of speakers per chunk when not provided if self.max_speakers_per_chunk is None: @@ -276,6 +276,7 @@ def setup(self, stage: Optional[str] = None): else Problem.MONO_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], powerset_max_classes=self.max_speakers_per_frame, @@ -326,13 +327,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -343,9 +337,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # get list and number of labels for current scope labels = list(np.unique(chunk_annotations[label_scope_key])) @@ -355,7 +349,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): pass # initial frame-level targets - y = np.zeros((num_frames, num_labels), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, num_labels), dtype=np.uint8) # map labels to indices mapping = {label: idx for idx, label in enumerate(labels)} @@ -366,7 +360,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): mapped_label = mapping[label] y[start:end, mapped_label] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=labels) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=labels + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} @@ -553,11 +549,7 @@ def training_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) + multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) permutated_target_powerset = self.model.powerset.to_powerset( permutated_target.float() @@ -686,11 +678,7 @@ def validation_step(self, batch, batch_idx: int): weight[:, num_frames - warm_up_right :] = 0.0 if self.specifications.powerset: - powerset = torch.nn.functional.one_hot( - torch.argmax(prediction, dim=-1), - self.model.powerset.num_powerset_classes, - ).float() - multilabel = self.model.powerset.to_multilabel(powerset) + multilabel = self.model.powerset.to_multilabel(prediction) permutated_target, _ = permutate(multilabel, target) # FIXME: handle case where target have too many speakers? diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 559ff24eb..967ea1f9b 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -23,7 +23,7 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np -from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from pyannote.core import Segment, SlidingWindowFeature from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric @@ -89,7 +89,6 @@ def __init__( augmentation: BaseWaveformTransform = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): - super().__init__( protocol, duration=duration, @@ -108,6 +107,7 @@ def __init__( problem=Problem.BINARY_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, + min_duration=self.min_duration, warm_up=self.warm_up, classes=[ "speech", @@ -144,13 +144,6 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) - # use model introspection to predict how many frames it will output - # TODO: this should be cached - num_samples = sample["X"].shape[1] - num_frames, _ = self.model.introspection(num_samples) - resolution = duration / num_frames - frames = SlidingWindow(start=0.0, duration=resolution, step=resolution) - # gather all annotations of current file annotations = self.annotations[self.annotations["file_id"] == file_id] @@ -161,16 +154,18 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): # discretize chunk annotations at model output resolution start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - start_idx = np.floor(start / resolution).astype(int) + start_idx = np.floor(start / self.model.example_output.frames.step).astype(int) end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - end_idx = np.ceil(end / resolution).astype(int) + end_idx = np.ceil(end / self.model.example_output.frames.step).astype(int) # frame-level targets - y = np.zeros((num_frames, 1), dtype=np.uint8) + y = np.zeros((self.model.example_output.num_frames, 1), dtype=np.uint8) for start, end in zip(start_idx, end_idx): y[start:end, 0] = 1 - sample["y"] = SlidingWindowFeature(y, frames, labels=["speech"]) + sample["y"] = SlidingWindowFeature( + y, self.model.example_output.frames, labels=["speech"] + ) metadata = self.metadata[file_id] sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} diff --git a/pyannote/audio/utils/multi_task.py b/pyannote/audio/utils/multi_task.py new file mode 100644 index 000000000..3886a0eeb --- /dev/null +++ b/pyannote/audio/utils/multi_task.py @@ -0,0 +1,59 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Any, Callable, Tuple, Union + +from pyannote.audio.core.model import Specifications + + +def map_with_specifications( + specifications: Union[Specifications, Tuple[Specifications]], + func: Callable, + *iterables, +) -> Union[Any, Tuple[Any]]: + """Compute the function using arguments from each of the iterables + + Returns a tuple if provided `specifications` is a tuple, + otherwise returns the function return value. + + Parameters + ---------- + specifications : (tuple of) Specifications + Specifications or tuple of specifications + func : callable + Function called for each specification with + `func(*iterables[i], specifications=specifications[i])` + *iterables : + List of iterables with same length as `specifications`. + + Returns + ------- + output : (tuple of) `func` return value(s) + """ + + if isinstance(specifications, Specifications): + return func(*iterables, specifications=specifications) + + return tuple( + func(*i, specifications=s) for s, *i in zip(specifications, *iterables) + ) diff --git a/pyannote/audio/utils/powerset.py b/pyannote/audio/utils/powerset.py index 215cb7946..0f5cfb5bc 100644 --- a/pyannote/audio/utils/powerset.py +++ b/pyannote/audio/utils/powerset.py @@ -85,25 +85,29 @@ def build_cardinality(self) -> torch.Tensor: return cardinality def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor: - """Convert (hard) predictions from powerset to multi-label + """Convert predictions from (soft) powerset to (hard) multi-label Parameter --------- powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor - Hard predictions in "powerset" space. + Soft predictions in "powerset" space. Returns ------- multi_label : (batch_size, num_frames, num_classes) torch.Tensor Hard predictions in "multi-label" space. - - Note - ---- - This method will not complain if `powerset` is provided a soft predictions - (e.g. the output of a softmax-ed classifier). However, in that particular - case, the resulting soft multi-label output will not make much sense. """ - return torch.matmul(powerset, self.mapping) + + hard_powerset = torch.nn.functional.one_hot( + torch.argmax(powerset, dim=-1), + self.num_powerset_classes, + ).float() + + return torch.matmul(hard_powerset, self.mapping) + + def forward(self, powerset: torch.Tensor) -> torch.Tensor: + """Alias for `to_multilabel`""" + return self.to_multilabel(powerset) def to_powerset(self, multilabel: torch.Tensor) -> torch.Tensor: """Convert (hard) predictions from multi-label to powerset diff --git a/pyannote/audio/utils/preview.py b/pyannote/audio/utils/preview.py index 6094c71cd..fcdf4d124 100644 --- a/pyannote/audio/utils/preview.py +++ b/pyannote/audio/utils/preview.py @@ -256,7 +256,7 @@ def make_frame(T: float): return IPythonVideo(video_path, embed=True) -def preview_training_samples( +def BROKEN_preview_training_samples( model: Model, blank: float = 1.0, video_fps: int = 5, diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..fe2a00e12 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,32 @@ +# MIT License +# +# Copyright (c) 2020- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +def pytest_sessionstart(session): + """ + Called after the Session object has been created and + before performing collection and entering the run test loop. + """ + + from pyannote.database import registry + + registry.load_database("tests/data/database.yml") diff --git a/tests/inference_test.py b/tests/inference_test.py index 807f94cc1..bd5040394 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -1,13 +1,13 @@ import numpy as np import pytest import pytorch_lightning as pl +from pyannote.core import SlidingWindowFeature +from pyannote.database import FileFinder, get_protocol from pyannote.audio import Inference, Model from pyannote.audio.core.task import Resolution from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel from pyannote.audio.tasks import VoiceActivityDetection -from pyannote.core import SlidingWindowFeature -from pyannote.database import FileFinder, get_protocol HF_SAMPLE_MODEL_ID = "pyannote/TestModelForContinuousIntegration" @@ -29,8 +29,8 @@ def trained(): ) vad = VoiceActivityDetection(protocol, duration=2.0, batch_size=16, num_workers=4) model = SimpleSegmentationModel(task=vad) - trainer = pl.Trainer(fast_dev_run=True) - trainer.fit(model, vad) + trainer = pl.Trainer(fast_dev_run=True, accelerator="cpu") + trainer.fit(model) return protocol, model @@ -91,7 +91,6 @@ def test_on_file_path(trained): def test_skip_aggregation(pretrained_model, dev_file): - inference = Inference(pretrained_model, skip_aggregation=True) scores = inference(dev_file) assert len(scores.data.shape) == 3 diff --git a/tests/test_train.py b/tests/test_train.py index 79e7f071a..7a7bfe338 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -20,125 +20,119 @@ def protocol(): def test_train_segmentation(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_train_voice_activity_detection(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_train_overlapped_speech_detection(protocol): overlapped_speech_detection = OverlappedSpeechDetection(protocol) model = SimpleSegmentationModel(task=overlapped_speech_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_with_task_that_does_not_need_setup_for_specs(protocol): voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_with_task_that_needs_setup_for_specs(protocol): segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_with_task_that_does_not_need_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_with_task_that_needs_setup_for_specs(protocol): - voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_freeze_with_task_that_needs_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_finetune_freeze_with_task_that_does_not_need_setup_for_specs(protocol): - vad = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=vad) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) vad = VoiceActivityDetection(protocol) model.task = vad model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_freeze_with_task_that_does_not_need_setup_for_specs(protocol): - segmentation = SpeakerDiarization(protocol) model = SimpleSegmentationModel(task=segmentation) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) voice_activity_detection = VoiceActivityDetection(protocol) model.task = voice_activity_detection model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) def test_transfer_freeze_with_task_that_needs_setup_for_specs(protocol): - voice_activity_detection = VoiceActivityDetection(protocol) model = SimpleSegmentationModel(task=voice_activity_detection) - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) segmentation = SpeakerDiarization(protocol) model.task = segmentation model.freeze_up_to("mfcc") - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(fast_dev_run=True, accelerator="cpu") trainer.fit(model) diff --git a/tutorials/add_your_own_task.ipynb b/tutorials/add_your_own_task.ipynb index b2053f459..251846957 100644 --- a/tutorials/add_your_own_task.ipynb +++ b/tutorials/add_your_own_task.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -32,6 +33,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -48,6 +50,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -57,6 +60,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -82,6 +86,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -125,6 +130,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -176,54 +182,52 @@ " augmentation=augmentation,\n", " )\n", "\n", - " def setup(self, stage=None):\n", - "\n", - " if stage == \"fit\":\n", - "\n", - " # load metadata for training subset\n", - " self.train_metadata_ = list()\n", - " for training_file in self.protocol.train():\n", - " self.training_metadata_.append({\n", - " # path to audio file (str)\n", - " \"audio\": training_file[\"audio\"],\n", - " # duration of audio file (float)\n", - " \"duration\": training_file[\"duration\"],\n", - " # reference annotation (pyannote.core.Annotation)\n", - " \"annotation\": training_file[\"annotation\"],\n", - " })\n", - "\n", - " # gather the list of classes\n", - " classes = set()\n", - " for training_file in self.train_metadata_:\n", - " classes.update(training_file[\"reference\"].labels())\n", - " classes = sorted(classes)\n", - "\n", - " # specify the addressed problem\n", - " self.specifications = Specifications(\n", - " # it is a multi-label classification problem\n", - " problem=Problem.MULTI_LABEL_CLASSIFICATION,\n", - " # we expect the model to output one prediction \n", - " # for the whole chunk\n", - " resolution=Resolution.CHUNK,\n", - " # the model will ingest chunks with that duration (in seconds)\n", - " duration=self.duration,\n", - " # human-readable names of classes\n", - " classes=classes)\n", - "\n", - " # `has_validation` is True iff protocol defines a development set\n", - " if not self.has_validation:\n", - " return\n", - "\n", - " # load metadata for validation subset\n", - " self.validation_metadata_ = list()\n", - " for validation_file in self.protocol.development():\n", - " self.validation_metadata_.append({\n", - " \"audio\": validation_file[\"audio\"],\n", - " \"num_samples\": math.floor(validation_file[\"duration\"] / self.duration),\n", - " \"annotation\": validation_file[\"annotation\"],\n", - " })\n", - " \n", - " \n", + " def setup(self):\n", + "\n", + " # load metadata for training subset\n", + " self.train_metadata_ = list()\n", + " for training_file in self.protocol.train():\n", + " self.training_metadata_.append({\n", + " # path to audio file (str)\n", + " \"audio\": training_file[\"audio\"],\n", + " # duration of audio file (float)\n", + " \"duration\": training_file[\"duration\"],\n", + " # reference annotation (pyannote.core.Annotation)\n", + " \"annotation\": training_file[\"annotation\"],\n", + " })\n", + "\n", + " # gather the list of classes\n", + " classes = set()\n", + " for training_file in self.train_metadata_:\n", + " classes.update(training_file[\"reference\"].labels())\n", + " classes = sorted(classes)\n", + "\n", + " # specify the addressed problem\n", + " self.specifications = Specifications(\n", + " # it is a multi-label classification problem\n", + " problem=Problem.MULTI_LABEL_CLASSIFICATION,\n", + " # we expect the model to output one prediction \n", + " # for the whole chunk\n", + " resolution=Resolution.CHUNK,\n", + " # the model will ingest chunks with that duration (in seconds)\n", + " duration=self.duration,\n", + " # human-readable names of classes\n", + " classes=classes)\n", + "\n", + " # `has_validation` is True iff protocol defines a development set\n", + " if not self.has_validation:\n", + " return\n", + "\n", + " # load metadata for validation subset\n", + " self.validation_metadata_ = list()\n", + " for validation_file in self.protocol.development():\n", + " self.validation_metadata_.append({\n", + " \"audio\": validation_file[\"audio\"],\n", + " \"num_samples\": math.floor(validation_file[\"duration\"] / self.duration),\n", + " \"annotation\": validation_file[\"annotation\"],\n", + " })\n", + " \n", + " \n", "\n", " def train__iter__(self):\n", " # this method generates training samples, one at a time, \"ad infinitum\". each worker \n", diff --git a/tutorials/overlapped_speech_detection.ipynb b/tutorials/overlapped_speech_detection.ipynb index 78c6372cb..1ad5d4090 100644 --- a/tutorials/overlapped_speech_detection.ipynb +++ b/tutorials/overlapped_speech_detection.ipynb @@ -20,6 +20,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -39,6 +40,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -49,6 +51,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -84,6 +87,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -103,6 +107,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -110,6 +115,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -130,6 +136,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -147,6 +154,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -161,10 +169,11 @@ "source": [ "import pytorch_lightning as pl\n", "trainer = pl.Trainer(max_epochs=10)\n", - "trainer.fit(model, osd)" + "trainer.fit(model)" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -185,6 +194,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -212,6 +222,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -219,6 +230,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -242,6 +254,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -258,6 +271,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -265,6 +279,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -297,6 +312,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [