-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathevaluate.py
96 lines (84 loc) · 3.23 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
from __future__ import print_function
from hyperparams import Hyperparams as hp
import tqdm
from data_load import load_data
import tensorflow as tf
from graph import Graph
from utils import spectrogram2wav
from scipy.io.wavfile import write
import os
import numpy as np
from utils import load_spectrograms
def mse(list1, list2):
return ((list1-list2) ** 2).mean(axis=None)
def calculate_mse(arr1, arr2):
if len(arr1) > len(arr2):
result = np.zeros(arr1.shape)
result[:arr2.shape[0]] = arr2
return mse(arr1, result)
else:
result = np.zeros(arr2.shape)
result[:arr1.shape[0]] = arr1
return mse(arr2, result)
evaluate_wav_num = 400
output_file = "evaluate_scores.txt"
opf = open(output_file, "a")
def evaluate():
# Load graph
g = Graph(mode="evaluate"); print("Graph loaded")
# Load data
fpaths, _, texts = load_data(mode="evaluate")
lengths = [len(t) for t in texts]
maxlen = sorted(lengths, reverse=True)[0]
new_texts = np.zeros((len(texts), maxlen), np.int32)
for i, text in enumerate(texts):
new_texts[i, :len(text)] = [idx for idx in text]
#new_texts = np.split(new_texts, 2)
new_texts = new_texts[:evaluate_wav_num]
half_size = int(len(fpaths)/2)
print(half_size)
#new_fpaths = [fpaths[:half_size], fpaths[half_size:]]
fpaths = fpaths[:evaluate_wav_num]
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Evaluate Model Restored!")
"""
err = 0.0
for i, t_split in enumerate(new_texts):
y_hat = np.zeros((t_split.shape[0], 200, hp.n_mels*hp.r), np.float32) # hp.n_mels*hp.r
for j in tqdm.tqdm(range(200)):
_y_hat = sess.run(g.y_hat, {g.x: t_split, g.y: y_hat})
y_hat[:, j, :] = _y_hat[:, j, :]
mags = sess.run(g.z_hat, {g.y_hat: y_hat})
for k, mag in enumerate(mags):
fname, mel_ans, mag_ans = load_spectrograms(new_fpaths[i][k])
print("File {} is being evaluated ...".format(fname))
audio = spectrogram2wav(mag)
audio_ans = spectrogram2wav(mag_ans)
err += calculate_mse(audio, audio_ans)
err = err/float(len(fpaths))
print(err)
"""
# Feed Forward
## mel
y_hat = np.zeros((new_texts.shape[0], 200, hp.n_mels*hp.r), np.float32) # hp.n_mels*hp.r
for j in tqdm.tqdm(range(200)):
_y_hat = sess.run(g.y_hat, {g.x: new_texts, g.y: y_hat})
y_hat[:, j, :] = _y_hat[:, j, :]
## mag
mags = sess.run(g.z_hat, {g.y_hat: y_hat})
err = 0.0
for i, mag in enumerate(mags):
fname, mel_ans, mag_ans = load_spectrograms(fpaths[i])
print("File {} is being evaluated ...".format(fname))
#audio = spectrogram2wav(mag)
#audio_ans = spectrogram2wav(mag_ans)
#err += calculate_mse(audio, audio_ans)
err += calculate_mse(mag, mag_ans)
err = err/float(len(fpaths))
print(err)
opf.write(hp.logdir + " spectrogram mse: " + str(err) + "\n")
if __name__ == '__main__':
evaluate()
print("Done")