Skip to content

Commit

Permalink
Merge pull request #5 from sensein/enh/organize
Browse files Browse the repository at this point in the history
Generate feature file (and other fixes)
  • Loading branch information
satra authored Mar 15, 2024
2 parents f93ef65 + a10bf75 commit 4287359
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 158 deletions.
314 changes: 176 additions & 138 deletions docs/b2ai_script.ipynb

Large diffs are not rendered by default.

39 changes: 36 additions & 3 deletions src/b2aiprep/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
import os
from pathlib import Path

import click

from .process import to_features, verify_speaker_from_files


@click.command()
@click.group()
def main():
"""Example script."""
click.echo("Will process some files later")
pass


@main.command()
@click.argument("filename", type=click.Path(exists=True))
@click.argument("subject", type=str)
@click.argument("task", type=str)
@click.option("--outdir", type=click.Path(), default=os.getcwd(), show_default=True)
@click.option("--n_mels", type=int, default=20, show_default=True)
@click.option("--n_coeff", type=int, default=20, show_default=True)
@click.option("--compute_deltas/--no-compute_deltas", default=True, show_default=True)
def convert(filename, subject, task, outdir, n_mels, n_coeff, compute_deltas):
to_features(
filename,
subject,
task,
outdir=Path(outdir),
n_mels=n_mels,
n_coeff=n_coeff,
compute_deltas=compute_deltas,
)


@main.command()
@click.argument("file1", type=click.Path(exists=True))
@click.argument("file2", type=click.Path(exists=True))
@click.argument("model", type=str)
def verify(file1, file2, model):
score, prediction = verify_speaker_from_files(file1, file2, model=model)
print(f"Score: {float(score):.2f} Prediction: {bool(prediction)}")
85 changes: 68 additions & 17 deletions src/b2aiprep/process.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
"""Functions to prepare acoustic data for the Bridge2AI voice project."""

import numpy as np
import os
import typing as ty
from hashlib import md5
from pathlib import Path

import speechbrain.processing.features as spf
import torch
import torchaudio
from scipy.signal import butter
from scipy import signal
from speechbrain.augment.time_domain import Resample
from speechbrain.dataio.dataio import read_audio, read_audio_info
from speechbrain.inference.speaker import EncoderClassifier


class Audio:
def __init__(self, signal: torch.tensor, sample_rate: float):
def __init__(self, signal: torch.tensor, sample_rate: int) -> None:
"""Initialize audio object
:param signal: is a Torch tensor
:param sample_rate: is a float
"""
self.signal = signal
self.sample_rate = sample_rate
self.sample_rate = int(sample_rate)

@classmethod
def from_file(cls, filename: str, channel: int = 0):
def from_file(cls, filename: Path, channel: int = 0) -> "Audio":
"""Load audio file
If the file contains more than one channel of audio,
Expand All @@ -31,11 +35,13 @@ def from_file(cls, filename: str, channel: int = 0):
"""
signal = read_audio(filename)
meta = read_audio_info(filename)
if signal.shape[-1] > 1:
if len(signal.shape) > 1 and signal.shape[-1] > 1:
signal = signal[:, [channel]]
else:
signal = signal[:, None]
return cls(signal, meta.sample_rate)

def to_16khz(self):
def to_16khz(self) -> "Audio":
"""Resample audio to 16kHz and return new Audio object
TODO: The default resampler does a poor job of taking care of
Expand All @@ -45,6 +51,27 @@ def to_16khz(self):
return Audio(resampler(self.signal.unsqueeze(0)).squeeze(0), 16000)


def embed_speaker(audio: Audio, model: str) -> torch.tensor:
"""Compute the speaker embedding of the audio signal"""
classifier = EncoderClassifier.from_hparams(source=model)
embeddings = classifier.encode_batch(audio.signal.T)
return embeddings.squeeze()


def verify_speaker(audio1: Audio, audio2: Audio, model: str) -> ty.Tuple[float, float]:
from speechbrain.inference.speaker import SpeakerRecognition

verification = SpeakerRecognition.from_hparams(source=model)
score, prediction = verification.verify_batch(audio1.signal.T, audio2.signal.T)
return score, prediction


def verify_speaker_from_files(file1: Path, file2: Path, model: str) -> ty.Tuple[float, float]:
audio1 = Audio.from_file(file1)
audio2 = Audio.from_file(file2)
return verify_speaker(audio1, audio2, model)


def specgram(
audio: Audio, win_length: int = 25, hop_lenth: int = 10, log: bool = False
) -> torch.tensor:
Expand Down Expand Up @@ -87,13 +114,37 @@ def MFCC(
return features.squeeze()


def resample_iir(audio: Audio, lowcut: float, new_sample_rate: int, order: int = 5):
def resample_iir(audio: Audio, lowcut: float, new_sample_rate: int, order: int = 4) -> Audio:
"""Resample audio using IIR filter"""
b, a = butter(order, lowcut, btype="low", fs=audio.sample_rate)
filtered = torchaudio.functional.filtfilt(
audio.signal.unsqueeze(0),
a_coeffs=torch.tensor(a.astype(np.float32)),
b_coeffs=torch.tensor(b.astype(np.float32)),
)
resampler = Resample(orig_freq=audio.sample_rate, new_freq=16000)
return Audio(resampler(filtered).squeeze(0), new_sample_rate)
sos = signal.butter(order, lowcut, btype="low", output="sos", fs=new_sample_rate)
filtered = torch.from_numpy(signal.sosfiltfilt(sos, audio.signal.squeeze()).copy()).float()
resampler = Resample(orig_freq=audio.sample_rate, new_freq=new_sample_rate)
return Audio(resampler(filtered.unsqueeze(0)).squeeze(0), new_sample_rate)


def to_features(
filename: Path,
subject: str,
task: str,
outdir: Path = Path(os.getcwd()),
n_mels: int = 20,
n_coeff: int = 20,
compute_deltas: bool = True,
) -> ty.Tuple[dict, Path]:
with open(filename, "rb") as f:
md5sum = md5(f.read()).hexdigest()
audio = Audio.from_file(filename)
audio = audio.to_16khz()
features = specgram(audio)
features_melfilterbank = melfilterbank(features, n_mels=n_mels)
features_mfcc = MFCC(features_melfilterbank, n_coeff=n_coeff, compute_deltas=compute_deltas)
features = {
"specgram": features,
"melfilterbank": features_melfilterbank,
"mfcc": features_mfcc,
"sample_rate": audio.sample_rate,
"checksum": md5sum,
}
outfile = outdir / f"sub-{subject}_task-{task}_md5-{md5sum}_features.pt"
torch.save(features, outfile)
return features, outfile

0 comments on commit 4287359

Please sign in to comment.