-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathml_torch_rnn.py
75 lines (61 loc) · 2.1 KB
/
ml_torch_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import math
import random
import numpy as np
import matplotlib.pyplot as plt
class TrainDataSet(torch.utils.data.Dataset):
def __init__(self):
pass
def __getitem__(self, item):
position = float(item) / 100
x = math.sin(math.pi * position) + random.uniform(-0.1, 0.1)
y = math.cos(math.pi * position)
return x, y
def __len__(self):
return 100000
class DemoRNN(torch.nn.Module):
def __init__(self):
super(DemoRNN, self).__init__()
self.rnn = torch.nn.RNN(input_size=1, hidden_size=10, num_layers=3, batch_first=True)
self.output_layer = torch.nn.Linear(in_features=10, out_features=1)
def forward(self, x, hidden):
output, h_out = self.rnn(x, hidden)
linear_input = output.view(-1, output.shape[2])
output = self.output_layer(linear_input)
return output, h_out
def main():
train_data = TrainDataSet()
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=1000, shuffle=False)
nn_model = DemoRNN()
print(nn_model)
optimize = torch.optim.Adam(nn_model.parameters(), lr=0.01)
loss_func = torch.nn.L1Loss()
h_output = None
plt.ion()
plt.show()
plt.cla()
for i, data in enumerate(train_loader):
plt.cla()
x, y = data
x = torch.from_numpy(np.float32(x.numpy()[np.newaxis, :, np.newaxis]))
y = torch.from_numpy(np.float32(y.numpy()[:, np.newaxis]))
row_idx = np.arange(0, x.shape[1], 1)
x_array = x.numpy()[0, :, 0]
y_array = y.numpy()[:, 0]
nn_model.train()
optimize.zero_grad()
output, h_output = nn_model(x, h_output)
h_output = h_output.data
loss = loss_func(output, y)
loss.backward()
optimize.step()
nn_model.eval()
predict_y, h = nn_model(x, None)
predict_y = predict_y.cpu().detach().numpy()[:, 0]
plt.cla()
plt.plot(row_idx, x_array, "b-")
plt.plot(row_idx, y_array, "g-")
plt.plot(row_idx, predict_y, "r-")
plt.pause(0.1)
if __name__ == "__main__":
main()