-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
104 lines (83 loc) · 2.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datetime import datetime
from coconet import Net, save_heatmaps, harmonize_melody_and_save_midi
from hparams import I, T, P, MIN_MIDI_PITCH, MAX_MIDI_PITCH
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load training data
data = np.load('Jsb16thSeparated.npz', encoding='bytes', allow_pickle=True)
# data augmentation
all_tracks = []
for y in data['train']:
for i in range(-3, 4):
all_tracks.append(y + i)
# construct training data
train_tracks = []
for track in all_tracks:
track = track.transpose()
cut = 0
while cut < track.shape[1] - T:
if (track[:, cut:cut + T] > 0).all():
train_tracks.append(track[:, cut:cut + T] - MIN_MIDI_PITCH)
cut += T
train_tracks = np.array(train_tracks).astype(int)
# get test sample
test_sample = data['valid'][0].transpose()[:, :T]
test_sample_melody = test_sample[0]
if __name__ == '__main__':
batch_size = 16
n_layers = 64
hidden_size = 128
n_train_steps = 80000
save_every = n_train_steps // 10
show_every = max(1, n_train_steps // 1000)
softmax = F.softmax
model = Net(n_layers, hidden_size).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
losses = []
model.train()
N = batch_size
for i in range(n_train_steps):
# tensor of shape (N, I, T)
C = np.random.randint(2, size=(N, I, T))
# batch is an np array of shape (N, I, T), entries are integers in [0, P)
indices = np.random.choice(train_tracks.shape[0], size=N)
batch = train_tracks[indices]
# targets is of shape (N*I*T)
targets = batch.reshape(-1)
targets = torch.tensor(targets).to(device)
# x is of shape (N, I, T, P)
batch = batch.reshape(N * I * T)
x = np.zeros((N * I * T, P))
r = np.arange(N * I * T)
x[r, batch] = 1
x = x.reshape(N, I, T, P)
x = torch.tensor(x).type(torch.FloatTensor).to(device)
C2 = torch.tensor(C).type(torch.FloatTensor).to(device)
out = model(x, C2)
out = out.view(N * I * T, P)
loss = loss_fn(out, targets)
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 500 == 0:
now_date_time = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f'{now_date_time} | step: {i} | loss: {loss.item()}')
D0 = np.ones((1, T))
D1 = np.zeros((3, T))
D = np.concatenate([D0, D1], axis=0).astype(int)
y = np.random.randint(P, size=(I, T))
y[0, :] = np.array(test_sample_melody - MIN_MIDI_PITCH)
save_heatmaps(model, y, D, i, device)
if i % 5000 == 0:
harmonize_melody_and_save_midi(test_sample_melody, i, model, device)
model.train()
# adjust learning rate
if i % 5000 == 0 and i > 0:
for g in optimizer.param_groups:
g['lr'] *= .75
torch.save(model.state_dict(), 'pretrained.pt')