Skip to content

Commit

Permalink
First
Browse files Browse the repository at this point in the history
  • Loading branch information
liPatrick committed Sep 13, 2024
1 parent ae39709 commit a4a2854
Showing 1 changed file with 81 additions and 4 deletions.
85 changes: 81 additions & 4 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets
import jinja2
import numpy as np
import openai
import simple_parsing
import yaml
Expand Down Expand Up @@ -175,6 +176,80 @@ def _map_sample(self, sample, exclude_fields):
return sample


@dataclasses.dataclass
class AudioExtensionTask:
audio_column_name: str = simple_parsing.field(default="audio", alias="-a")
asr_column_name: str = simple_parsing.field(default="sentence", alias="-A")
translation_column_name: str = simple_parsing.field(
default="translation", alias="-T"
)
id_column_name: str = simple_parsing.field(default="id", alias="-i")
extend_type: str = simple_parsing.field(
default="repeat", alias="-e", choices=["repeat", "combine"]
)
multiplier: int = simple_parsing.field(default=2, alias="-m")

def map_split(
self,
ds_split: datasets.Dataset,
num_proc: int,
writer_batch_size: int,
exclude_fields: List[str],
) -> datasets.Dataset:
print(
f'Extending audio using "{self.extend_type}" method with multiplier {self.multiplier}'
)

if self.extend_type == "repeat":
return ds_split.map(
function=self._map_sample_repeat,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
)
elif self.extend_type == "combine":
return ds_split.map(
function=self._map_batch_combine,
batched=True,
batch_size=self.multiplier,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
remove_columns=ds_split.column_names,
)
else:
raise ValueError(f"Unknown extend_type: {self.extend_type}")

def _map_sample_repeat(self, sample):
audio = sample[self.audio_column_name]
if isinstance(audio, dict):
audio_data = audio["array"]
else:
raise ValueError(f"Unsupported audio format: {type(audio)}")

repeated_audio = np.tile(audio_data, self.multiplier)
sample[self.audio_column_name]["array"] = repeated_audio

return sample

def _map_batch_combine(self, batch):
audios = batch[self.audio_column_name]
sentences = batch[self.asr_column_name]
translations = batch[self.translation_column_name]
ids = batch["id"]

combined_audios = np.concatenate(audios)
combined_sentences = " ".join(sentences)
combined_translations = " ".join(translations)
combined_ids = "+".join(ids)

new_batch = {
self.audio_column_name: combined_audios,
self.asr_column_name: combined_sentences,
self.translation_column_name: combined_translations,
self.id_column_name: combined_ids,
}
return new_batch


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
Expand Down Expand Up @@ -218,10 +293,12 @@ class DatasetToolArgs:
default_factory=lambda: ["audio"]
)

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore
default_factory=TtsTask,
positional=True,
task: Union[TtsTask, TextGenerationTask, AudioExtensionTask] = (
simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask, "audioext": AudioExtensionTask}, # type: ignore
default_factory=TtsTask,
positional=True,
)
)

def __post_init__(self):
Expand Down

0 comments on commit a4a2854

Please sign in to comment.