Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset Tool to add Timestamps #121

Merged
merged 19 commits into from
Oct 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions Justfile
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ export WANDB_LOG_MODEL:="checkpoint"
export PROJECT_DIR:="ultravox"
export MCLOUD_CLUSTER:="r7z22p1"
export MCLOUD_INSTANCE:="oci.bm.gpu.b4.8"
export MFA_ENV_NAME:="aligner"

default: format check test

@@ -62,3 +63,33 @@ run *FLAGS:

mcloud *FLAGS:
poetry run mcli interactive {{FLAGS}} --cluster ${MCLOUD_CLUSTER} --instance ${MCLOUD_INSTANCE} --name `whoami` --command "bash -c \"$(cat setup.sh)\""

@check_conda:
if ! command -v conda &> /dev/null; then \
echo "Conda is not installed."; \
mkdir -p ~/miniconda3; \
if [ "$(uname)" = "Darwin" ]; then \
echo "Downloading MacOS Miniconda."; \
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh; \
elif [ "$(uname)" = "Linux" ]; then \
echo "Downloading Linux Miniconda."; \
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh \
else \
echo "Unknown operating system."; \
fi; \
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3; \
rm ~/miniconda3/miniconda.sh; \
else \
echo "Conda is installed."; \
fi

@install_mfa: check_conda
if conda env list | grep -q "$MFA_ENV_NAME"; then \
echo "Environment '$MFA_ENV_NAME' already exists."; \
else \
echo "Creating environment '$MFA_ENV_NAME'."; \
conda create --name "$MFA_ENV_NAME" python=3.8 -y; \
conda create -n "$MFA_ENV_NAME" -c conda-forge montreal-forced-aligner; \
conda run -n "$MFA_ENV_NAME" mfa model download acoustic english_mfa; \
conda run -n "$MFA_ENV_NAME" mfa model download dictionary english_mfa; \
fi
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ wandb = "~0.17.1"
sacrebleu = "^2.4.2"
tenacity = "^9.0.0"
evals = {git = "https://github.com/fixie-ai/evals", rev = "0c66bf85df7a4b903ecb202b23c2a826b749fd71"}
praatio = "^6.2.0"

[tool.poetry.group.dev.dependencies]
black = "~24.4.2"
254 changes: 222 additions & 32 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import dataclasses
import glob
import json
import math
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import subprocess
import tempfile
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import datasets
import jinja2
import librosa
import openai
import simple_parsing
import soundfile as sf
import yaml
from praatio import textgrid
from tenacity import retry
from tenacity import stop_after_attempt
from tenacity import wait_fixed
@@ -21,6 +29,41 @@
tts_client: caching.CachingTtsWrapper
chat_client: caching.CachingChatWrapper

MFA_ENV_NAME = "aligner"


def apply_jinja_template(
template: str, sample: Dict[str, Any], exclude_fields: Optional[Set[str]] = None
):
"""
Apply a Jinja template to a sample, rendering it into text.
Jinja template allows for added flexibility as template can include variables and functions.
Args:
template: The Jinja template to apply. It can include variables, functions, and control structures.
Example:
{{ text }}
{{ text_proc.format_asr_text(text) }}
sample: The sample to apply the template to.
exclude_fields: Fields to exclude from the sample before rendering the template, to avoid loading large fields into memory.
"""
if exclude_fields:
# Filter out big fields like audio before the sample is passed into the jinja template
# otherwise it can lead to unnecessary memory usage.
sample = {k: sample[k] for k in sample.keys() if k not in exclude_fields}

try:
return jinja2.Template(template, undefined=jinja2.StrictUndefined).render(
**sample, json_dump=json.dumps, text_proc=text_proc
)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"template: {template}")
print(f"sample keys: {list(sample.keys())}, excluded keys: {exclude_fields}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e


@dataclasses.dataclass
class TtsTask:
@@ -55,7 +98,10 @@ def map_split(
) -> datasets.Dataset:
print(f'TTS mapping "{self.template}" to "{self.audio_column_name}"...')
ds_split = ds_split.map(
self._map_sample, num_proc=num_proc, writer_batch_size=writer_batch_size
self._map_sample,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
fn_kwargs={"exclude_fields": exclude_fields},
)
column_type = datasets.Audio(sampling_rate=self.sample_rate)
if self.json_mode and isinstance(
@@ -64,20 +110,10 @@ def map_split(
column_type = datasets.Sequence(column_type)
return ds_split.cast_column(self.audio_column_name, column_type)

def _map_sample(self, sample):
def _map_sample(self, sample, exclude_fields: Set[str]):
# using a Jinja template for some added flexibility, template can include variables and functions
# e.g., {{ text }} or {{ text_proc.format_asr_text(text) }}
try:
text_or_texts = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**sample, json_dump=json.dumps, text_proc=text_proc)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure column_name exists in the sample."
) from e
text_or_texts = apply_jinja_template(self.template, sample, exclude_fields)

if self.json_mode:
text_or_texts = yaml.safe_load(text_or_texts)
@@ -137,24 +173,11 @@ def _map_sample(self, sample, exclude_fields):
# using a Jinja template for some added flexibility, template can include variables and functions
# e.g., {{ text }} or {{ text_proc.format_asr_text(text) }}
try:
# Filter out the audio before the sample is passed into the jinja template, or it will get loaded into memory.
filtered_sample = {
k: sample[k] for k in sample.keys() if k not in exclude_fields
}
rendered = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc)
rendered = apply_jinja_template(self.template, sample, exclude_fields)
except text_proc.FormatASRError as e:
print(f"Format ASR Error {e}. Filtering out sample.")
sample[self.new_column_name] = None
return sample
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(filtered_sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e

if self.json_mode:
turns = yaml.safe_load(rendered)
@@ -175,6 +198,163 @@ def _map_sample(self, sample, exclude_fields):
return sample


@dataclasses.dataclass
class TimestampGenerationTask:
"""
This task is used to generate timestamps for the text transcription.
It uses the Montreal Forced Aligner (MFA) to align the text with the audio. The result is a
list of timestamps for each word in the text transcription. The timestamps are stored in a new
column, in a list of dict format:
[ {"start": float in seconds, "end": float in seconds, "text": first word str}, ... ]
"""

# Jinja template for the text transcription that needs to be aligned
template: str = simple_parsing.field(alias="-T")
# The accoustic model to use for MFA alignment.
# Make sure the dictionary and acoustic model are installed. See just install_mfa for an example (English).
# Model index: https://mfa-models.readthedocs.io/en/latest/acoustic/index.html
# For many languages there exists a {language}_mfa model that you can use, e.g. "english_mfa"
mfa_acoustic_model: str = simple_parsing.field(alias="-m")
# The dictionary to use for MFA alignment. Defaults to the same name as the acoustic model.
mfa_dictionary: Optional[str] = simple_parsing.field(default=None, alias="-d")
audio_column_name: str = simple_parsing.field(default="audio", alias="-a")
sample_rate: int = simple_parsing.field(default=16000, alias="-r")
# The column name to store the timestamps in
timestamp_column_name: str = simple_parsing.field(default="timestamps", alias="-ts")
aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar")

def __post_init__(self):
if self.mfa_dictionary is None:
self.mfa_dictionary = self.mfa_acoustic_model

try:
# Make sure the MFA environment is installed
subprocess.run(["conda", "run", "-n", MFA_ENV_NAME, "echo"], check=True)
except subprocess.CalledProcessError:
raise Exception(
"Please install the MFA environment using `just install_mfa` first."
)

if self.template.startswith("@"):
with open(self.template[1:], "r") as template_file:
self.template = template_file.read()

def map_split(
self,
ds_split: datasets.Dataset,
num_proc: int,
writer_batch_size: int,
exclude_fields: List[str],
) -> datasets.Dataset:
# 0. create a temp directory to store the audio and text files
# The files will be deleted when the with block ends or when an exception is raised
with tempfile.TemporaryDirectory() as temp_dir:
# 1. copy all audio-text pairs into the temp directory
ds_split.map(
self._store_sample_as_files,
num_proc=num_proc,
fn_kwargs={"exclude_fields": set(exclude_fields), "temp_dir": temp_dir},
)

count_wavs = len(glob.glob(os.path.join(temp_dir, "*.wav")))
assert count_wavs == len(
ds_split
), "Not all samples were stored as files. The id is likely not unique."

# 2. run the alignment
self._run_alignment(temp_dir, num_proc=num_proc)

# 3. retrieve the timestamps
ds_mapped = ds_split.map(
self._retrieve_timestamps,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
fn_kwargs={"temp_dir": temp_dir},
)

# 4. filter out samples without timestamps (should be a small number)
ds_mapped = ds_mapped.filter(
lambda sample: sample[self.timestamp_column_name] is not None,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
)

# 5. make sure most samples have timestamps
if len(ds_split) * self.aligned_ratio_check > len(ds_mapped):
raise Exception(
f"Found too many samples without timestamps: {len(ds_mapped)}/{len(ds_split)} aligned."
)

return ds_mapped

def _retrieve_timestamps(self, sample, temp_dir: str):
# find the timestamps for the audio and populate the timestamps column
sample_id = self.get_id(sample)
text_path = os.path.join(temp_dir, f"{sample_id}.TextGrid")
if not os.path.exists(text_path):
sample[self.timestamp_column_name] = None
return sample

tg = textgrid.openTextgrid(text_path, False)
timestamps = tg.getTier("words").entries
sample[self.timestamp_column_name] = [
{"start": entry.start, "end": entry.end, "text": entry.label}
for entry in timestamps
]
return sample

@staticmethod
def get_id(sample):
for key in ["id", "segment_id"]:
if key in sample and isinstance(sample[key], str):
return str(sample[key])
for key in ["file", "path", "audio_file"]:
if key in sample and isinstance(sample[key], str):
return Path(sample[key]).stem
raise ValueError("Could not find an ID in the sample")

def _store_sample_as_files(self, sample, temp_dir: str, exclude_fields: Set[str]):
sample_id = self.get_id(sample)
audio_path = os.path.join(temp_dir, f"{sample_id}.wav")
with open(audio_path, "wb") as f:
audio = sample[self.audio_column_name]
if audio["sampling_rate"] != self.sample_rate:
audio["array"] = librosa.resample(
audio["array"],
orig_sr=audio["sampling_rate"],
target_sr=self.sample_rate,
)
sf.write(f, audio["array"], 16000, format="WAV", subtype="PCM_16")

text_path = os.path.join(temp_dir, f"{sample_id}.txt")
text = apply_jinja_template(self.template, sample, exclude_fields)
with open(text_path, "w") as f:
f.write(text)

def _run_alignment(self, temp_dir: str, num_proc: int = 16) -> None:
subprocess.run(
[
"conda",
"run",
"--no-capture-output",
"-n",
MFA_ENV_NAME,
"mfa",
"align",
"--clean",
"--single_speaker",
"--use_mp",
"-j",
str(num_proc),
temp_dir,
self.mfa_acoustic_model,
str(self.mfa_dictionary),
temp_dir,
],
check=True,
)


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
@@ -183,6 +363,8 @@ def _map_sample(self, sample, exclude_fields):
# --shuffle --upload_name fixie-ai/librispeech_asr --private --base_url https://api.fireworks.ai/inference/v1 \
# --api_key $FIREWORKS_API_KEY --token $HF_TOKEN --language_model accounts/fireworks/models/llama-v3-8b-instruct \
# --template @ultravox/tools/ds_tool/continuation.jinja --max_tokens 64 --num_workers 30 --writer_batch_size 30
# just ds_tool timestamp -d fixie-ai/common_voice_17_0 -S en --upload_name fixie-ai/cv_ts \
# -m english_mfa -T "\"{{text_proc.format_asr_text(sentence)}}\""
@dataclasses.dataclass
class DatasetToolArgs:
# HF source dataset parameters
@@ -218,10 +400,12 @@ class DatasetToolArgs:
default_factory=lambda: ["audio"]
)

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore
default_factory=TtsTask,
positional=True,
task: Union[TtsTask, TextGenerationTask, TimestampGenerationTask] = (
simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask, "timestamp": TimestampGenerationTask}, # type: ignore
default_factory=TtsTask,
positional=True,
)
)

def __post_init__(self):
@@ -232,6 +416,11 @@ def __post_init__(self):
if self.dataset_split and not self.upload_split:
self.upload_split = self.dataset_split

if self.upload_name == self.dataset_name:
raise ValueError(
"Updating datasets in-place is not well-supported and hence frowned upon."
)


class DatasetChunkProcessor:
args: DatasetToolArgs
@@ -301,6 +490,7 @@ def process_and_upload_split_rescursive(
# then the huggingface README needs to be updated to have the
# download_size, and dataset_size fields present under dataset_info (could be initalized to 0)
print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.")
print(traceback.format_exc())
if total_chunks == 1:
print(
f"Finished processing and uploading 0/1 chunks for range [{start_index}, {end_index})"