Skip to content

Commit

Permalink
added pesq and stoi for reconstruction performance evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Sep 8, 2024
1 parent c43977e commit 1e65a97
Showing 1 changed file with 48 additions and 4 deletions.
52 changes: 48 additions & 4 deletions egs/libritts/CODEC/encodec/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List
from statistics import mean
from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from codec_datamodule import LibriTTSCodecDataModule
from pesq import pesq
from pystoi import stoi
from scipy import signal
from torch import nn
from train import get_model, get_params

Expand Down Expand Up @@ -105,12 +109,25 @@ def remove_encodec_weight_norm(model) -> None:
remove_weight_norm(decoder._modules[key].conv.conv)


def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float:
"""Compute PESQ score between reference and generated audio."""
DEFAULT_SAMPLING_RATE = 16000
ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE)
deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE)
return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb")


def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float:
"""Compute STOI score between reference and generated audio."""
return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False)


def infer_dataset(
dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict,
model: nn.Module,
) -> None:
) -> Tuple[float, float]:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Expand All @@ -123,6 +140,9 @@ def infer_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
The average PESQ and STOI scores.
"""

# Background worker save audios to disk.
Expand Down Expand Up @@ -150,6 +170,9 @@ def _save_worker(
num_cuts = 0
log_interval = 5

pesq_wb_scores = []
stoi_scores = []

try:
num_batches = len(dl)
except TypeError:
Expand All @@ -169,6 +192,25 @@ def _save_worker(
)
audio_hats = audio_hats.squeeze(1).cpu()

for cut_id, audio, audio_hat, audio_len in zip(
cut_ids, audios, audio_hats, audio_lens
):
try:
pesq_wb = compute_pesq(
ref_wav=audio[:audio_len].numpy(),
gen_wav=audio_hat[:audio_len].numpy(),
)
pesq_wb_scores.append(pesq_wb)
except Exception as e:
logging.error(f"Error while computing PESQ for cut {cut_id}: {e}")

stoi_score = compute_stoi(
ref_wav=audio[:audio_len].numpy(),
gen_wav=audio_hat[:audio_len].numpy(),
sampling_rate=params.sampling_rate,
)
stoi_scores.append(stoi_score)

futures.append(
executor.submit(
_save_worker,
Expand All @@ -192,6 +234,7 @@ def _save_worker(
# return results
for f in futures:
f.result()
return mean(pesq_wb_scores), mean(stoi_scores)


@torch.no_grad()
Expand Down Expand Up @@ -285,12 +328,13 @@ def main():

logging.info(f"Processing {subset} set, saving to {save_wav_dir}")

infer_dataset(
pesq_wb, stoi = infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
)
logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}")

logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")
Expand Down

0 comments on commit 1e65a97

Please sign in to comment.