From a4a285493c0c7a1ad1757db8fd49d3e48573d24b Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 13 Sep 2024 14:22:54 -0700 Subject: [PATCH] First --- ultravox/tools/ds_tool/ds_tool.py | 85 +++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 4 deletions(-) diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 750f62e4..db79a7ce 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -6,6 +6,7 @@ import datasets import jinja2 +import numpy as np import openai import simple_parsing import yaml @@ -175,6 +176,80 @@ def _map_sample(self, sample, exclude_fields): return sample +@dataclasses.dataclass +class AudioExtensionTask: + audio_column_name: str = simple_parsing.field(default="audio", alias="-a") + asr_column_name: str = simple_parsing.field(default="sentence", alias="-A") + translation_column_name: str = simple_parsing.field( + default="translation", alias="-T" + ) + id_column_name: str = simple_parsing.field(default="id", alias="-i") + extend_type: str = simple_parsing.field( + default="repeat", alias="-e", choices=["repeat", "combine"] + ) + multiplier: int = simple_parsing.field(default=2, alias="-m") + + def map_split( + self, + ds_split: datasets.Dataset, + num_proc: int, + writer_batch_size: int, + exclude_fields: List[str], + ) -> datasets.Dataset: + print( + f'Extending audio using "{self.extend_type}" method with multiplier {self.multiplier}' + ) + + if self.extend_type == "repeat": + return ds_split.map( + function=self._map_sample_repeat, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + ) + elif self.extend_type == "combine": + return ds_split.map( + function=self._map_batch_combine, + batched=True, + batch_size=self.multiplier, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + remove_columns=ds_split.column_names, + ) + else: + raise ValueError(f"Unknown extend_type: {self.extend_type}") + + def _map_sample_repeat(self, sample): + audio = sample[self.audio_column_name] + if isinstance(audio, dict): + audio_data = audio["array"] + else: + raise ValueError(f"Unsupported audio format: {type(audio)}") + + repeated_audio = np.tile(audio_data, self.multiplier) + sample[self.audio_column_name]["array"] = repeated_audio + + return sample + + def _map_batch_combine(self, batch): + audios = batch[self.audio_column_name] + sentences = batch[self.asr_column_name] + translations = batch[self.translation_column_name] + ids = batch["id"] + + combined_audios = np.concatenate(audios) + combined_sentences = " ".join(sentences) + combined_translations = " ".join(translations) + combined_ids = "+".join(ids) + + new_batch = { + self.audio_column_name: combined_audios, + self.asr_column_name: combined_sentences, + self.translation_column_name: combined_translations, + self.id_column_name: combined_ids, + } + return new_batch + + # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. # just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN # just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct @@ -218,10 +293,12 @@ class DatasetToolArgs: default_factory=lambda: ["audio"] ) - task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups( - {"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore - default_factory=TtsTask, - positional=True, + task: Union[TtsTask, TextGenerationTask, AudioExtensionTask] = ( + simple_parsing.subgroups( + {"tts": TtsTask, "textgen": TextGenerationTask, "audioext": AudioExtensionTask}, # type: ignore + default_factory=TtsTask, + positional=True, + ) ) def __post_init__(self):