-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
590 lines (445 loc) · 17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from audidata.collate.default import collate_fn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import wandb
from audio_understanding.data.samplers import InfiniteSampler
from audio_understanding.utils import LinearWarmUp, parse_yaml, remove_padded_columns
def train(args) -> None:
# Arguments
wandb_log = not args.no_log
config_path = args.config
filename = Path(__file__).stem
# Configs
configs = parse_yaml(config_path)
device = configs["train"]["device"]
# Checkpoints directory
config_name = Path(config_path).stem
ckpts_dir = Path("./checkpoints", filename, config_name)
Path(ckpts_dir).mkdir(parents=True, exist_ok=True)
# Datasets
train_dataset = get_dataset(configs, split="train")
test_dataset = get_dataset(configs, split="test")
# Sampler
train_sampler = InfiniteSampler(train_dataset)
# Dataloader
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=configs["train"]["batch_size_per_device"],
sampler=train_sampler,
num_workers=configs["train"]["num_workers"],
collate_fn=collate_fn,
pin_memory=True
)
# Audio encoder
audio_encoder = get_audio_encoder(
configs=configs,
ckpt_path=configs["train"]["resume_ckpt_path"]
).to(device)
# Tokenizer for converting text into IDs and vice versa
tokenizer = get_tokenizer(configs=configs)
# LLM decoder
llm = get_llm(
configs=configs,
audio_latent_dim=audio_encoder.latent_dim,
vocab_size=len(tokenizer),
ckpt_path=configs["train"]["resume_ckpt_path"]
).to(device)
# Learnable parameters
params = get_learnable_params(configs, audio_encoder, llm)
# Optimizer
optimizer, scheduler = get_optimizer_and_scheduler(
configs=configs,
params=params
)
# Logger
if wandb_log:
wandb.init(project="audio_understanding", name="{}".format(config_name))
# Train
for step, data in enumerate(tqdm(train_dataloader)):
# ------ 1. Data preparation ------
# 1.1 Prepare audio, question, and answering
audio, question, answering = get_audio_question_answering(data)
# audio: (b, c, t), question: (b, t), answering: (b, t)
# 1.2 Encode audio into latent
audio = audio.to(device)
audio_latent = audio_encoder.encode(audio=audio, train_mode=True) # shape: (b, t, d)
# 1.3 Tokenize question text to IDs
question_ids = tokenizer.texts_to_ids(
texts=question,
fix_length=configs["max_question_len"]
).to(device) # shape: (b, t)
# 1.4 Tokenize answering text to IDs
answering_ids = tokenizer.texts_to_ids(
texts=answering,
fix_length=configs["max_answering_len"]
).to(device) # shape: (b, t)
# 1.5 Remove padded columns to speed up training
if configs["train"]["remove_padded_columns"]:
answering_ids = remove_padded_columns(
ids=answering_ids,
pad_token_id=tokenizer.pad_token_id
)
# 1.6 Prepare inputs
seqs = [audio_latent, question_ids, answering_ids]
seq_types = ["audio", "id", "id"]
loss_types = [None, None, "ce"]
# ------ 2. Training ------
# 2.1 Forward
llm.train()
output_seqs = llm(
seqs=seqs,
seq_types=seq_types,
mask=None
) # list
# 2.2 Prepare data for next ID prediction
output_seqs = [seq[:, 0 : -1] for seq in output_seqs]
target_seqs = [seq[:, 1 :] for seq in seqs]
# 2.3 Loss
loss = ce_loss(
output_seqs=output_seqs,
target_seqs=target_seqs,
loss_types=loss_types,
ignore_index=tokenizer.pad_token_id
)
# 2.4 Optimize
optimizer.zero_grad() # Reset all parameter.grad to 0
loss.backward() # Update all parameter.grad
optimizer.step() # Update all parameters based on all parameter.grad
# 2.5 Learning rate scheduler
if scheduler:
scheduler.step()
if step % 100 == 0:
print(loss)
# ------ 3. Evaluation ------
# 3.1 Evaluate
if step % configs["train"]["test_every_n_steps"] == 0:
train_loss = validate(
configs=configs,
dataset=train_dataset,
audio_encoder=audio_encoder,
tokenizer=tokenizer,
llm=llm
)
test_loss = validate(
configs=configs,
dataset=test_dataset,
audio_encoder=audio_encoder,
tokenizer=tokenizer,
llm=llm
)
if wandb_log:
wandb.log(
data={"train_loss": train_loss, "test_loss": test_loss},
step=step
)
print("Train loss: {}".format(train_loss))
print("Test loss: {}".format(test_loss))
# 3.2 Save model
if step % configs["train"]["save_every_n_steps"] == 0:
ckpt_path = Path(ckpts_dir, "step={}.pth".format(step))
ckpt = {}
if configs["audio_encoder"]["trainable"]:
ckpt["audio_encoder"] = audio_encoder.state_dict()
if configs["llm"]["trainable"]:
ckpt["llm"] = llm.state_dict()
torch.save(ckpt, ckpt_path)
print("Save model to {}".format(ckpt_path))
if step == configs["train"]["training_steps"]:
break
def get_dataset(
configs: dict,
split: str
) -> Dataset:
r"""Get datasets."""
from audidata.io.crops import RandomCrop, StartCrop
from audidata.transforms import Mono, TextNormalization, TimeShift
sr = configs["sample_rate"]
clip_duration = configs["clip_duration"]
datasets_split = "{}_datasets".format(split)
datasets = []
for name in configs[datasets_split].keys():
if name == "GTZAN":
from audio_understanding.datasets.gtzan import GTZAN
dataset = GTZAN(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=sr,
crop=RandomCrop(clip_duration=clip_duration),
transform=Mono(),
)
datasets.append(dataset)
elif name == "LibriSpeech":
from audio_understanding.datasets.librispeech import LibriSpeech
dataset = LibriSpeech(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=sr,
crop=StartCrop(clip_duration=clip_duration),
transform=[Mono(), TimeShift(sr=sr, shift=(0., 0.5))],
)
datasets.append(dataset)
elif name == "Clotho":
from audio_understanding.datasets.clotho import Clotho
dataset = Clotho(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=sr,
crop=StartCrop(clip_duration=clip_duration),
transform=[Mono(), TimeShift(sr=sr, shift=(0., 0.5))],
target_transform=TextNormalization()
)
datasets.append(dataset)
elif name == "MAESTRO":
from audidata.transforms.midi import PianoRoll
from audio_understanding.datasets.maestro import MAESTRO
from audio_understanding.target_transforms.midi import MIDI2Tokens
if configs["midi_to_tokens"] == "MIDI2Tokens":
midi_transform = MIDI2Tokens(fps=configs["fps"])
else:
raise NotImplementedError
dataset = MAESTRO(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=sr,
crop=RandomCrop(clip_duration=clip_duration, end_pad=clip_duration - 0.1),
transform=Mono(),
load_target=True,
extend_pedal=True,
target_transform=[PianoRoll(fps=100, pitches_num=128), midi_transform],
)
datasets.append(dataset)
elif name == "AudioCaps":
from audio_understanding.datasets.audiocaps import AudioCaps
dataset = AudioCaps(
root=configs[datasets_split][name]["root"],
split=configs[datasets_split][name]["split"],
sr=sr,
crop=StartCrop(clip_duration=clip_duration),
transform=Mono(),
target_transform=TextNormalization()
)
datasets.append(dataset)
elif name == "WavCaps":
from audio_understanding.datasets.wavcaps import WavCaps
dataset = WavCaps(
root=configs[datasets_split][name]["root"],
sr=sr,
crop=StartCrop(clip_duration=clip_duration),
transform=Mono(),
target_transform=TextNormalization()
)
datasets.append(dataset)
else:
raise ValueError(name)
if len(datasets) == 1:
return datasets[0]
else:
raise ValueError("Do not support multiple datasets in this file.")
def get_audio_encoder(configs: dict, ckpt_path: str) -> nn.Module:
r"""Load pretrained audio encoder."""
name = configs["audio_encoder"]["name"]
sr = configs["sample_rate"]
trainable = configs["audio_encoder"]["trainable"]
if name == "Whisper":
from audio_understanding.audio_encoders.whisper import Whisper
model = Whisper(sr=sr, trainable=trainable)
elif name == "PianoTranscriptionCRnn":
from audio_understanding.audio_encoders.piano_transcription_crnn import \
PianoTranscriptionCRnn
model = PianoTranscriptionCRnn(sr=sr, trainable=trainable)
elif name == "PannsCnn14":
from audio_understanding.audio_encoders.panns import PannsCnn14
model = PannsCnn14(sr=sr, trainable=trainable)
else:
raise ValueError(name)
if ckpt_path and configs["audio_encoder"]["trainable"]:
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["audio_encoder"])
return model
def get_tokenizer(configs: dict) -> nn.Module:
r"""Get tokenizer."""
name = configs["tokenizer"]["name"]
if name == "Bert":
from audio_understanding.tokenizers.bert import Bert
tokenizer = Bert()
elif name == "BertMIDI":
from audio_understanding.tokenizers.bert_midi import BertMIDI
tokenizer = BertMIDI()
else:
raise ValueError(name)
return tokenizer
def get_llm(
configs: dict,
audio_latent_dim: int,
vocab_size: int,
ckpt_path: str
) -> nn.Module:
r"""Initialize LLM decoder."""
name = configs["llm"]["name"]
if name == "Llama":
from audio_understanding.llm.llama import Llama, LlamaConfig
block_size = configs["llm"]["block_size"]
n_layer = configs["llm"]["n_layer"]
n_head = configs["llm"]["n_head"]
n_embd = configs["llm"]["n_embd"]
config = LlamaConfig(
block_size=block_size,
audio_latent_dim=audio_latent_dim,
vocab_size=vocab_size,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd
)
model = Llama(config=config)
else:
raise ValueError(name)
if ckpt_path and configs["llm"]["trainable"]:
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["llm"])
return model
def get_learnable_params(
configs: dict,
audio_encoder: nn.Module,
llm: nn.Module
) -> list:
params = []
if configs["audio_encoder"]["trainable"]:
params += list(audio_encoder.parameters())
if configs["llm"]["trainable"]:
params += list(llm.parameters())
return params
def get_optimizer_and_scheduler(
configs: dict,
params: list[torch.Tensor]
) -> tuple[optim.Optimizer, None | optim.lr_scheduler.LambdaLR]:
r"""Get optimizer and scheduler."""
lr = float(configs["train"]["lr"])
warm_up_steps = configs["train"]["warm_up_steps"]
optimizer_name = configs["train"]["optimizer"]
if optimizer_name == "AdamW":
optimizer = optim.AdamW(params=params, lr=lr)
if warm_up_steps:
lr_lambda = LinearWarmUp(warm_up_steps)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lr_lambda)
else:
scheduler = None
return optimizer, scheduler
def get_audio_question_answering(
data: dict
) -> tuple[torch.Tensor, list[str], list[str]]:
r"""Process data to audio, question, and answering according to different
datasets.
Returns:
audio: (b, c, t)
question: (b, t)
answering: (b, t)
"""
name = data["dataset_name"][0]
if name in ["GTZAN"]:
return data["audio"], data["question"], data["label"]
elif name in ["AudioCaps", "Clotho", "LibriSpeech", "WavCaps"]:
return data["audio"], data["question"], data["caption"]
elif name in ["MAESTRO"]:
return data["audio"], data["question"], data["token"]
else:
raise ValueError(name)
def ce_loss(
output_seqs: list[torch.Tensor],
target_seqs: list[torch.Tensor],
loss_types: list[callable],
ignore_index: int
) -> torch.float:
r"""Calculate loss."""
total_loss = 0.
for i in range(len(output_seqs)):
if loss_types[i] is None:
continue
elif loss_types[i] == "ce":
total_loss += F.cross_entropy(
input=output_seqs[i].flatten(0, 1), # shape: (b*t, vocab_size)
target=target_seqs[i].flatten(0, 1), # shape: (b*t,)
ignore_index=-1
)
else:
raise ValueError(loss_types[i])
return total_loss
def validate(
configs: dict,
dataset: Dataset,
audio_encoder: nn.Module,
tokenizer: object,
llm: nn.Module,
valid_steps=50
) -> float:
r"""Validate the model on part of data."""
device = next(audio_encoder.parameters()).device
losses = []
batch_size = configs["train"]["batch_size_per_device"]
skip_n = max(1, len(dataset) // valid_steps)
for idx in range(0, len(dataset), skip_n):
print("{}/{}".format(idx, len(dataset)))
# ------ 1. Data preparation ------
# 1.0 Collate data to batch
data = [dataset[i] for i in range(idx, min(idx + batch_size, len(dataset)))]
data = collate_fn(data)
# 1.1 Prepare audio, question, and answering
audio, question, answering = get_audio_question_answering(data)
# audio: (b, c, t), question: (b, t), answering: (b, t)
# 1.3 Tokenize question text to IDs
audio = audio.to(device)
audio_latent = audio_encoder.encode(audio=audio, train_mode=False) # shape: (b, t, d)
# 1.4 Tokenize answering text to IDs
question_ids = tokenizer.texts_to_ids(
texts=question,
fix_length=configs["max_question_len"]
).to(device) # shape: (b, t)
# 1.5 Remove padded columns to speed up training
answering_ids = tokenizer.texts_to_ids(
texts=answering,
fix_length=configs["max_answering_len"]
).to(device) # shape: (b, t)
# 1.6 Prepare inputs
if configs["train"]["remove_padded_columns"]:
answering_ids = remove_padded_columns(
ids=answering_ids,
pad_token_id=tokenizer.pad_token_id
)
# Prepare inputs
seqs = [audio_latent, question_ids, answering_ids]
seq_types = ["audio", "id", "id"]
loss_types = [None, None, "ce"]
# ------ 2. Training ------
# 2.1 Forward
with torch.no_grad():
llm.eval()
output_seqs = llm(
seqs=seqs,
seq_types=seq_types,
mask=None
) # list
# 2.2 Prepare data for next ID prediction
output_seqs = [seq[:, 0 : -1] for seq in output_seqs]
target_seqs = [seq[:, 1 :] for seq in seqs]
# 2.3 Loss
loss = ce_loss(
output_seqs=output_seqs,
target_seqs=target_seqs,
loss_types=loss_types,
ignore_index=tokenizer.pad_token_id
)
losses.append(loss.item())
return np.mean(losses)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path of config yaml.")
parser.add_argument("--no_log", action="store_true", default=False)
args = parser.parse_args()
train(args)