-
Notifications
You must be signed in to change notification settings - Fork 0
/
recover_model.py
232 lines (209 loc) · 8.99 KB
/
recover_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# -*- coding: utf-8 -*-
"""
deepONet_HH_pytorch.py
author: Edoardo Centofanti
Learning Hodgkin-Huxley model with DeepONet
"""
# internal modules
from src.utility_dataset import *
from src.architectures import get_optimizer, get_loss
# external modules
import torch
# for test launcher interface
import os
import yaml
import argparse
import matplotlib.pyplot as plt
#########################################
# default value
#########################################
mydevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#mydevice = 'cpu'
## Metal GPU acceleration on Mac OSX
## NOT WORKING since ComplexFloat and Float64 is not supported by the MPS backend
## It isn't worth to force the conversion since we have cuda machine for test
#if torch.backends.mps.is_available():
# mydevice = torch.device("mps")
# print ("MPS acceleration is enabled")
#else:
# print ("MPS device not found.")
torch.set_default_device(mydevice) # default tensor device
torch.set_default_dtype(torch.float32) # default tensor dtype
# Define command-line arguments
parser = argparse.ArgumentParser(description="Learning Hodgkin-Huxley model with DeepONet")
parser.add_argument("--config_file", type=str, default="default_params_don.yml", help="Path to the YAML configuration file")
args = parser.parse_args()
# Read the configuration from the specified YAML file
with open(args.config_file, "r") as config_file:
config = yaml.safe_load(config_file)
param_file_name = os.path.splitext(args.config_file)[0]
# Now, `param_file_name` contains the name without the .json suffix
print("Test name:", param_file_name)
#########################################
# files names to be saved
#########################################
name_log_dir = 'exp_' + param_file_name
name_model = 'model_' + param_file_name
#########################################
# DeepONet's hyperparameter
#########################################
arc = config.get("arc")
dataset_train = config.get("dataset_train")
dataset_test = config.get("dataset_test")
batch_size = config.get("batch_size")
scaling = config.get("scaling")
labels = config.get("labels") # default False
full_v_data = config.get("full_v_data") # default False
N_FourierF = config.get("N_FourierF")
scale_FF = 1 # config.get("scale_FF")
adapt_actfun = config.get("adapt_actfun")
scheduler = config.get("scheduler")
Loss = config.get("Loss")
epochs = config.get("epochs")
lr = config.get("lr")
u_dim = config.get("u_dim")
x_dim = config.get("x_dim")
G_dim = config.get("G_dim")
inner_layer_b = config.get("inner_layer_b")
inner_layer_t = config.get("inner_layer_t")
activation_b = config.get("activation_b")
activation_t = config.get("activation_t")
arc_b = config.get("arc_b")
arc_t = config.get("arc_t")
initial_b = config.get("initial_b")
initial_t = config.get("initial_t")
#### WNO parameters
width = config.get("width")
level = config.get("level")
#### FNO parameters
d_a = config.get("d_a")
d_v = config.get("d_v")
d_u = config.get("d_u")
L = config.get("L")
modes = config.get("modes")
act_fun = config.get("act_fun")
initialization = config.get("initialization")
scalar = config.get("scalar")
padding = config.get("padding")
arc_fno = config.get("arc_fno")
x_padding = config.get("x_padding")
RNN = config.get("RNN")
#### Plotting parameters
show_every = config.get("show_every")
ep_step = config.get("ep_step")
idx = config.get("idx")
plotting = config.get("plotting")
#########################################
# MAIN
#########################################
if __name__=="__main__":
# [159, 69, 134, 309]
idx = torch.randint(low=0, high=400, size=(4,))
# 50 has 0 peaks
idx = torch.tensor([50, 309, 134, 159])
print("indexes to print = "+str(idx))
# Load dataset
if "LR" in dataset_train:
u_train, x_train, v_train, scale_fac = load_LR_train(dataset_train,full_v_data)
u_test, x_test, v_test, indices = load_LR_test(dataset_test,full_v_data)
else:
u_train, x_train, v_train, scale_fac, _ = load_train(dataset_train,scaling,labels,full_v_data,shuffle=True)
u_test, x_test, v_test, indices = load_test(dataset_test,scale_fac,scaling,labels,full_v_data,shuffle=True)
# batch loader
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(v_train, u_train),
batch_size = batch_size)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(v_test, u_test),
batch_size = batch_size)
# Loss function
myloss = get_loss(Loss)
#modelname = "./model_" + param_file_name.replace("./", "")
modelname = "peano_test/model_" + param_file_name.replace("peano_test/", "")
#model = torch.load(modelname)
model = torch.load(modelname, map_location=torch.device('cpu'))
#### initial value of v
u_test_unscaled, x_test_unscaled, v_test_unscaled = load_test(dataset_test,full_v_data=True)
# Same order of scaled data
u_test_unscaled = u_test_unscaled[indices]
v_test_unscaled = v_test_unscaled[indices]
esempio_test = v_test_unscaled[idx, :].to('cpu')
esempio_test_pp = v_test[idx, :].to('cpu')
sol_test = u_test_unscaled[idx].to('cpu')
x_test_unscaled = x_test_unscaled.to('cpu')
## Third figure for approximation with DON of HH model
with torch.no_grad(): # no grad for efficiency reasons
out_test = model((esempio_test_pp.to(mydevice), x_test.to(mydevice)))
out_test = out_test.to('cpu')
params = {'legend.fontsize': 12,
'axes.labelsize': 20,
'axes.titlesize': 20,
'xtick.labelsize': 15,
'ytick.labelsize': 15}
plt.rcParams.update(params)
plot_error = False
plot_stimuli = True
plotname = ''
plot_plane = True
if not plot_plane:
# Create a single figure with a grid layout
if plot_stimuli==True:
fig, axs = plt.subplots(1, len(idx), figsize=(16, 2.6))
# First row: Applied current (I_app)
for i in range(len(idx)):
axs[i].plot(x_test_unscaled, esempio_test[i])
axs[0].set_ylabel('$I_{app}$(t)')
#axs[i].set_xlabel('t')
axs[i].set_ylim([-0.2, 10])
axs[i].grid()
plotname = 'Test_input.eps'
else:
fig, axs = plt.subplots(1, len(idx), figsize=(16, 5.3))
if plot_error==False:
# Second row: Numerical approximation (V_m) and DON approximation (V_m)
for i in range(len(idx)):
axs[i].plot(x_test_unscaled, sol_test[i].to('cpu'), label='Numerical approximation')
axs[i].plot(x_test_unscaled, out_test[i], 'r--', label=arc+' approximation')
axs[0].set_ylabel('$V_m$ (mV)', labelpad=-5)
if arc=='WNO':
axs[i].set_xlabel('t')
axs[i].set_ylim([-100, 50])
axs[i].grid()
axs[i].legend(loc='upper left')
plotname = arc+param_file_name.replace("peano_test/"+arc.lower(), "") + '.eps'
else:
# OR Second row: Error between Numerical approximation (V_m) and DON approximation (V_m)
for i in range(len(idx)):
axs[i].semilogy(x_test_unscaled, torch.abs(sol_test[i]-out_test[i]).to('cpu'), label='$|u_{NN}-u_{num}|$')
axs[0].set_ylabel('$V_m$ (mV)', labelpad=5)
if arc=='WNO':
axs[i].set_xlabel('t')
axs[i].set_ylim([1E-5, 100])
axs[i].grid()
axs[i].legend(loc='upper left')
plotname = arc+param_file_name.replace("peano_test/"+arc.lower(), "") + '_err' +'.eps'
else:
plotname = "HHpeaks.eps"
fig, axs = plt.subplots(2, len(idx), figsize=(16, 8), gridspec_kw={'height_ratios': [2, 1]})
# Second row: Applied current (I_app)
for i in range(len(idx)):
axs[1,i].plot(x_test_unscaled, esempio_test[i])
axs[1,0].set_ylabel('$I_{app}$(t)')
axs[1,i].set_xlabel('t')
axs[1,i].set_ylim([-0.2, 10])
if i!=0:
axs[1,i].set_yticks([])
#axs[1,i].grid()
# First row: Numerical approximation (V_m) and DON approximation (V_m)
for i in range(len(idx)):
axs[0,i].plot(x_test_unscaled, sol_test[i].to('cpu'), label='Numerical approximation')
axs[0,0].set_ylabel('$V_m$ (mV)', labelpad=-5)
axs[0,i].set_ylim([-100, 50])
axs[0,i].set_xticks([])
if i!=0:
axs[0,i].set_yticks([])
#axs[0,i].grid()
#axs[0,i].legend(loc='upper left')
print("Plot will be saved as: " + plotname)
# Adjust layout to prevent overlap
plt.tight_layout()
plt.savefig(plotname,format='eps')
plt.show()