From 52c7f11a6f70a906195ec80a788052b5227ce361 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 24 Sep 2024 16:39:59 -0700 Subject: [PATCH 01/19] add timestamp forced aligner --- Justfile | 32 +++++++ poetry.lock | 16 +++- pyproject.toml | 1 + ultravox/tools/ds_tool/ds_tool.py | 152 +++++++++++++++++++++++++++++- 4 files changed, 196 insertions(+), 5 deletions(-) diff --git a/Justfile b/Justfile index bafe39db..2348cc9a 100644 --- a/Justfile +++ b/Justfile @@ -3,6 +3,7 @@ export WANDB_LOG_MODEL:="checkpoint" export PROJECT_DIR:="ultravox" export MCLOUD_CLUSTER:="r7z22p1" export MCLOUD_INSTANCE:="oci.bm.gpu.b4.8" +export MFA_ENV_NAME:="aligner" default: format check test @@ -62,3 +63,34 @@ run *FLAGS: mcloud *FLAGS: poetry run mcli interactive {{FLAGS}} --cluster ${MCLOUD_CLUSTER} --instance ${MCLOUD_INSTANCE} --name `whoami` --command "bash -c \"$(cat setup.sh)\"" + +@check_conda: + if ! command -v conda &> /dev/null; then \ + echo "Conda is not installed."; \ + mkdir -p ~/miniconda3; \ + if [[ "$OSTYPE" == "darwin"* ]]; then \ + echo "Downloading MacOS Miniconda."; \ + curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh; \ + elif [[ "$OSTYPE" == "linux-gnu"* ]]; then \ + echo "Downloading Linux Miniconda."; \ + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh \ + else \ + echo "Unknown operating system."; \ + fi; \ + bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3; \ + rm ~/miniconda3/miniconda.sh; \ + else \ + echo "Conda is installed."; \ + fi + +@install_mfa: check_conda + if conda env list | grep -q "$MFA_ENV_NAME"; then \ + echo "Environment '$MFA_ENV_NAME' already exists."; \ + else \ + echo "Creating environment '$MFA_ENV_NAME'."; \ + conda create --name "$MFA_ENV_NAME" python=3.8 -y; \ + conda create -n "$MFA_ENV_NAME" -c conda-forge montreal-forced-aligner; \ + conda activate "$MFA_ENV_NAME"; \ + mfa model download acoustic english_mfa; \ + mfa model download dictionary english_mfa; \ + fi \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 6ee9b992..7465e884 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5124,6 +5124,20 @@ docs = ["sphinx (>=1.7.1)"] redis = ["redis"] tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] +[[package]] +name = "praatio" +version = "6.2.0" +description = "A library for working with praat, textgrids, time aligned audio transcripts, and audio files." +optional = false +python-versions = ">3.6.0" +files = [ + {file = "praatio-6.2.0-py3-none-any.whl", hash = "sha256:6541018791a3f0b087a8168d1746a165937c3fff1f94c7a6883b3f470e0cf405"}, + {file = "praatio-6.2.0.tar.gz", hash = "sha256:7d2a7f8135a3e0691743ada0af84308b64e637f07038cea77d814b8aa2fa2e40"}, +] + +[package.dependencies] +typing-extensions = "*" + [[package]] name = "preshed" version = "3.0.9" @@ -8883,4 +8897,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "798d26eeecb0625e6e6b655f7209286319924de573c9cdd9a30416593a492cb5" +content-hash = "f1d462cee8239c355f81406ff7ee88e42fd85b955abec54aac629ef2cd4a4cce" diff --git a/pyproject.toml b/pyproject.toml index 21c6734c..d11c699c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ wandb = "~0.17.1" sacrebleu = "^2.4.2" tenacity = "^9.0.0" evals = {git = "https://github.com/fixie-ai/evals", rev = "0c66bf85df7a4b903ecb202b23c2a826b749fd71"} +praatio = "^6.2.0" [tool.poetry.group.dev.dependencies] black = "~24.4.2" diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 750f62e4..b0ea08c2 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -2,12 +2,18 @@ import json import math import os +import tempfile +import subprocess +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import datasets import jinja2 +import librosa import openai import simple_parsing +import soundfile as sf +from praatio import textgrid import yaml from tenacity import retry from tenacity import stop_after_attempt @@ -21,6 +27,8 @@ tts_client: caching.CachingTtsWrapper chat_client: caching.CachingChatWrapper +MFA_ENV_NAME = "aligner" + @dataclasses.dataclass class TtsTask: @@ -88,6 +96,134 @@ def _map_sample(self, sample): return sample +@dataclasses.dataclass +class TimestampGenerationTask: + language: str = simple_parsing.field(default="en", alias="-l") + audio_column_name: str = simple_parsing.field(default="audio", alias="-a") + template: str = simple_parsing.field( + default="{{text_proc.format_asr_text(text)}}", alias="-T" + ) + timestamp_column_name: str = simple_parsing.field( + default="timestamps", alias="-tsc" + ) + sample_rate: int = simple_parsing.field(default=16000, alias="-r") + temp_dir: Optional[str] = None + aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") + + def __post_init__(self): + try: + # Make sure the MFA environment is installed + subprocess.run(["conda", "run", "-n", MFA_ENV_NAME, "echo"], check=True) + except subprocess.CalledProcessError: + raise Exception( + "Please install the MFA environment using `just install_mfa` first." + ) + + if self.template.startswith("@"): + with open(self.template[1:], "r") as template_file: + self.template = template_file.read() + + def map_split( + self, + ds_split: datasets.Dataset, + num_proc: int, + writer_batch_size: int, + exclude_fields: List[str], + ) -> datasets.Dataset: + ds_mapped = ds_split.map( + self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size + ).filter( + lambda sample: sample[self.timestamp_column_name] is not None, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) + if len(ds_split) * self.aligned_ratio_check > len(ds_mapped): + raise Exception( + f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned." + ) + return ds_mapped + + def _map_sample(self, sample): + # find the timestamps for the audio and populate the timestamps column + sample_id = self.get_id(sample) + text_path = os.path.join(self.temp_dir, f"{sample_id}.TextGrid") + if not os.path.exists(text_path): + sample[self.timestamp_column_name] = None + return sample + + tg = textgrid.openTextgrid(text_path, False) + timestamps = tg.getTier("words").entries + sample[self.timestamp_column_name] = [ + {"start": entry.start, "end": entry.end, "text": entry.label} + for entry in timestamps + ] + return sample + + @staticmethod + def get_id(sample): + for key in ["id", "segment_id"]: + if key in sample and isinstance(sample[key], str): + return str(sample[key]) + for key in ["file", "path", "audio_file"]: + if key in sample and isinstance(sample[key], str): + return Path(sample[key]).stem + raise ValueError("Could not find an ID in the sample") + + def preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: + if self.temp_dir is None: + if hasattr(self, "_temp_dir"): + self._temp_dir.cleanup() + self._temp_dir = tempfile.TemporaryDirectory() + self.temp_dir = self._temp_dir.name + + os.makedirs(self.temp_dir, exist_ok=True) + + # 1. copy all audio-text pairs into a temp directory + for sample in ds_split: + sample_id = self.get_id(sample) + audio_path = os.path.join(self.temp_dir, f"{sample_id}.wav") + with open(audio_path, "wb") as f: + audio = sample[self.audio_column_name] + if audio["sampling_rate"] != self.sample_rate: + audio["array"] = librosa.resample( + audio["array"], + orig_sr=audio["sampling_rate"], + target_sr=self.sample_rate, + ) + sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16") + + text_path = os.path.join(self.temp_dir, f"{sample_id}.txt") + with open(text_path, "w") as f: + # TODO: exclude_fields + text = jinja2.Template( + self.template, undefined=jinja2.StrictUndefined + ).render(**sample, json_dump=json.dumps, text_proc=text_proc) + f.write(text) + + # 2. run forced alignment on the temp directory + subprocess.run( + [ + "conda", + "run", + "--no-capture-output", + "-n", + MFA_ENV_NAME, + "mfa", + "align", + "--clean", + self.temp_dir, + "english_mfa", + "english_mfa", + self.temp_dir, + ], + check=True, + ) + + def __del__(self): + if hasattr(self, "_temp_dir"): + self._temp_dir.cleanup() + + @dataclasses.dataclass class TextGenerationTask: new_column_name: str = simple_parsing.field(alias="-c") @@ -218,10 +354,12 @@ class DatasetToolArgs: default_factory=lambda: ["audio"] ) - task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups( - {"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore - default_factory=TtsTask, - positional=True, + task: Union[TtsTask, TextGenerationTask, TimestampGenerationTask] = ( + simple_parsing.subgroups( + {"tts": TtsTask, "textgen": TextGenerationTask, "ts": TimestampGenerationTask}, # type: ignore + default_factory=TtsTask, + positional=True, + ) ) def __post_init__(self): @@ -300,7 +438,10 @@ def process_and_upload_split_rescursive( # If the error is unsupported operand type(s) for -=: 'NoneType' and 'float', # then the huggingface README needs to be updated to have the # download_size, and dataset_size fields present under dataset_info (could be initalized to 0) + import traceback + print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.") + print(traceback.format_exc()) if total_chunks == 1: print( f"Finished processing and uploading 0/1 chunks for range [{start_index}, {end_index})" @@ -390,6 +531,9 @@ def main(args: DatasetToolArgs): if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) + if hasattr(args.task, "preprocess"): + args.task.preprocess(ds_split) + ds_chunk_proc.process_and_upload_split_rescursive( split_name, ds_split, 0, len(ds_split) ) From fe2a184a8677b5dcd91a0ed25191ed6c08992b3a Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 24 Sep 2024 16:42:55 -0700 Subject: [PATCH 02/19] fix download conda on linux --- Justfile | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/Justfile b/Justfile index 2348cc9a..269bdbba 100644 --- a/Justfile +++ b/Justfile @@ -68,10 +68,10 @@ mcloud *FLAGS: if ! command -v conda &> /dev/null; then \ echo "Conda is not installed."; \ mkdir -p ~/miniconda3; \ - if [[ "$OSTYPE" == "darwin"* ]]; then \ + if [ "$(uname)" = "Darwin" ]; then \ echo "Downloading MacOS Miniconda."; \ curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh; \ - elif [[ "$OSTYPE" == "linux-gnu"* ]]; then \ + elif [ "$(uname)" = "Linux" ]; then \ echo "Downloading Linux Miniconda."; \ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh \ else \ @@ -90,7 +90,6 @@ mcloud *FLAGS: echo "Creating environment '$MFA_ENV_NAME'."; \ conda create --name "$MFA_ENV_NAME" python=3.8 -y; \ conda create -n "$MFA_ENV_NAME" -c conda-forge montreal-forced-aligner; \ - conda activate "$MFA_ENV_NAME"; \ - mfa model download acoustic english_mfa; \ - mfa model download dictionary english_mfa; \ + conda run -n "$MFA_ENV_NAME" mfa model download acoustic english_mfa; \ + conda run -n "$MFA_ENV_NAME" mfa model download dictionary english_mfa; \ fi \ No newline at end of file From 6226f9b9cbd2f47fe6e7c2b7300deaf5c0da1034 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 24 Sep 2024 17:01:40 -0700 Subject: [PATCH 03/19] formatting --- ultravox/tools/ds_tool/ds_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index b0ea08c2..0f799a70 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -2,8 +2,8 @@ import json import math import os -import tempfile import subprocess +import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -13,8 +13,8 @@ import openai import simple_parsing import soundfile as sf -from praatio import textgrid import yaml +from praatio import textgrid from tenacity import retry from tenacity import stop_after_attempt from tenacity import wait_fixed From b036218dc770e9417407dba05c5bbb7c1fbcb6b7 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 25 Sep 2024 13:24:35 -0700 Subject: [PATCH 04/19] move preprocess inside map_split --- ultravox/tools/ds_tool/ds_tool.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 0f799a70..da3a49ad 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -111,6 +111,8 @@ class TimestampGenerationTask: aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") def __post_init__(self): + self._temp_dir: Optional[tempfile.TemporaryDirectory] = None + try: # Make sure the MFA environment is installed subprocess.run(["conda", "run", "-n", MFA_ENV_NAME, "echo"], check=True) @@ -130,6 +132,9 @@ def map_split( writer_batch_size: int, exclude_fields: List[str], ) -> datasets.Dataset: + # Main task: generate timestamps for the audio samples + self._preprocess(ds_split) + ds_mapped = ds_split.map( self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size ).filter( @@ -169,9 +174,9 @@ def get_id(sample): return Path(sample[key]).stem raise ValueError("Could not find an ID in the sample") - def preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: + def _preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: if self.temp_dir is None: - if hasattr(self, "_temp_dir"): + if self._temp_dir is not None: self._temp_dir.cleanup() self._temp_dir = tempfile.TemporaryDirectory() self.temp_dir = self._temp_dir.name @@ -220,8 +225,9 @@ def preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: ) def __del__(self): - if hasattr(self, "_temp_dir"): + if self._temp_dir is not None: self._temp_dir.cleanup() + self._temp_dir = None @dataclasses.dataclass @@ -531,8 +537,8 @@ def main(args: DatasetToolArgs): if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - if hasattr(args.task, "preprocess"): - args.task.preprocess(ds_split) + # if hasattr(args.task, "preprocess"): + # args.task.preprocess(ds_split) ds_chunk_proc.process_and_upload_split_rescursive( split_name, ds_split, 0, len(ds_split) From cddac21b1e73fef399cb30a9da1c6a1a1490376d Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 26 Sep 2024 12:00:55 -0700 Subject: [PATCH 05/19] fix temp_dir --- ultravox/tools/ds_tool/ds_tool.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index da3a49ad..a8510ff4 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -107,9 +107,14 @@ class TimestampGenerationTask: default="timestamps", alias="-tsc" ) sample_rate: int = simple_parsing.field(default=16000, alias="-r") - temp_dir: Optional[str] = None aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") + def _get_new_temp_dir(self) -> str: + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = tempfile.TemporaryDirectory() + self.temp_dir = self._temp_dir.name + def __post_init__(self): self._temp_dir: Optional[tempfile.TemporaryDirectory] = None @@ -175,13 +180,7 @@ def get_id(sample): raise ValueError("Could not find an ID in the sample") def _preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: - if self.temp_dir is None: - if self._temp_dir is not None: - self._temp_dir.cleanup() - self._temp_dir = tempfile.TemporaryDirectory() - self.temp_dir = self._temp_dir.name - - os.makedirs(self.temp_dir, exist_ok=True) + self._get_new_temp_dir() # 1. copy all audio-text pairs into a temp directory for sample in ds_split: @@ -537,9 +536,6 @@ def main(args: DatasetToolArgs): if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - # if hasattr(args.task, "preprocess"): - # args.task.preprocess(ds_split) - ds_chunk_proc.process_and_upload_split_rescursive( split_name, ds_split, 0, len(ds_split) ) From 583feb133a04c6f0e77304bc937905cfc32b24e4 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 26 Sep 2024 15:56:10 -0700 Subject: [PATCH 06/19] faster TimestampGenerationTask --- ultravox/tools/ds_tool/ds_tool.py | 82 +++++++++++++++++++------------ 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index a8510ff4..becb44a6 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -5,7 +5,7 @@ import subprocess import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import datasets import jinja2 @@ -137,23 +137,41 @@ def map_split( writer_batch_size: int, exclude_fields: List[str], ) -> datasets.Dataset: - # Main task: generate timestamps for the audio samples - self._preprocess(ds_split) + exclude_fields = set(exclude_fields) + self._get_new_temp_dir() + + # 1. copy all audio-text pairs into a temp directory + ds_split.map( + self._store_sample_as_files, + num_proc=num_proc, + fn_kwargs={"exclude_fields": exclude_fields}, + ) + + # 2. run the alignment + self._run_alignment(self.temp_dir) + # 3. retrieve the timestamps ds_mapped = ds_split.map( - self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size - ).filter( + self._retrieve_timestamps, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) + + # 4. filter out samples without timestamps (should be a small number) + ds_mapped = ds_mapped.filter( lambda sample: sample[self.timestamp_column_name] is not None, num_proc=num_proc, writer_batch_size=writer_batch_size, ) + + # 5. make sure most samples have timestamps if len(ds_split) * self.aligned_ratio_check > len(ds_mapped): raise Exception( f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned." ) return ds_mapped - def _map_sample(self, sample): + def _retrieve_timestamps(self, sample): # find the timestamps for the audio and populate the timestamps column sample_id = self.get_id(sample) text_path = os.path.join(self.temp_dir, f"{sample_id}.TextGrid") @@ -179,32 +197,32 @@ def get_id(sample): return Path(sample[key]).stem raise ValueError("Could not find an ID in the sample") - def _preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: - self._get_new_temp_dir() + def _store_sample_as_files(self, sample, exclude_fields: Set[str]): + sample_id = self.get_id(sample) + audio_path = os.path.join(self.temp_dir, f"{sample_id}.wav") + with open(audio_path, "wb") as f: + audio = sample[self.audio_column_name] + if audio["sampling_rate"] != self.sample_rate: + audio["array"] = librosa.resample( + audio["array"], + orig_sr=audio["sampling_rate"], + target_sr=self.sample_rate, + ) + sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16") - # 1. copy all audio-text pairs into a temp directory - for sample in ds_split: - sample_id = self.get_id(sample) - audio_path = os.path.join(self.temp_dir, f"{sample_id}.wav") - with open(audio_path, "wb") as f: - audio = sample[self.audio_column_name] - if audio["sampling_rate"] != self.sample_rate: - audio["array"] = librosa.resample( - audio["array"], - orig_sr=audio["sampling_rate"], - target_sr=self.sample_rate, - ) - sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16") + text_path = os.path.join(self.temp_dir, f"{sample_id}.txt") + with open(text_path, "w") as f: + filtered_sample = { + k: sample[k] for k in sample.keys() if k not in exclude_fields + } + text = jinja2.Template( + self.template, undefined=jinja2.StrictUndefined + ).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc) + f.write(text) - text_path = os.path.join(self.temp_dir, f"{sample_id}.txt") - with open(text_path, "w") as f: - # TODO: exclude_fields - text = jinja2.Template( - self.template, undefined=jinja2.StrictUndefined - ).render(**sample, json_dump=json.dumps, text_proc=text_proc) - f.write(text) + return None - # 2. run forced alignment on the temp directory + def _run_alignment(self, temp_dir: str) -> None: subprocess.run( [ "conda", @@ -215,10 +233,12 @@ def _preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: "mfa", "align", "--clean", - self.temp_dir, + "--use_mp", + "--single_speaker", + temp_dir, "english_mfa", "english_mfa", - self.temp_dir, + temp_dir, ], check=True, ) From d276eb83b8a9dd09c8bf87ad2391d3e07720e344 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 26 Sep 2024 16:31:41 -0700 Subject: [PATCH 07/19] remove single_speaker --- ultravox/tools/ds_tool/ds_tool.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index becb44a6..2dc6ad4c 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -234,7 +234,6 @@ def _run_alignment(self, temp_dir: str) -> None: "align", "--clean", "--use_mp", - "--single_speaker", temp_dir, "english_mfa", "english_mfa", From d1812281b958681e1dcd41edee47cfe14b0becd8 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 26 Sep 2024 16:48:03 -0700 Subject: [PATCH 08/19] set num_proc for ds_tool --- ultravox/tools/ds_tool/ds_tool.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 2dc6ad4c..50f25813 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -148,7 +148,7 @@ def map_split( ) # 2. run the alignment - self._run_alignment(self.temp_dir) + self._run_alignment(self.temp_dir, num_proc=num_proc) # 3. retrieve the timestamps ds_mapped = ds_split.map( @@ -222,7 +222,7 @@ def _store_sample_as_files(self, sample, exclude_fields: Set[str]): return None - def _run_alignment(self, temp_dir: str) -> None: + def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: subprocess.run( [ "conda", @@ -234,6 +234,8 @@ def _run_alignment(self, temp_dir: str) -> None: "align", "--clean", "--use_mp", + "-j", + str(num_proc), temp_dir, "english_mfa", "english_mfa", From 245daaf33ab894065f59aabc7659c775ff622baf Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 16:39:01 -0700 Subject: [PATCH 09/19] bugfixes and refactoring --- ultravox/tools/ds_tool/ds_tool.py | 281 ++++++++++++++++-------------- 1 file changed, 147 insertions(+), 134 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 50f25813..8b906291 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -1,4 +1,5 @@ import dataclasses +import glob import json import math import os @@ -30,6 +31,39 @@ MFA_ENV_NAME = "aligner" +def apply_jinja_template( + template: str, sample: Dict[str, Any], exclude_fields: Optional[Set[str]] = None +): + """ + Apply a Jinja template to a sample, rendering it into text. + Jinja template allows for added flexibility as template can include variables and functions. + + Args: + template: The Jinja template to apply. It can include variables, functions, and control structures. + Example: + {{ text }} + {{ text_proc.format_asr_text(text) }} + sample: The sample to apply the template to. + exclude_fields: Fields to exclude from the sample before rendering the template, to avoid loading large fields into memory. + """ + if exclude_fields: + # Filter out big fields like audio before the sample is passed into the jinja template + # otherwise it can lead to unnecessary memory usage. + sample = {k: sample[k] for k in sample.keys() if k not in exclude_fields} + + try: + return jinja2.Template(template, undefined=jinja2.StrictUndefined).render( + **sample, json_dump=json.dumps, text_proc=text_proc + ) + except jinja2.TemplateError as e: + print(f"Error rendering template: {e}") + print(f"template: {template}") + print(f"sample keys: {list(sample.keys())}") + raise ValueError( + f"Template rendering failed. Make sure all keys in the template exist in the sample." + ) from e + + @dataclasses.dataclass class TtsTask: # Jinja template for the text that needs to be converted to audio @@ -63,7 +97,10 @@ def map_split( ) -> datasets.Dataset: print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...') ds_split = ds_split.map( - self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size + self._map_sample, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + fn_kwargs={"exclude_fields": exclude_fields}, ) column_type = datasets.Audio(sampling_rate=self.sample_rate) if self.json_mode and isinstance( @@ -72,20 +109,10 @@ def map_split( column_type = datasets.Sequence(column_type) return ds_split.cast_column(self.audio_column_name, column_type) - def _map_sample(self, sample): + def _map_sample(self, sample, exclude_fields: Set[str]): # using a Jinja template for some added flexibility, template can include variables and functions # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} - try: - text_or_texts = jinja2.Template( - self.template, undefined=jinja2.StrictUndefined - ).render(**sample, json_dump=json.dumps, text_proc=text_proc) - except jinja2.TemplateError as e: - print(f"Error rendering template: {e}") - print(f"template: {self.template}") - print(f"sample keys: {list(sample.keys())}") - raise ValueError( - f"Template rendering failed. Make sure column_name exists in the sample." - ) from e + text_or_texts = apply_jinja_template(self.template, sample, exclude_fields) if self.json_mode: text_or_texts = yaml.safe_load(text_or_texts) @@ -96,28 +123,101 @@ def _map_sample(self, sample): return sample +@dataclasses.dataclass +class TextGenerationTask: + new_column_name: str = simple_parsing.field(alias="-c") + template: str = simple_parsing.field(alias="-T") + json_mode: bool = simple_parsing.field(default=False, alias="-j") + + language_model: str = simple_parsing.field(default="gpt-4o", alias="-m") + base_url: Optional[str] = simple_parsing.field(default=None, alias="-b") + api_key: Optional[str] = simple_parsing.field(default=None, alias="-k") + max_tokens: int = 128 + temperature: float = 0 + + def __post_init__(self): + # The OAI client is separate from the task to avoid pickling issues when multiprocessing. + global chat_client + # Caching the client to avoid repeated calls to the API if the tool fails. + chat_client = caching.CachingChatWrapper( + openai.Client(base_url=self.base_url, api_key=self.api_key), + unique_id=f"{self.base_url}__{self.language_model}", + ) + if self.template.startswith("@"): + with open(self.template[1:], "r") as template_file: + self.template = template_file.read() + + def map_split( + self, + ds_split: datasets.Dataset, + num_proc: int, + writer_batch_size: int, + exclude_fields: List[str], + ) -> datasets.Dataset: + print(f'Generating "{self.new_column_name}" with template:\n{self.template}') + ds_mapped = ds_split.map( + lambda sample: self._map_sample(sample, set(exclude_fields)), + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) + + # Filter out samples where new_column_name is None + return ds_mapped.filter( + lambda sample: sample[self.new_column_name] is not None, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) + + def _map_sample(self, sample, exclude_fields): + # using a Jinja template for some added flexibility, template can include variables and functions + # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} + try: + rendered = apply_jinja_template(self.template, sample, exclude_fields) + except text_proc.FormatASRError as e: + print(f"Format ASR Error {e}. Filtering out sample.") + sample[self.new_column_name] = None + return sample + + if self.json_mode: + turns = yaml.safe_load(rendered) + assert isinstance(turns, list) + assert all(isinstance(turn, dict) for turn in turns) + assert len(turns) > 0 + assert turns[-1].get("role", None) == "user" + else: + turns = [{"role": "user", "content": rendered}] + + sample[self.new_column_name] = chat_client.chat_completion( + model=self.language_model, + messages=turns, + max_tokens=self.max_tokens, + temperature=self.temperature, + ) + + return sample + + @dataclasses.dataclass class TimestampGenerationTask: - language: str = simple_parsing.field(default="en", alias="-l") + """ + This task is used to generate timestamps for the text transcription. + It uses the Montreal Forced Aligner (MFA) to align the text with the audio. The result is a + list of timestamps for each word in the text transcription. The timestamps are stored in a new + column, in a dictionary format: {"start": float in seconds, "end": float in seconds, "text": word str}. + """ + + # Jinja template for the text transcription that needs to be aligned + template: str = simple_parsing.field(alias="-T") audio_column_name: str = simple_parsing.field(default="audio", alias="-a") - template: str = simple_parsing.field( - default="{{text_proc.format_asr_text(text)}}", alias="-T" - ) - timestamp_column_name: str = simple_parsing.field( - default="timestamps", alias="-tsc" - ) sample_rate: int = simple_parsing.field(default=16000, alias="-r") + # The column name to store the timestamps in + timestamp_column_name: str = simple_parsing.field(default="timestamps", alias="-ts") + # The language to use for the MFA alignment. Make sure the dictionary and acoustic model are installed. + # See just install_mfa as it downloads the English models. + language: str = simple_parsing.field(default="english", alias="-l") aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") - def _get_new_temp_dir(self) -> str: - if self._temp_dir is not None: - self._temp_dir.cleanup() - self._temp_dir = tempfile.TemporaryDirectory() - self.temp_dir = self._temp_dir.name - def __post_init__(self): - self._temp_dir: Optional[tempfile.TemporaryDirectory] = None - try: # Make sure the MFA environment is installed subprocess.run(["conda", "run", "-n", MFA_ENV_NAME, "echo"], check=True) @@ -137,16 +237,23 @@ def map_split( writer_batch_size: int, exclude_fields: List[str], ) -> datasets.Dataset: - exclude_fields = set(exclude_fields) - self._get_new_temp_dir() + # 0. create a temp directory to store the audio and text files + _temp_dir = tempfile.TemporaryDirectory() + self.temp_dir = _temp_dir.name + os.makedirs(self.temp_dir, exist_ok=True) - # 1. copy all audio-text pairs into a temp directory + # 1. copy all audio-text pairs into the temp directory ds_split.map( self._store_sample_as_files, num_proc=num_proc, - fn_kwargs={"exclude_fields": exclude_fields}, + fn_kwargs={"exclude_fields": set(exclude_fields)}, ) + count_wavs = len(glob.glob(os.path.join(self.temp_dir, "*.wav"))) + assert count_wavs >= len( + ds_split + ), "Not all samples were stored as files. The id is likely not unique." + # 2. run the alignment self._run_alignment(self.temp_dir, num_proc=num_proc) @@ -169,6 +276,10 @@ def map_split( raise Exception( f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned." ) + + # 6. cleanup + _temp_dir.cleanup() + return ds_mapped def _retrieve_timestamps(self, sample): @@ -211,17 +322,10 @@ def _store_sample_as_files(self, sample, exclude_fields: Set[str]): sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16") text_path = os.path.join(self.temp_dir, f"{sample_id}.txt") + text = apply_jinja_template(self.template, sample, exclude_fields) with open(text_path, "w") as f: - filtered_sample = { - k: sample[k] for k in sample.keys() if k not in exclude_fields - } - text = jinja2.Template( - self.template, undefined=jinja2.StrictUndefined - ).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc) f.write(text) - return None - def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: subprocess.run( [ @@ -233,109 +337,18 @@ def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: "mfa", "align", "--clean", + "--single_speaker", "--use_mp", "-j", str(num_proc), temp_dir, - "english_mfa", - "english_mfa", + f"{self.language}_mfa", + f"{self.language}_mfa", temp_dir, ], check=True, ) - def __del__(self): - if self._temp_dir is not None: - self._temp_dir.cleanup() - self._temp_dir = None - - -@dataclasses.dataclass -class TextGenerationTask: - new_column_name: str = simple_parsing.field(alias="-c") - template: str = simple_parsing.field(alias="-T") - json_mode: bool = simple_parsing.field(default=False, alias="-j") - - language_model: str = simple_parsing.field(default="gpt-4o", alias="-m") - base_url: Optional[str] = simple_parsing.field(default=None, alias="-b") - api_key: Optional[str] = simple_parsing.field(default=None, alias="-k") - max_tokens: int = 128 - temperature: float = 0 - - def __post_init__(self): - # The OAI client is separate from the task to avoid pickling issues when multiprocessing. - global chat_client - # Caching the client to avoid repeated calls to the API if the tool fails. - chat_client = caching.CachingChatWrapper( - openai.Client(base_url=self.base_url, api_key=self.api_key), - unique_id=f"{self.base_url}__{self.language_model}", - ) - if self.template.startswith("@"): - with open(self.template[1:], "r") as template_file: - self.template = template_file.read() - - def map_split( - self, - ds_split: datasets.Dataset, - num_proc: int, - writer_batch_size: int, - exclude_fields: List[str], - ) -> datasets.Dataset: - print(f'Generating "{self.new_column_name}" with template:\n{self.template}') - ds_mapped = ds_split.map( - lambda sample: self._map_sample(sample, set(exclude_fields)), - num_proc=num_proc, - writer_batch_size=writer_batch_size, - ) - - # Filter out samples where new_column_name is None - return ds_mapped.filter( - lambda sample: sample[self.new_column_name] is not None, - num_proc=num_proc, - writer_batch_size=writer_batch_size, - ) - - def _map_sample(self, sample, exclude_fields): - # using a Jinja template for some added flexibility, template can include variables and functions - # e.g., {{ text }} or {{ text_proc.format_asr_text(text) }} - try: - # Filter out the audio before the sample is passed into the jinja template, or it will get loaded into memory. - filtered_sample = { - k: sample[k] for k in sample.keys() if k not in exclude_fields - } - rendered = jinja2.Template( - self.template, undefined=jinja2.StrictUndefined - ).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc) - except text_proc.FormatASRError as e: - print(f"Format ASR Error {e}. Filtering out sample.") - sample[self.new_column_name] = None - return sample - except jinja2.TemplateError as e: - print(f"Error rendering template: {e}") - print(f"template: {self.template}") - print(f"sample keys: {list(filtered_sample.keys())}") - raise ValueError( - f"Template rendering failed. Make sure all keys in the template exist in the sample." - ) from e - - if self.json_mode: - turns = yaml.safe_load(rendered) - assert isinstance(turns, list) - assert all(isinstance(turn, dict) for turn in turns) - assert len(turns) > 0 - assert turns[-1].get("role", None) == "user" - else: - turns = [{"role": "user", "content": rendered}] - - sample[self.new_column_name] = chat_client.chat_completion( - model=self.language_model, - messages=turns, - max_tokens=self.max_tokens, - temperature=self.temperature, - ) - - return sample - # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. # just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN From 7b132c8d53d3a922c15d2d2788652d6dfd01e3d6 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 16:41:03 -0700 Subject: [PATCH 10/19] move import to file start --- ultravox/tools/ds_tool/ds_tool.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 8b906291..3a8d08bc 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -5,6 +5,7 @@ import os import subprocess import tempfile +import traceback from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -477,8 +478,6 @@ def process_and_upload_split_rescursive( # If the error is unsupported operand type(s) for -=: 'NoneType' and 'float', # then the huggingface README needs to be updated to have the # download_size, and dataset_size fields present under dataset_info (could be initalized to 0) - import traceback - print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.") print(traceback.format_exc()) if total_chunks == 1: From 6a340c1fd3a656159689e45722e388a9e43d99eb Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 16:42:39 -0700 Subject: [PATCH 11/19] rename ts to timestamp --- ultravox/tools/ds_tool/ds_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 3a8d08bc..bb83245c 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -396,7 +396,7 @@ class DatasetToolArgs: task: Union[TtsTask, TextGenerationTask, TimestampGenerationTask] = ( simple_parsing.subgroups( - {"tts": TtsTask, "textgen": TextGenerationTask, "ts": TimestampGenerationTask}, # type: ignore + {"tts": TtsTask, "textgen": TextGenerationTask, "timestamp": TimestampGenerationTask}, # type: ignore default_factory=TtsTask, positional=True, ) From 6e24b5bf2accb51097f9fbe18a0d6d35b09e6ca0 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 16:53:53 -0700 Subject: [PATCH 12/19] improve cli arg names --- ultravox/tools/ds_tool/ds_tool.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index bb83245c..515202e3 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -209,16 +209,23 @@ class TimestampGenerationTask: # Jinja template for the text transcription that needs to be aligned template: str = simple_parsing.field(alias="-T") + # The accoustic model to use for MFA alignment. + # Make sure the dictionary and acoustic model are installed. See just install_mfa for an example (English). + # Model index: https://mfa-models.readthedocs.io/en/latest/acoustic/index.html + # For many languages there exists a {language}_mfa model that you can use, e.g. "english_mfa" + mfa_acoustic_model: str = simple_parsing.field(alias="-m") + # The dictionary to use for MFA alignment. Defaults to the same name as the acoustic model. + mfa_dictionary: str = simple_parsing.field(default=None, alias="-d") audio_column_name: str = simple_parsing.field(default="audio", alias="-a") sample_rate: int = simple_parsing.field(default=16000, alias="-r") # The column name to store the timestamps in timestamp_column_name: str = simple_parsing.field(default="timestamps", alias="-ts") - # The language to use for the MFA alignment. Make sure the dictionary and acoustic model are installed. - # See just install_mfa as it downloads the English models. - language: str = simple_parsing.field(default="english", alias="-l") aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") def __post_init__(self): + if self.mfa_dictionary is None: + self.mfa_dictionary = self.mfa_acoustic_model + try: # Make sure the MFA environment is installed subprocess.run(["conda", "run", "-n", MFA_ENV_NAME, "echo"], check=True) @@ -343,8 +350,8 @@ def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: "-j", str(num_proc), temp_dir, - f"{self.language}_mfa", - f"{self.language}_mfa", + self.mfa_acoustic_model, + self.mfa_dictionary, temp_dir, ], check=True, @@ -359,6 +366,8 @@ def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: # --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \ # --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \ # --template @ultravox/tools/ds_tool/continuation.jinja --max_tokens 64 --num_workers 30 --writer_batch_size 30 +# just ds_tool timestamp -d fixie-ai/common_voice_17_0 -S en --upload_name fixie-ai/cv_ts \ +# -m english_mfa -T "\"{{text_proc.format_asr_text(sentence)}}\"" @dataclasses.dataclass class DatasetToolArgs: # HF source dataset parameters From f91e42f7c7365a2f7f2e8c13664c694b6b27f6e6 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 16:57:53 -0700 Subject: [PATCH 13/19] add warning for updating dataset in-place --- ultravox/tools/ds_tool/ds_tool.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 515202e3..216e157e 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -419,6 +419,11 @@ def __post_init__(self): if self.dataset_split and not self.upload_split: self.upload_split = self.dataset_split + if self.upload_name == self.dataset_name: + raise ValueError( + "Updating datasets in-place is not well-supported and hence frowned upon." + ) + class DatasetChunkProcessor: args: DatasetToolArgs From 0e9458187a1f95f176572a262d5620cf8e81a2ad Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 17:01:45 -0700 Subject: [PATCH 14/19] update docs --- ultravox/tools/ds_tool/ds_tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 216e157e..36363e26 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -204,7 +204,8 @@ class TimestampGenerationTask: This task is used to generate timestamps for the text transcription. It uses the Montreal Forced Aligner (MFA) to align the text with the audio. The result is a list of timestamps for each word in the text transcription. The timestamps are stored in a new - column, in a dictionary format: {"start": float in seconds, "end": float in seconds, "text": word str}. + column, in a list of dict format: + [ {"start": float in seconds, "end": float in seconds, "text": first word str}, ... ] """ # Jinja template for the text transcription that needs to be aligned From 95478a304255b1608de8ff8dc0b9c99fc3815c17 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 27 Sep 2024 17:04:19 -0700 Subject: [PATCH 15/19] fix mypy checks --- ultravox/tools/ds_tool/ds_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 36363e26..db6ae08b 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -216,7 +216,7 @@ class TimestampGenerationTask: # For many languages there exists a {language}_mfa model that you can use, e.g. "english_mfa" mfa_acoustic_model: str = simple_parsing.field(alias="-m") # The dictionary to use for MFA alignment. Defaults to the same name as the acoustic model. - mfa_dictionary: str = simple_parsing.field(default=None, alias="-d") + mfa_dictionary: Optional[str] = simple_parsing.field(default=None, alias="-d") audio_column_name: str = simple_parsing.field(default="audio", alias="-a") sample_rate: int = simple_parsing.field(default=16000, alias="-r") # The column name to store the timestamps in @@ -352,7 +352,7 @@ def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None: str(num_proc), temp_dir, self.mfa_acoustic_model, - self.mfa_dictionary, + str(self.mfa_dictionary), temp_dir, ], check=True, From 9036b839220ec4359ccb6fcd8ff9d9e1046ea247 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 1 Oct 2024 10:51:39 -0700 Subject: [PATCH 16/19] add exclude_fields to logs --- ultravox/tools/ds_tool/ds_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index db6ae08b..2c3d0c69 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -59,7 +59,7 @@ def apply_jinja_template( except jinja2.TemplateError as e: print(f"Error rendering template: {e}") print(f"template: {template}") - print(f"sample keys: {list(sample.keys())}") + print(f"sample keys: {list(sample.keys())}, excluded keys: {exclude_fields}") raise ValueError( f"Template rendering failed. Make sure all keys in the template exist in the sample." ) from e From 93259fc0d79aa05ec702a63a06f56253f6576912 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 1 Oct 2024 11:58:03 -0700 Subject: [PATCH 17/19] tighten wav count condition --- ultravox/tools/ds_tool/ds_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 2c3d0c69..7d10dca3 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -259,7 +259,7 @@ def map_split( ) count_wavs = len(glob.glob(os.path.join(self.temp_dir, "*.wav"))) - assert count_wavs >= len( + assert count_wavs == len( ds_split ), "Not all samples were stored as files. The id is likely not unique." From 9df2ca3b71bd06301f7442c96fa4b369e6f7df53 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 1 Oct 2024 13:30:58 -0700 Subject: [PATCH 18/19] use with tempdir to cleanup on exceptions --- ultravox/tools/ds_tool/ds_tool.py | 68 +++++++++++++++---------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 7d10dca3..cebc85cc 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -247,47 +247,45 @@ def map_split( exclude_fields: List[str], ) -> datasets.Dataset: # 0. create a temp directory to store the audio and text files - _temp_dir = tempfile.TemporaryDirectory() - self.temp_dir = _temp_dir.name - os.makedirs(self.temp_dir, exist_ok=True) - - # 1. copy all audio-text pairs into the temp directory - ds_split.map( - self._store_sample_as_files, - num_proc=num_proc, - fn_kwargs={"exclude_fields": set(exclude_fields)}, - ) + # The files will be deleted when the with block ends or when an exception is raised + with tempfile.TemporaryDirectory() as _temp_dir: + self.temp_dir = _temp_dir.name + os.makedirs(self.temp_dir, exist_ok=True) + + # 1. copy all audio-text pairs into the temp directory + ds_split.map( + self._store_sample_as_files, + num_proc=num_proc, + fn_kwargs={"exclude_fields": set(exclude_fields)}, + ) - count_wavs = len(glob.glob(os.path.join(self.temp_dir, "*.wav"))) - assert count_wavs == len( - ds_split - ), "Not all samples were stored as files. The id is likely not unique." + count_wavs = len(glob.glob(os.path.join(self.temp_dir, "*.wav"))) + assert count_wavs == len( + ds_split + ), "Not all samples were stored as files. The id is likely not unique." - # 2. run the alignment - self._run_alignment(self.temp_dir, num_proc=num_proc) + # 2. run the alignment + self._run_alignment(self.temp_dir, num_proc=num_proc) - # 3. retrieve the timestamps - ds_mapped = ds_split.map( - self._retrieve_timestamps, - num_proc=num_proc, - writer_batch_size=writer_batch_size, - ) - - # 4. filter out samples without timestamps (should be a small number) - ds_mapped = ds_mapped.filter( - lambda sample: sample[self.timestamp_column_name] is not None, - num_proc=num_proc, - writer_batch_size=writer_batch_size, - ) + # 3. retrieve the timestamps + ds_mapped = ds_split.map( + self._retrieve_timestamps, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) - # 5. make sure most samples have timestamps - if len(ds_split) * self.aligned_ratio_check > len(ds_mapped): - raise Exception( - f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned." + # 4. filter out samples without timestamps (should be a small number) + ds_mapped = ds_mapped.filter( + lambda sample: sample[self.timestamp_column_name] is not None, + num_proc=num_proc, + writer_batch_size=writer_batch_size, ) - # 6. cleanup - _temp_dir.cleanup() + # 5. make sure most samples have timestamps + if len(ds_split) * self.aligned_ratio_check > len(ds_mapped): + raise Exception( + f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned." + ) return ds_mapped From 9302f2d672136ea094aea781b741331806ddf22b Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 1 Oct 2024 13:52:03 -0700 Subject: [PATCH 19/19] make self.temp_dir a local var --- ultravox/tools/ds_tool/ds_tool.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index cebc85cc..50d7e88e 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -248,30 +248,28 @@ def map_split( ) -> datasets.Dataset: # 0. create a temp directory to store the audio and text files # The files will be deleted when the with block ends or when an exception is raised - with tempfile.TemporaryDirectory() as _temp_dir: - self.temp_dir = _temp_dir.name - os.makedirs(self.temp_dir, exist_ok=True) - + with tempfile.TemporaryDirectory() as temp_dir: # 1. copy all audio-text pairs into the temp directory ds_split.map( self._store_sample_as_files, num_proc=num_proc, - fn_kwargs={"exclude_fields": set(exclude_fields)}, + fn_kwargs={"exclude_fields": set(exclude_fields), "temp_dir": temp_dir}, ) - count_wavs = len(glob.glob(os.path.join(self.temp_dir, "*.wav"))) + count_wavs = len(glob.glob(os.path.join(temp_dir, "*.wav"))) assert count_wavs == len( ds_split ), "Not all samples were stored as files. The id is likely not unique." # 2. run the alignment - self._run_alignment(self.temp_dir, num_proc=num_proc) + self._run_alignment(temp_dir, num_proc=num_proc) # 3. retrieve the timestamps ds_mapped = ds_split.map( self._retrieve_timestamps, num_proc=num_proc, writer_batch_size=writer_batch_size, + fn_kwargs={"temp_dir": temp_dir}, ) # 4. filter out samples without timestamps (should be a small number) @@ -289,10 +287,10 @@ def map_split( return ds_mapped - def _retrieve_timestamps(self, sample): + def _retrieve_timestamps(self, sample, temp_dir: str): # find the timestamps for the audio and populate the timestamps column sample_id = self.get_id(sample) - text_path = os.path.join(self.temp_dir, f"{sample_id}.TextGrid") + text_path = os.path.join(temp_dir, f"{sample_id}.TextGrid") if not os.path.exists(text_path): sample[self.timestamp_column_name] = None return sample @@ -315,9 +313,9 @@ def get_id(sample): return Path(sample[key]).stem raise ValueError("Could not find an ID in the sample") - def _store_sample_as_files(self, sample, exclude_fields: Set[str]): + def _store_sample_as_files(self, sample, temp_dir: str, exclude_fields: Set[str]): sample_id = self.get_id(sample) - audio_path = os.path.join(self.temp_dir, f"{sample_id}.wav") + audio_path = os.path.join(temp_dir, f"{sample_id}.wav") with open(audio_path, "wb") as f: audio = sample[self.audio_column_name] if audio["sampling_rate"] != self.sample_rate: @@ -328,7 +326,7 @@ def _store_sample_as_files(self, sample, exclude_fields: Set[str]): ) sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16") - text_path = os.path.join(self.temp_dir, f"{sample_id}.txt") + text_path = os.path.join(temp_dir, f"{sample_id}.txt") text = apply_jinja_template(self.template, sample, exclude_fields) with open(text_path, "w") as f: f.write(text)