Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🍎 Add MLX support #103

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ ffmpeg
ffmpeg-python
pre-commit
fire
tqdm
more-itertools
tiktoken
huggingface_hub
scipy
mlx>=0.11
numba
Empty file.
169 changes: 169 additions & 0 deletions whisperplus/apple_pipeline/mlx_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright © 2023 Apple Inc.

import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Union

import mlx.core as mx
import numpy as np

# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary.

Parameters
----------
file: str
The audio file to open

sr: int
The sample rate to resample the audio if necessary

Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""Pad or trim the audio array to N_SAMPLES, as expected by the encoder."""
if array.shape[axis] > length:
sl = [slice(None)] * array.ndim
sl[axis] = slice(0, length)
array = array[tuple(sl)]

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = mx.pad(array, pad_widths)

return array


@lru_cache(maxsize=None)
def mel_filters(n_mels: int) -> mx.array:
"""
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa
dependency; saved using:

np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400,
n_mels=80), mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), )
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
return mx.load(filename)[f"mel_{n_mels}"]


@lru_cache(maxsize=None)
def hanning(size):
return mx.array(np.hanning(size + 1)[:-1])


def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="reflect"):
if nfft is None:
nfft = nperseg
if noverlap is None:
noverlap = nfft // 4

def _pad(x, padding, pad_mode="constant"):
if pad_mode == "constant":
return mx.pad(x, [(padding, padding)])
elif pad_mode == "reflect":
prefix = x[1:padding + 1][::-1]
suffix = x[-(padding + 1):-1][::-1]
return mx.concatenate([prefix, x, suffix])
else:
raise ValueError(f"Invalid pad_mode {pad_mode}")

padding = nperseg // 2
x = _pad(x, padding, pad_mode)

strides = [noverlap, 1]
t = (x.size - nperseg + noverlap) // noverlap
shape = [t, nfft]
x = mx.as_strided(x, shape=shape, strides=strides)
return mx.fft.rfft(x * window)


def log_mel_spectrogram(
audio: Union[str, np.ndarray],
n_mels: int = 80,
padding: int = 0,
):
"""
Compute the log-Mel spectrogram of.

Parameters
----------
audio: Union[str, np.ndarray, mx.array], shape = (*)
The path to audio or either a NumPy or mlx array containing the audio waveform in 16 kHz

n_mels: int
The number of Mel-frequency filters, only 80 is supported

padding: int
Number of zero samples to pad to the right

Returns
-------
mx.array, shape = (80, n_frames)
An array that contains the Mel spectrogram
"""
device = mx.default_device()
mx.set_default_device(mx.cpu)
if isinstance(audio, str):
audio = load_audio(audio)
elif not isinstance(audio, mx.array):
audio = mx.array(audio)

if padding > 0:
audio = mx.pad(audio, (0, padding))
window = hanning(N_FFT)
freqs = stft(audio, window, nperseg=N_FFT, noverlap=HOP_LENGTH)
magnitudes = freqs[:-1, :].abs().square()

filters = mel_filters(n_mels)
mel_spec = magnitudes @ filters.T

log_spec = mx.maximum(mel_spec, 1e-10).log10()
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
mx.set_default_device(device)
return log_spec
126 changes: 126 additions & 0 deletions whisperplus/apple_pipeline/mlx_whisper/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import time

import mlx.core as mx
from mlx_whisper import audio, decoding, load_models, transcribe

audio_file = "mlx_whisper/assets/ls_test.flac"


def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument(
"--mlx-dir",
type=str,
default="mlx_models",
help="The folder of MLX models",
)
parser.add_argument(
"--all",
action="store_true",
help="Use all available models, i.e. tiny,small,medium,large-v3",
)
parser.add_argument(
"-m",
"--models",
type=str,
help="Specify models as a comma-separated list (e.g., tiny,small,medium)",
)
return parser.parse_args()


def timer(fn, *args):
for _ in range(5):
fn(*args)

num_its = 10

tic = time.perf_counter()
for _ in range(num_its):
fn(*args)
toc = time.perf_counter()
return (toc - tic) / num_its


def feats(n_mels: int = 80):
data = audio.load_audio(audio_file)
data = audio.pad_or_trim(data)
mels = audio.log_mel_spectrogram(data, n_mels)
mx.eval(mels)
return mels


def model_forward(model, mels, tokens):
logits = model(mels, tokens)
mx.eval(logits)
return logits


def decode(model, mels):
return decoding.decode(model, mels)


def everything(model_path):
return transcribe(audio_file, path_or_hf_repo=model_path)


if __name__ == "__main__":
args = parse_arguments()
if args.all:
models = ["tiny", "small", "medium", "large-v3"]
elif args.models:
models = args.models.split(",")
else:
models = ["tiny"]

print("Selected models:", models)

feat_time = timer(feats)
print(f"\nFeature time {feat_time:.3f}")

for model_name in models:
model_path = f"mlx-community/whisper-{model_name}-mlx"
print(f"\nModel: {model_name.upper()}")
tokens = mx.array(
[
50364,
1396,
264,
665,
5133,
23109,
25462,
264,
6582,
293,
750,
632,
42841,
292,
370,
938,
294,
4054,
293,
12653,
356,
50620,
50620,
23563,
322,
3312,
13,
50680,
],
mx.int32,
)[None]
model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16)
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything, model_path)
print(f"Everything time {everything_time:.3f}")
print(f"\n{'-----' * 10}\n")
Loading
Loading