-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_feature.py
122 lines (100 loc) · 3.31 KB
/
get_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# !/usr/bin/python3
import os
import torch
import logging
from audioldm2 import text_to_feature, build_model, get_time, save_condition, read_json
import argparse
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
required=False,
default="ckpt/audioldm2-full.pth",
# default="ckpt/audioldm2-speech-gigaspeech.pth",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"-tl",
"--text_list",
type=str,
required=False,
default="",
help="A file that contains text prompt to the model for audio generation",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="/home/huangqiaochu/dtj/data/audiocaps",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="audioldm2-full",
choices=["audioldm2-full", "audioldm2-music-665k", "audioldm2-full-large-1150k","audioldm2-speech-ljspeech","audioldm2-speech-gigaspeech"]
)
parser.add_argument(
"-d",
"--device",
type=str,
required=False,
help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
default="auto",
)
parser.add_argument(
"--num_text",
type=int,
required=False,
default=3,
help="Generate using how many text prompts at one time",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=1234,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
subset = args.text_list.split('/')[-1].split('.')[0].split('_')[0]
if subset == 'valid':
subset = 'validation'
print(subset)
root_savepath = args.save_path
save_path = os.path.join(args.save_path, 'crossattn_audiomae_generated', subset)
save_path_mask = os.path.join(args.save_path, 'crossattn_audiomae_generated', subset)
random_seed = args.seed
os.makedirs(save_path, exist_ok=True)
os.makedirs(save_path_mask, exist_ok=True)
if(args.text_list):
print("Generate audio based on the text prompts in %s" % args.text_list)
text_all, name_all = read_json(args.text_list)
audioldm2 = build_model(ckpt_path=args.model_path, model_name=args.model_name, device=args.device)
step = args.num_text
total_iterations = len(text_all) // step
for i in tqdm(range(total_iterations), desc="Processing", unit="iteration"):
start_idx = i * step
end_idx = start_idx + step
# if start_idx < 11200 * 4:
# continue
texts = text_all[start_idx:end_idx]
name = name_all[start_idx:end_idx]
c = text_to_feature(audioldm2, texts)
save_condition(c, root_savepath, texts, subset)
# print(1)
remaining_elements = len(text_all) % step
if remaining_elements > 0:
start_idx = total_iterations * step # 使用总迭代次数来计算起始索引
end_idx = start_idx + remaining_elements
texts = text_all[start_idx:end_idx]
name = name_all[start_idx:end_idx]
c = text_to_feature(audioldm2, texts)
save_condition(c, root_savepath, texts, subset)