-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgetmask.py
80 lines (65 loc) · 2.45 KB
/
getmask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import numpy as np
import torch
import torch.nn as nn
from VAE import MelVae
from audioldm2.utils import default_audioldm_config
import matplotlib.pyplot as plt
def saveVaeLatent(latent, path='subplots_image_.png'):
########传入numpy数组########
# 创建一张大图片,分成 9 行 1 列的子图
assert latent.shape[0] == 8
num_rows = 9
fig, axes = plt.subplots(num_rows, figsize=(12, 16))
# 将八个张量绘制成子图
for i in range(num_rows):
# for j in range(num_cols):
if i < 8:
t = latent[i].T
else :
t = sum(latent) / 8
t = t.T
ax = axes[i]
imm = ax.imshow(t, cmap='bwr', origin="lower", aspect="auto") # 根据需要设置颜色映射
# 添加颜色条(数值标尺)
colorbar = plt.colorbar(imm)
colorbar.set_label('Color Scale') # 设置颜色条标签
ax.set_aspect(2.5, adjustable="box")
ax.set_title(f'Tensor {i+1}') # 设置子图标题
ax.axis('off') # 不显示坐标轴
# 调整子图之间的间距
plt.tight_layout()
# 保存图片
plt.savefig(path)
def saveMel(data, path="mel_spectrum_recon.png"):
########传入numpy数组########
fig1, ax1 = plt.subplots()
# 绘制Mel频谱图
im = ax1.imshow(data, cmap='inferno', origin="lower", aspect="auto")
# 添加颜色条(数值标尺)
colorbar = plt.colorbar(im)
colorbar.set_label('Color Scale') # 设置颜色条标签
# 添加标题
plt.title(path.split('/')[-1].split('.')[0])
# 设置保存的图像尺寸
# desired_aspect_ratio = 2.5 # 所需的纵横比
ax1.set_aspect(2.5, adjustable="box")
# fig.set_size_inches(8, 8 / desired_aspect_ratio) # 根据所需纵横比计算高度
# 保存图像
plt.savefig(path)
path = 'output/18_08_2023_09_34_43/the sound of a light saber.wav'
config = default_audioldm_config()
melvae = MelVae(config, ckpt_path='ckpt/audioldm2-full.pth')
Mel,_,_ = melvae.get_mels([path])
Mel[0] = -11.5129
Mel[0][384:640] = -0.5
latent = melvae.get_latens(Mel.unsqueeze(0)).squeeze(0)
mel_recon = melvae.decode_first_stage(latent.unsqueeze(0)).squeeze(0)
print(latent.size())
torch.save(latent.clone(), 'prior_384-640.pt')
data_recon = mel_recon[0].numpy().T
data = Mel[0].numpy().T
saveMel(data_recon, path="mel_spectrum_recon.png")
saveMel(data, path="mel_spectrum.png")
#
saveVaeLatent(latent.detach().numpy(), path='latent.png')