Skip to content

Commit

Permalink
Break up datasets.py (#141)
Browse files Browse the repository at this point in the history
* Break up datasets.py

This splits out types.py and registry.py to move the list of pre-defined datasets to its own file and avoid circular refs.

An __all__ import is used to minimize changes to surrounding code.

* sr

* cr

* merge

* restore typing
  • Loading branch information
juberti authored Oct 28, 2024
1 parent 041c4fe commit 487e939
Show file tree
Hide file tree
Showing 20 changed files with 808 additions and 754 deletions.
17 changes: 17 additions & 0 deletions ultravox/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ultravox.data.data_sample import *
from ultravox.data.datasets import *
from ultravox.data.registry import *
from ultravox.data.types import *

__all__ = [
"SizedIterableDataset",
"EmptyDataset",
"InterleaveDataset",
"Range",
"Dataproc",
"VoiceDataset",
"VoiceDatasetArgs",
"VoiceSample",
"create_dataset",
"register_datasets",
]
112 changes: 112 additions & 0 deletions ultravox/data/data_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import base64
import dataclasses
import io
from typing import Any, Dict, List, Optional

import librosa
import numpy as np
import soundfile as sf
from numpy import typing as npt

SAMPLE_RATE = 16000


def audio_from_file(path: str) -> np.ndarray:
"""Load audio from a file, converting to float32 PCM @ 16 kHz."""
audio, _ = librosa.load(path, sr=SAMPLE_RATE)
assert audio.dtype == np.float32
return audio


def audio_from_buf(buf: bytes) -> np.ndarray:
"""Load audio from a buffer, converting to float32 PCM @ 16 kHz."""
audio, _ = librosa.load(io.BytesIO(buf), sr=SAMPLE_RATE)
assert audio.dtype == np.float32
return audio


def audio_to_wav(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> bytes:
"""Convert audio to WAV format, 16-bit PCM @ 16 kHz."""
assert audio.dtype == np.float32
with io.BytesIO() as buf:
sf.write(buf, audio, sample_rate, format="WAV", subtype="PCM_16")
return buf.getvalue()


def audio_to_wav_base64(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str:
"""Convert audio to a base64-encoded WAV file."""
return base64.b64encode(audio_to_wav(audio, sample_rate)).decode("utf-8")


def audio_to_data_uri(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str:
"""Convert audio to a data URI."""
return f"data:audio/wav;base64,{audio_to_wav_base64(audio, sample_rate)}"


def messages_from_prompt(prompt: str) -> List[Dict[str, str]]:
return [{"role": "user", "content": prompt}]


@dataclasses.dataclass
class VoiceSample:
@staticmethod
def from_json(data: Dict[str, Any]) -> "VoiceSample":
"""Convert from JSON format; audio is expected as base64ed WAV."""
bytes = base64.b64decode(data["audio"])
return VoiceSample(data["messages"], audio_from_buf(bytes))

@staticmethod
def from_prompt(prompt: str) -> "VoiceSample":
"""Create a VoiceSample from a prompt only."""
return VoiceSample(messages_from_prompt(prompt), None)

@staticmethod
def from_prompt_and_file(prompt: str, path: str) -> "VoiceSample":
"""Create a VoiceSample from a prompt and an audio file."""
return VoiceSample(messages_from_prompt(prompt), audio_from_file(path))

@staticmethod
def from_prompt_and_buf(prompt: str, buf: bytes) -> "VoiceSample":
"""Create a VoiceSample from a prompt and an encoded audio buffer."""
return VoiceSample(messages_from_prompt(prompt), audio_from_buf(buf))

@staticmethod
def from_prompt_and_raw(
prompt: str, buf: np.ndarray, sample_rate: int
) -> "VoiceSample":
"""Create a VoiceSample from a prompt and raw audio data with sample rate."""
# Keep in native sample rate; we'll resample later if needed.
return VoiceSample(messages_from_prompt(prompt), buf, sample_rate)

def to_json(self) -> Dict[str, Any]:
"""Convert to JSON format; audio is written as base64ed WAV."""
obj: Dict[str, Any] = {"messages": self.messages}
if self.audio is not None:
obj["audio"] = audio_to_wav_base64(self.audio, self.sample_rate)
return obj

def __post_init__(self):
"""Ensure audio is float32 PCM."""
if self.audio is not None:
if self.audio.dtype == np.float64:
self.audio = self.audio.astype(np.float32)
elif self.audio.dtype == np.int16:
self.audio = self.audio.astype(np.float32) / np.float32(32768.0)
elif self.audio.dtype == np.int32:
self.audio = self.audio.astype(np.float32) / np.float32(2147483648.0)
assert (
self.audio.dtype == np.float32
), f"Unexpected audio dtype: {self.audio.dtype}"
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}"

def add_past_messages(self, past_messages: List[Dict[str, str]]):
self.messages = past_messages + self.messages

messages: List[Dict[str, str]]
"""List of messages, each with a "role" and "content" field."""
audio: Optional[npt.NDArray[np.float32]] = None
"""Audio data as float32 PCM @ `sample_rate`."""
sample_rate: int = SAMPLE_RATE
"""Audio sample rate in Hz."""
audio_transcript: Optional[str] = None
"""For evaluations, the known transcript of the audio."""
84 changes: 84 additions & 0 deletions ultravox/data/data_sample_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Union

import numpy as np
import pytest

from ultravox.data import data_sample


def _create_sine_wave(
freq: int = 440,
duration: float = 1.0,
sample_rate: int = 16000,
amplitude: float = 0.1,
target_dtype: str = "float32",
) -> Union[
np.typing.NDArray[np.float32],
np.typing.NDArray[np.float64],
np.typing.NDArray[np.int16],
np.typing.NDArray[np.int32],
]:
t = np.arange(sample_rate * duration, dtype=np.float32) / sample_rate
wave = amplitude * np.sin(2 * np.pi * freq * t)
match target_dtype:
case "int16":
wave = np.int16(wave * 32767)
case "int32":
wave = np.int32(wave * 2147483647)
case "float32":
# Already float32, nothing needed.
pass
case "float64":
wave = wave.astype(np.float64)
case _:
raise ValueError(f"Unsupported dtype: {target_dtype}")
return wave


def _create_and_validate_sample(target_dtype: str = "float32"):
# Create a sine wave at 440 Hz with a duration of 1.0 second, sampled at 16
# kHz, with an amplitude of 0.1, and the specified dtype.
array = _create_sine_wave(target_dtype=target_dtype)
sample = data_sample.VoiceSample.from_prompt_and_raw(
"Transcribe\n<|audio|>", array, 16000
)
assert sample.sample_rate == 16000
assert sample.audio is not None, "sample.audio should not be None"
assert len(sample.audio) == 16000
assert sample.audio.dtype == np.float32
assert sample.messages == [
{"role": "user", "content": "Transcribe\n<|audio|>"},
]
# Serialize and deserialize the sample.
json = sample.to_json()
sample2 = data_sample.VoiceSample.from_json(json)
assert sample2.sample_rate == sample.sample_rate
assert sample2.audio is not None, "sample2.audio should not be None"
assert len(sample2.audio) == len(sample.audio)
assert sample2.audio.dtype == sample.audio.dtype
assert sample2.messages == sample.messages
assert np.allclose(sample2.audio, sample.audio, rtol=0.0001, atol=0.0001)


def test_create_sample__int16():
_create_and_validate_sample("int16")


def test_create_sample__int32():
_create_and_validate_sample("int32")


def test_create_sample__float32():
_create_and_validate_sample("float32")


def test_create_sample__float64():
_create_and_validate_sample("float64")


def test_create_sample__raises_on_unsupported_dtype():
with pytest.raises(AssertionError):
array = np.ndarray(shape=(16000,), dtype=np.uint8)
_ = data_sample.VoiceSample.from_prompt_and_raw(
"Transcribe\n<|audio|>", array, 16000
)
Loading

0 comments on commit 487e939

Please sign in to comment.