-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Break up datasets.py This splits out types.py and registry.py to move the list of pre-defined datasets to its own file and avoid circular refs. An __all__ import is used to minimize changes to surrounding code. * sr * cr * merge * restore typing
- Loading branch information
Showing
20 changed files
with
808 additions
and
754 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from ultravox.data.data_sample import * | ||
from ultravox.data.datasets import * | ||
from ultravox.data.registry import * | ||
from ultravox.data.types import * | ||
|
||
__all__ = [ | ||
"SizedIterableDataset", | ||
"EmptyDataset", | ||
"InterleaveDataset", | ||
"Range", | ||
"Dataproc", | ||
"VoiceDataset", | ||
"VoiceDatasetArgs", | ||
"VoiceSample", | ||
"create_dataset", | ||
"register_datasets", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import base64 | ||
import dataclasses | ||
import io | ||
from typing import Any, Dict, List, Optional | ||
|
||
import librosa | ||
import numpy as np | ||
import soundfile as sf | ||
from numpy import typing as npt | ||
|
||
SAMPLE_RATE = 16000 | ||
|
||
|
||
def audio_from_file(path: str) -> np.ndarray: | ||
"""Load audio from a file, converting to float32 PCM @ 16 kHz.""" | ||
audio, _ = librosa.load(path, sr=SAMPLE_RATE) | ||
assert audio.dtype == np.float32 | ||
return audio | ||
|
||
|
||
def audio_from_buf(buf: bytes) -> np.ndarray: | ||
"""Load audio from a buffer, converting to float32 PCM @ 16 kHz.""" | ||
audio, _ = librosa.load(io.BytesIO(buf), sr=SAMPLE_RATE) | ||
assert audio.dtype == np.float32 | ||
return audio | ||
|
||
|
||
def audio_to_wav(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> bytes: | ||
"""Convert audio to WAV format, 16-bit PCM @ 16 kHz.""" | ||
assert audio.dtype == np.float32 | ||
with io.BytesIO() as buf: | ||
sf.write(buf, audio, sample_rate, format="WAV", subtype="PCM_16") | ||
return buf.getvalue() | ||
|
||
|
||
def audio_to_wav_base64(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str: | ||
"""Convert audio to a base64-encoded WAV file.""" | ||
return base64.b64encode(audio_to_wav(audio, sample_rate)).decode("utf-8") | ||
|
||
|
||
def audio_to_data_uri(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str: | ||
"""Convert audio to a data URI.""" | ||
return f"data:audio/wav;base64,{audio_to_wav_base64(audio, sample_rate)}" | ||
|
||
|
||
def messages_from_prompt(prompt: str) -> List[Dict[str, str]]: | ||
return [{"role": "user", "content": prompt}] | ||
|
||
|
||
@dataclasses.dataclass | ||
class VoiceSample: | ||
@staticmethod | ||
def from_json(data: Dict[str, Any]) -> "VoiceSample": | ||
"""Convert from JSON format; audio is expected as base64ed WAV.""" | ||
bytes = base64.b64decode(data["audio"]) | ||
return VoiceSample(data["messages"], audio_from_buf(bytes)) | ||
|
||
@staticmethod | ||
def from_prompt(prompt: str) -> "VoiceSample": | ||
"""Create a VoiceSample from a prompt only.""" | ||
return VoiceSample(messages_from_prompt(prompt), None) | ||
|
||
@staticmethod | ||
def from_prompt_and_file(prompt: str, path: str) -> "VoiceSample": | ||
"""Create a VoiceSample from a prompt and an audio file.""" | ||
return VoiceSample(messages_from_prompt(prompt), audio_from_file(path)) | ||
|
||
@staticmethod | ||
def from_prompt_and_buf(prompt: str, buf: bytes) -> "VoiceSample": | ||
"""Create a VoiceSample from a prompt and an encoded audio buffer.""" | ||
return VoiceSample(messages_from_prompt(prompt), audio_from_buf(buf)) | ||
|
||
@staticmethod | ||
def from_prompt_and_raw( | ||
prompt: str, buf: np.ndarray, sample_rate: int | ||
) -> "VoiceSample": | ||
"""Create a VoiceSample from a prompt and raw audio data with sample rate.""" | ||
# Keep in native sample rate; we'll resample later if needed. | ||
return VoiceSample(messages_from_prompt(prompt), buf, sample_rate) | ||
|
||
def to_json(self) -> Dict[str, Any]: | ||
"""Convert to JSON format; audio is written as base64ed WAV.""" | ||
obj: Dict[str, Any] = {"messages": self.messages} | ||
if self.audio is not None: | ||
obj["audio"] = audio_to_wav_base64(self.audio, self.sample_rate) | ||
return obj | ||
|
||
def __post_init__(self): | ||
"""Ensure audio is float32 PCM.""" | ||
if self.audio is not None: | ||
if self.audio.dtype == np.float64: | ||
self.audio = self.audio.astype(np.float32) | ||
elif self.audio.dtype == np.int16: | ||
self.audio = self.audio.astype(np.float32) / np.float32(32768.0) | ||
elif self.audio.dtype == np.int32: | ||
self.audio = self.audio.astype(np.float32) / np.float32(2147483648.0) | ||
assert ( | ||
self.audio.dtype == np.float32 | ||
), f"Unexpected audio dtype: {self.audio.dtype}" | ||
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}" | ||
|
||
def add_past_messages(self, past_messages: List[Dict[str, str]]): | ||
self.messages = past_messages + self.messages | ||
|
||
messages: List[Dict[str, str]] | ||
"""List of messages, each with a "role" and "content" field.""" | ||
audio: Optional[npt.NDArray[np.float32]] = None | ||
"""Audio data as float32 PCM @ `sample_rate`.""" | ||
sample_rate: int = SAMPLE_RATE | ||
"""Audio sample rate in Hz.""" | ||
audio_transcript: Optional[str] = None | ||
"""For evaluations, the known transcript of the audio.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from ultravox.data import data_sample | ||
|
||
|
||
def _create_sine_wave( | ||
freq: int = 440, | ||
duration: float = 1.0, | ||
sample_rate: int = 16000, | ||
amplitude: float = 0.1, | ||
target_dtype: str = "float32", | ||
) -> Union[ | ||
np.typing.NDArray[np.float32], | ||
np.typing.NDArray[np.float64], | ||
np.typing.NDArray[np.int16], | ||
np.typing.NDArray[np.int32], | ||
]: | ||
t = np.arange(sample_rate * duration, dtype=np.float32) / sample_rate | ||
wave = amplitude * np.sin(2 * np.pi * freq * t) | ||
match target_dtype: | ||
case "int16": | ||
wave = np.int16(wave * 32767) | ||
case "int32": | ||
wave = np.int32(wave * 2147483647) | ||
case "float32": | ||
# Already float32, nothing needed. | ||
pass | ||
case "float64": | ||
wave = wave.astype(np.float64) | ||
case _: | ||
raise ValueError(f"Unsupported dtype: {target_dtype}") | ||
return wave | ||
|
||
|
||
def _create_and_validate_sample(target_dtype: str = "float32"): | ||
# Create a sine wave at 440 Hz with a duration of 1.0 second, sampled at 16 | ||
# kHz, with an amplitude of 0.1, and the specified dtype. | ||
array = _create_sine_wave(target_dtype=target_dtype) | ||
sample = data_sample.VoiceSample.from_prompt_and_raw( | ||
"Transcribe\n<|audio|>", array, 16000 | ||
) | ||
assert sample.sample_rate == 16000 | ||
assert sample.audio is not None, "sample.audio should not be None" | ||
assert len(sample.audio) == 16000 | ||
assert sample.audio.dtype == np.float32 | ||
assert sample.messages == [ | ||
{"role": "user", "content": "Transcribe\n<|audio|>"}, | ||
] | ||
# Serialize and deserialize the sample. | ||
json = sample.to_json() | ||
sample2 = data_sample.VoiceSample.from_json(json) | ||
assert sample2.sample_rate == sample.sample_rate | ||
assert sample2.audio is not None, "sample2.audio should not be None" | ||
assert len(sample2.audio) == len(sample.audio) | ||
assert sample2.audio.dtype == sample.audio.dtype | ||
assert sample2.messages == sample.messages | ||
assert np.allclose(sample2.audio, sample.audio, rtol=0.0001, atol=0.0001) | ||
|
||
|
||
def test_create_sample__int16(): | ||
_create_and_validate_sample("int16") | ||
|
||
|
||
def test_create_sample__int32(): | ||
_create_and_validate_sample("int32") | ||
|
||
|
||
def test_create_sample__float32(): | ||
_create_and_validate_sample("float32") | ||
|
||
|
||
def test_create_sample__float64(): | ||
_create_and_validate_sample("float64") | ||
|
||
|
||
def test_create_sample__raises_on_unsupported_dtype(): | ||
with pytest.raises(AssertionError): | ||
array = np.ndarray(shape=(16000,), dtype=np.uint8) | ||
_ = data_sample.VoiceSample.from_prompt_and_raw( | ||
"Transcribe\n<|audio|>", array, 16000 | ||
) |
Oops, something went wrong.