Skip to content

Commit

Permalink
Remove superfluous imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyuan Li committed Feb 5, 2024
1 parent 6df000b commit b734069
Showing 1 changed file with 37 additions and 50 deletions.
87 changes: 37 additions & 50 deletions art/estimators/speech_recognition/pytorch_icefall.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@
| Repository link: https://github.com/k2-fsa/icefall/tree/master
"""
import ast
from argparse import Namespace
import logging
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import List, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np
import torch

from art import config
from art.estimators.pytorch import PyTorchEstimator
from art.estimators.speech_recognition.speech_recognizer import SpeechRecognizerMixin, PytorchSpeechRecognizerMixin
from art.utils import get_file

if TYPE_CHECKING:
# pylint: disable=C0412
Expand All @@ -56,7 +52,7 @@ class PyTorchIcefall(PytorchSpeechRecognizerMixin, SpeechRecognizerMixin, PyTorc
from pathlib import Path
import k2

estimator_params = PyTorchEstimator.estimator_params + ["icefall_config_filepath"]
estimator_params = PyTorchEstimator.estimator_params

def __init__(
self,
Expand All @@ -67,7 +63,7 @@ def __init__(
preprocessing: "PREPROCESSING_TYPE" = None,
device_type: str = "gpu",
verbose: bool = True,
model_ensemble = None,
model_ensemble=None,
):
"""
Initialization of an instance PyTorchIcefall
Expand Down Expand Up @@ -110,21 +106,15 @@ def __init__(
if self.postprocessing_defences is not None: # pragma: no cover
raise ValueError("This estimator does not support `postprocessing_defences`.")

# Set cpu/gpu device
self._device = torch.device("cpu")
if torch.cuda.is_available():
self._device = torch.device("cuda", 0)

# load_model_ensemble
if model_ensemble is not None:
self.params = model_ensemble['params']
self.transducer_model = model_ensemble['model']
self.word2ids = model_ensemble['word2ids']
self.get_id2word = model_ensemble['get_id2word']
self.params = model_ensemble["params"]
self.transducer_model = model_ensemble["model"]
self.word2ids = model_ensemble["word2ids"]
self.get_id2word = model_ensemble["get_id2word"]

self.transducer_model.to(self.device)


def predict(self, x: np.ndarray, batch_size: int = 1, **kwargs) -> np.ndarray:
"""
Perform prediction for a batch of inputs.
Expand Down Expand Up @@ -156,7 +146,7 @@ def predict(self, x: np.ndarray, batch_size: int = 1, **kwargs) -> np.ndarray:
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))

for sample_index in range(num_batch):
wav = x_preprocessed[sample_index] # np.array, len = wav len
wav = x_preprocessed[sample_index] # np.array, len = wav len

# extract features
x, _, _ = self.transform_model_input(x=torch.tensor(wav))
Expand All @@ -165,7 +155,7 @@ def predict(self, x: np.ndarray, batch_size: int = 1, **kwargs) -> np.ndarray:
encoder_out, encoder_out_lens = self.transducer_model.encoder(x=x, x_lens=shape)
hyp = greedy_search(model=self.transducer_model, encoder_out=encoder_out, id2word=self.get_id2word)
decoded_output.append(hyp)

return np.concatenate(decoded_output)

def loss_gradient(self, x, y: np.ndarray, **kwargs) -> np.ndarray:
Expand Down Expand Up @@ -217,34 +207,31 @@ def transform_model_input(self, x, y=None, compute_gradient=False):

from dataclasses import dataclass, asdict

@dataclass
class FbankConfig:
params = {
# Spectogram-related part
dither: float = 0.0
window_type: str = "povey"
"dither": 0.0,
"window_type": "povey",
"sample_frequency": 16000,
"snip_edges": False,
"num_mel_bins": 23,
# Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
frame_length: float = 0.025
frame_shift: float = 0.01
remove_dc_offset: bool = True
round_to_power_of_two: bool = True
energy_floor: float = 1e-10
min_duration: float = 0.0
preemphasis_coefficient: float = 0.97
raw_energy: bool = True

"frame_length": 25.0,
"frame_shift": 10.0,
"remove_dc_offset": True,
"round_to_power_of_two": True,
"energy_floor": 1e-10,
"min_duration": 0.0,
"preemphasis_coefficient": 0.97,
"raw_energy": True,
# Fbank-related part
low_freq: float = 20.0
high_freq: float = -400.0
num_mel_bins: int = 40
use_energy: bool = False
vtln_low: float = 100.0
vtln_high: float = -500.0
vtln_warp: float = 1.0

params = asdict(FbankConfig())
params.update({"sample_frequency": 16000, "snip_edges": False, "num_mel_bins": 23})
params["frame_shift"] *= 1000.0
params["frame_length"] *= 1000.0
"low_freq": 20.0,
"high_freq": -400.0,
"num_mel_bins": 40,
"use_energy": False,
"vtln_low: float": 100.0,
"vtln_high: float": -500.0,
"vtln_warp: float": 1.0,
}

feature_list = []
num_frames = []
Expand All @@ -254,7 +241,7 @@ class FbankConfig:
isnan = torch.isnan(x[i])
nisnan = torch.sum(isnan).item()
if nisnan > 0:
logging.info("input isnan={}/{} {}".format(nisnan, x[i].shape, x[i][isnan], torch.max(torch.abs(x[i]))))
logging.info(f"input isnan={nisnan}/{x[i][isnan]} {torch.max(torch.abs(x[i]))}")

xx = x[i]
xx = xx.to(self._device)
Expand All @@ -264,9 +251,9 @@ class FbankConfig:
num_frames.append(feat_i.shape[1])

indices = sorted(range(len(feature_list)), key=lambda i: feature_list[i].shape[1], reverse=True)
indices = torch.LongTensor(indices)
num_frames = torch.IntTensor([num_frames[idx] for idx in indices])
start_frames = torch.zeros(len(x), dtype=torch.int)
indices = torch.tensor(indices, dtype=torch.int64, device=self._device)
num_frames = torch.tensor([num_frames[idx] for idx in indices], dtype=torch.int32, device=self._device)
start_frames = torch.zeros(len(x), dtype=torch.int32, device=self._device)

supervisions["sequence_idx"] = indices.int()
supervisions["start_frame"] = start_frames
Expand Down Expand Up @@ -310,7 +297,7 @@ def input_shape(self) -> Tuple[int, ...]:
return self._input_shape # type: ignore

@property
def model(self):
def model(self) -> "torch.nn.Module":
"""
Get current model.
Expand Down Expand Up @@ -366,4 +353,4 @@ def compute_loss_and_decoded_output(
hyp = greedy_search(model=self.transducermodel, encoder_out=encoder_out, id2word=self.get_id2word)
decoded_output.append(hyp)

return np.concatenate(decoded_output)
return loss, np.concatenate(decoded_output)

0 comments on commit b734069

Please sign in to comment.