Skip to content

Commit

Permalink
Merge pull request RVC-Project#1769 from RVC-Project/formatter-dev
Browse files Browse the repository at this point in the history
chore(format): run black on dev
  • Loading branch information
RVC-Boss authored Jan 26, 2024
2 parents 850ec48 + 738e55f commit c09d1bc
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 148 deletions.
10 changes: 5 additions & 5 deletions gui_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,11 +877,11 @@ def audio_callback(
self.input_wav_denoise[-self.block_frame - 2 * self.zc :]
)[160:]
else:
self.input_wav_res[
-160 * (indata.shape[0] // self.zc + 1) :
] = self.resampler(self.input_wav[-indata.shape[0] - 2 * self.zc :])[
160:
]
self.input_wav_res[-160 * (indata.shape[0] // self.zc + 1) :] = (
self.resampler(self.input_wav[-indata.shape[0] - 2 * self.zc :])[
160:
]
)
# infer
if self.function == "vc":
infer_wav = self.rvc.infer(
Expand Down
156 changes: 113 additions & 43 deletions infer-web.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions infer/lib/infer_pack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,17 @@ def forward(self, f0: torch.Tensor, upp: int):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
rad_values = (
f0_buf / self.sampling_rate
) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one = torch.cumsum(
rad_values, 1
) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
Expand Down
8 changes: 6 additions & 2 deletions infer/lib/infer_pack/models_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,17 @@ def forward(self, f0, upp):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
idx + 2
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
rad_values = (
f0_buf / self.sampling_rate
) % 1 ###%1意味着n_har的乘积无法后处理优化
rand_ini = torch.rand(
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one = torch.cumsum(
rad_values, 1
) # % 1 #####%1意味着后面的cumsum无法再优化
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
Expand Down
114 changes: 60 additions & 54 deletions infer/modules/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def torch_bmm(input, mat2, *, out=None):
): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[
start_idx:end_idx, start_idx_2:end_idx_2
] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out,
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = (
original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out,
)
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
Expand Down Expand Up @@ -138,61 +138,67 @@ def scaled_dot_product_attention(
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
hidden_states[
start_idx:end_idx, start_idx_2:end_idx_2
] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[
start_idx:end_idx, start_idx_2:end_idx_2
]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = (
original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=(
attn_mask[start_idx:end_idx, start_idx_2:end_idx_2]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
hidden_states[
:, start_idx:end_idx, start_idx_2:end_idx_2
] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[
:, start_idx:end_idx, start_idx_2:end_idx_2
]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = (
original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=(
attn_mask[
:, start_idx:end_idx, start_idx_2:end_idx_2
]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
if no_shape_one:
hidden_states[
start_idx:end_idx
] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[start_idx:end_idx] = (
original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=(
attn_mask[start_idx:end_idx]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
hidden_states[
:, start_idx:end_idx
] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[:, start_idx:end_idx] = (
original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=(
attn_mask[:, start_idx:end_idx]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
return original_scaled_dot_product_attention(
Expand Down
42 changes: 25 additions & 17 deletions infer/modules/ipex/hijacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def return_xpu(device):
return (
f"xpu:{device[-1]}"
if isinstance(device, str) and ":" in device
else f"xpu:{device}"
if isinstance(device, int)
else torch.device("xpu")
if isinstance(device, torch.device)
else "xpu"
else (
f"xpu:{device}"
if isinstance(device, int)
else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
)
)


Expand Down Expand Up @@ -271,12 +271,16 @@ def ipex_hijacks():
"torch.batch_norm",
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
input,
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device),
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device),
(
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device)
),
(
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device)
),
*args,
**kwargs,
),
Expand All @@ -286,12 +290,16 @@ def ipex_hijacks():
"torch.instance_norm",
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
input,
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device),
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device),
(
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device)
),
(
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device)
),
*args,
**kwargs,
),
Expand Down
8 changes: 5 additions & 3 deletions infer/modules/train/extract_feature_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ def readwave(wav_path, normalize=False):
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device)
if is_half and device not in ["mps", "cpu"]
else feats.to(device),
"source": (
feats.half().to(device)
if is_half and device not in ["mps", "cpu"]
else feats.to(device)
),
"padding_mask": padding_mask.to(device),
"output_layer": 9 if version == "v1" else 12, # layer 9
}
Expand Down
22 changes: 12 additions & 10 deletions infer/modules/vc/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,28 @@ def get_vc(self, sid, *to_return_protect):

to_return_protect0 = {
"visible": self.if_f0 != 0,
"value": to_return_protect[0]
if self.if_f0 != 0 and to_return_protect
else 0.5,
"value": (
to_return_protect[0] if self.if_f0 != 0 and to_return_protect else 0.5
),
"__type__": "update",
}
to_return_protect1 = {
"visible": self.if_f0 != 0,
"value": to_return_protect[1]
if self.if_f0 != 0 and to_return_protect
else 0.33,
"value": (
to_return_protect[1] if self.if_f0 != 0 and to_return_protect else 0.33
),
"__type__": "update",
}

if sid == "" or sid == []:
if self.hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
if (
self.hubert_model is not None
): # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
logger.info("Clean model cache")
del (self.net_g, self.n_spk, self.hubert_model, self.tgt_sr) # ,cpt
self.hubert_model = (
self.net_g
) = self.n_spk = self.hubert_model = self.tgt_sr = None
self.hubert_model = self.net_g = self.n_spk = self.hubert_model = (
self.tgt_sr
) = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
###楼下不这么折腾清理不干净
Expand Down
28 changes: 21 additions & 7 deletions tools/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,28 @@
)
sid.change(fn=vc.get_vc, inputs=[sid], outputs=[spk_item])
gr.Markdown(
value=i18n("男转女推荐+12key, 女转男推荐-12key, 如果音域爆炸导致音色失真也可以自己调整到合适音域. ")
value=i18n(
"男转女推荐+12key, 女转男推荐-12key, 如果音域爆炸导致音色失真也可以自己调整到合适音域. "
)
)
vc_input3 = gr.Audio(label="上传音频(长度小于90秒)")
vc_transform0 = gr.Number(label=i18n("变调(整数, 半音数量, 升八度12降八度-12)"), value=0)
vc_transform0 = gr.Number(
label=i18n("变调(整数, 半音数量, 升八度12降八度-12)"), value=0
)
f0method0 = gr.Radio(
label=i18n("选择音高提取算法,输入歌声可用pm提速,harvest低音好但巨慢无比,crepe效果好但吃GPU"),
label=i18n(
"选择音高提取算法,输入歌声可用pm提速,harvest低音好但巨慢无比,crepe效果好但吃GPU"
),
choices=["pm", "harvest", "crepe", "rmvpe"],
value="pm",
interactive=True,
)
filter_radius0 = gr.Slider(
minimum=0,
maximum=7,
label=i18n(">=3则使用对harvest音高识别的结果使用中值滤波,数值为滤波半径,使用可以削弱哑音"),
label=i18n(
">=3则使用对harvest音高识别的结果使用中值滤波,数值为滤波半径,使用可以削弱哑音"
),
value=3,
step=1,
interactive=True,
Expand Down Expand Up @@ -107,19 +115,25 @@
rms_mix_rate0 = gr.Slider(
minimum=0,
maximum=1,
label=i18n("输入源音量包络替换输出音量包络融合比例,越靠近1越使用输出包络"),
label=i18n(
"输入源音量包络替换输出音量包络融合比例,越靠近1越使用输出包络"
),
value=1,
interactive=True,
)
protect0 = gr.Slider(
minimum=0,
maximum=0.5,
label=i18n("保护清辅音和呼吸声,防止电音撕裂等artifact,拉满0.5不开启,调低加大保护力度但可能降低索引效果"),
label=i18n(
"保护清辅音和呼吸声,防止电音撕裂等artifact,拉满0.5不开启,调低加大保护力度但可能降低索引效果"
),
value=0.33,
step=0.01,
interactive=True,
)
f0_file = gr.File(label=i18n("F0曲线文件, 可选, 一行一个音高, 代替默认F0及升降调"))
f0_file = gr.File(
label=i18n("F0曲线文件, 可选, 一行一个音高, 代替默认F0及升降调")
)
but0 = gr.Button(i18n("转换"), variant="primary")
vc_output1 = gr.Textbox(label=i18n("输出信息"))
vc_output2 = gr.Audio(label=i18n("输出音频(右下角三个点,点了可以下载)"))
Expand Down
1 change: 1 addition & 0 deletions tools/infer/infer-pm-index256.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
对源特征进行检索
"""

import os
import logging

Expand Down
1 change: 1 addition & 0 deletions tools/infer/train-index-v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
格式:直接cid为自带的index位;aid放不下了,通过字典来查,反正就5w个
"""

import os
import traceback
import logging
Expand Down
1 change: 1 addition & 0 deletions tools/infer/train-index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
格式:直接cid为自带的index位;aid放不下了,通过字典来查,反正就5w个
"""

import os
import logging

Expand Down
Loading

0 comments on commit c09d1bc

Please sign in to comment.