forked from zhaoforever/nn-irm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmono_enhance.py
69 lines (57 loc) · 2.98 KB
/
mono_enhance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
# coding=utf-8
import argparse
import librosa
import glob
import os
import pickle
from model import MaskComputer, IRMEstimator
from compute_mask import apply_cmvn, stft, nfft
from dataset import splice_frames
import numpy as np
MAX_INT16 = np.iinfo(np.int16).max
def run(args):
num_bins = nfft(args.frame_length)
context = args.left_context + args.right_context + 1
estimator = IRMEstimator(int(num_bins / 2 + 1), nframes=context)
computer = MaskComputer(estimator, args.model_state)
sub_dir = os.path.basename(os.path.abspath(args.noisy_dir))
dst_dir = os.path.join(args.dumps_dir, sub_dir)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
for noisy_wave in glob.glob('{}/*.wav'.format(args.noisy_dir)):
name = os.path.basename(noisy_wave)
# f x t
noisy_specs = stft(noisy_wave, 16000, args.frame_length, args.frame_shift, 'hamming')
input_specs = splice_frames(apply_cmvn(np.abs(noisy_specs.transpose())), args.left_context, args.right_context)
mask_n = computer.compute_masks(input_specs)
with open('{}/{}.irm'.format(dst_dir, name.split('.')[0]), 'wb') as f:
pickle.dump(mask_n, f)
if args.write_wav:
clean_specs = noisy_specs * (1 - mask_n).transpose()
clean_samples = librosa.istft(clean_specs, args.frame_shift, args.frame_length, 'hamming')
# print('dumps to {}/{}'.format(dst_dir, name))
# NOTE: for kaldi, must write in np.int16
librosa.output.write_wav('{}/{}'.format(dst_dir, name), \
(clean_samples / np.max(np.abs(clean_samples)) * MAX_INT16).astype(np.int16), 16000)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Command to enhance mono-channel wave")
parser.add_argument('noisy_dir', type=str,
help="directory of noisy wave")
parser.add_argument('model_state', type=str,
help="paramters of mask estimator to be used")
parser.add_argument('--dumps-dir', type=str, default='enhan_mask', dest='dumps_dir',
help="directory for dumping enhanced wave")
parser.add_argument('--frame-length', type=int, default=512, dest='frame_length',
help="frame length for STFT/iSTFT")
parser.add_argument('--frame-shift', type=int, default=256, dest='frame_shift',
help="frame shift for STFT/iSTFT")
parser.add_argument('--left-context', type=int, dest="left_context", default=3,
help="left context of inputs for neural networks")
parser.add_argument('--right-context', type=int, dest="right_context", default=3,
help="right context of inputs for neural networks")
parser.add_argument('--write-wav', action='store_true', dest="write_wav", default=False,
help="weather write out enhanced wave")
args = parser.parse_args()
run(args)