forked from Westlake-AI/Markov-Lipschitz-Deep-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
92 lines (69 loc) · 3.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class MLDL_model(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.NetworkStructure = args['NetworkStructure']
self.name_list = ['Layer 0 ({})'.format(self.NetworkStructure[0])]
self.index_list = [0]
self.network = nn.ModuleList()
# Encoder
for i in range(len(self.NetworkStructure)-1):
self.network.append(
nn.Linear(
self.NetworkStructure[i], self.NetworkStructure[i+1])
)
self.name_list.append('Layer {} ({})'.format(i+1, self.NetworkStructure[i+1]))
if i != len(self.NetworkStructure)-2:
if 'Spheres' in self.args['DATASET'] and self.args['Mode'] == 'ML-AE':
self.network.append(nn.LeakyReLU())
else:
self.network.append(nn.LeakyReLU(0.1))
self.name_list.append('Layer {} ({})'.format(i+1, self.NetworkStructure[i+1]))
self.index_list.append(len(self.name_list)-1)
# Decoder
for i in range(len(self.NetworkStructure)-1, 0, -1):
self.network.append(
nn.Linear(
self.NetworkStructure[i], self.NetworkStructure[i-1])
)
self.name_list.append('Layer {}\' ({})'.format(i-1, self.NetworkStructure[i-1]))
if i > 1:
if 'Spheres' in self.args['DATASET'] and self.args['Mode'] == 'ML-AE':
self.network.append(nn.LeakyReLU())
else:
self.network.append(nn.LeakyReLU(0.1))
self.name_list.append('Layer {}\' ({})'.format(i-1, self.NetworkStructure[i-1]))
self.index_list.append(len(self.name_list)-1)
# Forward, and saves all intermediate results as a list
def forward(self, data):
data = data.view(data.shape[0], -1)
output_info = [data, ]
input_data = data
for i, layer in enumerate(self.network):
output_data = layer(input_data)
output_info.append(output_data)
input_data = output_data
return output_info
# Input the input layer data, pass the encoder, and get the reconstruction result
def Encoder(self, data):
output_info = []
input_data = data
for i, layer in enumerate(self.network):
if i <= (len(self.NetworkStructure) - 2) * 2:
output_data = layer(input_data)
output_info.append(output_data)
input_data = output_data
return output_info[-1]
# Input the hidden layer data, pass the decoder, and get the reconstruction result
def Decoder(self, data):
output_info = []
input_data = data
for i, layer in enumerate(self.network):
if i > (len(self.NetworkStructure) - 2) * 2:
output_data = layer(input_data)
output_info.append(output_data)
input_data = output_data
return output_info[-1]