-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathirm_oracle_phase.py
102 lines (81 loc) · 3.06 KB
/
irm_oracle_phase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2018 Microsoft Research Aisa (author: Ke Wang)
# 2019 Northwestern Polytechnical University (author: Ke Wang)
from __future__ import absolute_import
from __future__ import division
from __future__ import absolute_import
import os
import sys
sys.path.append(os.path.dirname(sys.path[0]) + '/utils')
from evaluate.eval_sdr import eval_sdr
from evaluate.eval_sdr_sources import eval_sdr_sources
from evaluate.eval_si_sdr import eval_si_sdr
from sigproc.dsp import overlap_and_add, wavread, wavwrite
from sigproc.mask import apply_mask, irm
from sigproc.spectrum import spectrum
mix_wav_scp = 'data/tt/mix.scp'
s1_wav_scp = 'data/tt/s1.scp'
s2_wav_scp = 'data/tt/s2.scp'
ori_dir = 'data/2speakers/wav8k/min/tt'
recons_dir = 'exp/irm_oracle_phase/wav'
sample_rate = 8000
frame_length = 32
frame_shift = 8
window_type = 'hanning'
preemphasis = 0.0
square_root_window = True
# do not change
use_log = False
use_power = False
# do not change
if not os.path.exists(recons_dir):
os.makedirs(recons_dir)
f_mix_wav = open(mix_wav_scp, "r")
f_s1_wav = open(s1_wav_scp, "r")
f_s2_wav = open(s2_wav_scp, "r")
mix_wav = f_mix_wav.readlines()
s1_wav = f_s1_wav.readlines()
s2_wav = f_s2_wav.readlines()
assert len(mix_wav) == len(s1_wav)
assert len(s1_wav) == len(s2_wav)
def readwav(line):
key, path = line.strip().split()
wav, frame_rate = wavread(path)
return key, wav
def compute_spectrum(line):
key, wav = readwav(line)
feat = spectrum(wav, sample_rate, frame_length, frame_shift,
window_type, preemphasis, use_log, use_power,
square_root_window)
return key, feat
for i in range(len(mix_wav)):
key_mix, feat_mix = compute_spectrum(mix_wav[i])
key_s1, feat_s1 = compute_spectrum(s1_wav[i])
key_s2, feat_s2 = compute_spectrum(s2_wav[i])
assert key_mix == key_s1 and key_s1 == key_s2
mask_s1 = irm(feat_s1, feat_s2, use_log, use_power)
mask_s2 = 1 - mask_s1
key_wav_s1, ori_wav_s1 = readwav(s1_wav[i])
key_wav_s2, ori_wav_s2 = readwav(s2_wav[i])
assert key_wav_s1 == key_wav_s2
enhance_s1 = apply_mask(feat_mix, mask_s1, use_log, use_power)
enhance_s2 = apply_mask(feat_mix, mask_s2, use_log, use_power)
# Reconstruction
wav_s1 = overlap_and_add(enhance_s1, ori_wav_s1, sample_rate, frame_length,
frame_shift, window_type, preemphasis,
use_log, use_power, square_root_window)
wav_s2 = overlap_and_add(enhance_s2, ori_wav_s2, sample_rate, frame_length,
frame_shift, window_type, preemphasis,
use_log, use_power, square_root_window)
wavwrite(wav_s1, sample_rate, recons_dir + "/" + key_wav_s1 + "_1.wav")
wavwrite(wav_s2, sample_rate, recons_dir + "/" + key_wav_s2 + "_2.wav")
f_mix_wav.close()
f_s1_wav.close()
f_s2_wav.close()
# SI-SDR
eval_si_sdr(ori_dir, os.path.dirname(recons_dir))
# SDR.sources
eval_sdr_sources(ori_dir, os.path.dirname(recons_dir))
# SDR.v4
eval_sdr(ori_dir, os.path.dirname(recons_dir))