-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathipsm_oracle.py
119 lines (93 loc) · 3.59 KB
/
ipsm_oracle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2018 Microsoft Research Aisa (author: Ke Wang)
# 2019 Northwestern Polytechnical University (author: Ke Wang)
from __future__ import absolute_import
from __future__ import division
from __future__ import absolute_import
import os
import sys
import numpy as np
sys.path.append(os.path.dirname(sys.path[0]) + '/utils')
from evaluate.eval_sdr import eval_sdr
from evaluate.eval_sdr_sources import eval_sdr_sources
from evaluate.eval_si_sdr import eval_si_sdr
from sigproc.dsp import get_phase, overlap_and_add, wavread, wavwrite
from sigproc.mask import apply_mask, ipsm_spectrum
from sigproc.spectrum import spectrum
# EPSILON = np.finfo(np.float32).eps
EPSILON = 0.0
MAX_MASK = 10.0
mix_wav_scp = 'data/tt/mix.scp'
s1_wav_scp = 'data/tt/s1.scp'
s2_wav_scp = 'data/tt/s2.scp'
ori_dir = 'data/2speakers/wav8k/min/tt'
recons_dir = 'exp/ipsm_oracle/wav'
sample_rate = 8000
frame_length = 32
frame_shift = 8
window_type = "hanning"
preemphasis = 0.0
square_root_window = True
# do not change
use_log = False
use_power = False
# do not change
if not os.path.exists(recons_dir):
os.makedirs(recons_dir)
f_mix_wav = open(mix_wav_scp, "r")
f_s1_wav = open(s1_wav_scp, "r")
f_s2_wav = open(s2_wav_scp, "r")
mix_wav = f_mix_wav.readlines()
s1_wav = f_s1_wav.readlines()
s2_wav = f_s2_wav.readlines()
assert len(mix_wav) == len(s1_wav)
assert len(s1_wav) == len(s2_wav)
def readwav(line):
key, path = line.strip().split()
wav, frame_rate = wavread(path)
return key, wav
def compute_spectrum(line):
key, wav = readwav(line)
feat = spectrum(wav, sample_rate, frame_length, frame_shift,
window_type, preemphasis, use_log, use_power,
square_root_window)
return key, feat
def compute_phase(line):
key, wav = readwav(line)
phase = get_phase(wav, sample_rate, frame_length, frame_shift, window_type,
preemphasis, square_root_window)
return phase
for i in range(len(mix_wav)):
key_mix, feat_mix = compute_spectrum(mix_wav[i])
key_s1, feat_s1 = compute_spectrum(s1_wav[i])
key_s2, feat_s2 = compute_spectrum(s2_wav[i])
assert key_mix == key_s1 and key_s1 == key_s2
phase_mix = compute_phase(mix_wav[i])
phase_s1 = compute_phase(s1_wav[i])
phase_s2 = compute_phase(s2_wav[i])
mask_s1 = ipsm_spectrum(feat_s1, feat_mix, phase_s1, phase_mix, use_log, use_power)
mask_s2 = ipsm_spectrum(feat_s2, feat_mix, phase_s2, phase_mix, use_log, use_power)
# mask_s1 = np.clip(mask_s1, a_min=EPSILON, a_max=MAX_MASK)
# mask_s2 = np.clip(mask_s2, a_min=EPSILON, a_max=MAX_MASK)
key_wav, wav = readwav(mix_wav[i])
enhance_s1 = apply_mask(feat_mix, mask_s1, use_log, use_power)
enhance_s2 = apply_mask(feat_mix, mask_s2, use_log, use_power)
# Reconstruction
wav_s1 = overlap_and_add(enhance_s1, wav, sample_rate, frame_length,
frame_shift, window_type, preemphasis,
use_log, use_power, square_root_window)
wav_s2 = overlap_and_add(enhance_s2, wav, sample_rate, frame_length,
frame_shift, window_type, preemphasis,
use_log, use_power, square_root_window)
wavwrite(wav_s1, sample_rate, recons_dir + "/" + key_wav + "_1.wav")
wavwrite(wav_s2, sample_rate, recons_dir + "/" + key_wav + "_2.wav")
f_mix_wav.close()
f_s1_wav.close()
f_s2_wav.close()
# SI-SDR
eval_si_sdr(ori_dir, os.path.dirname(recons_dir))
# SDR.sources
eval_sdr_sources(ori_dir, os.path.dirname(recons_dir))
# SDR.v4
eval_sdr(ori_dir, os.path.dirname(recons_dir))