forked from wang-fujin/PINN4SOH
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_XJTU.py
111 lines (96 loc) · 5.26 KB
/
main_XJTU.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from dataloader.dataloader import XJTUdata
from Model.Model import PINN
import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def load_data(args,small_sample=None):
root = 'data/XJTU data'
data = XJTUdata(root=root, args=args)
train_list = []
test_list = []
files = os.listdir(root)
for file in files:
if args.batch in file:
if '4' in file or '8' in file:
test_list.append(os.path.join(root, file))
else:
train_list.append(os.path.join(root, file))
if small_sample is not None:
train_list = train_list[:small_sample]
train_loader = data.read_all(specific_path_list=train_list)
test_loader = data.read_all(specific_path_list=test_list)
dataloader = {'train': train_loader['train_2'],
'valid': train_loader['valid_2'],
'test': test_loader['test_3']}
return dataloader
def main():
args = get_args()
batchs = ['2C', '3C', 'R2.5', 'R3', 'RW', 'satellite']
for i in range(6):
batch = batchs[i]
setattr(args, 'batch', batch)
for e in range(10):
save_folder = 'results of reviewer/XJTU results/' + str(i) + '-' + str(i) + '/Experiment' + str(e + 1)
if not os.path.exists(save_folder):
os.makedirs(save_folder)
log_dir = 'logging.txt'
setattr(args, "save_folder", save_folder)
setattr(args, "log_dir", log_dir)
dataloader = load_data(args)
pinn = PINN(args)
pinn.Train(trainloader=dataloader['train'],validloader=dataloader['valid'],testloader=dataloader['test'])
def small_sample():
args = get_args()
batchs = ['2C', '3C', 'R2.5', 'R3', 'RW', 'satellite']
for n in [1,2,3,4]:
for i in range(6):
batch = batchs[i]
setattr(args, 'batch', batch)
setattr(args,'batch_size',128)
setattr(args,'alpha',0.5)
setattr(args,'beta',10)
for e in range(10):
save_folder = f'results/XJTU results (small sample {n})/' + str(i) + '-' + str(i) + '/Experiment' + str(e + 1)
if not os.path.exists(save_folder):
os.makedirs(save_folder)
log_dir = 'logging.txt'
setattr(args, "save_folder", save_folder)
setattr(args, "log_dir", log_dir)
dataloader = load_data(args,small_sample=n)
pinn = PINN(args)
pinn.Train(trainloader=dataloader['train'], validloader=dataloader['valid'],
testloader=dataloader['test'])
def get_args():
parser = argparse.ArgumentParser('Hyper Parameters for XJTU dataset')
parser.add_argument('--data', type=str, default='XJTU', help='XJTU, HUST, MIT, TJU')
parser.add_argument('--train_batch', type=int, default=0, choices=[-1,0,1,2,3,4,5],
help='如果是-1,读取全部数据,并随机划分训练集和测试集;否则,读取对应的batch数据'
'(if -1, read all data and random split train and test sets; '
'else, read the corresponding batch data)')
parser.add_argument('--test_batch', type=int, default=1, choices=[-1,0,1,2,3,4,5],
help='如果是-1,读取全部数据,并随机划分训练集和测试集;否则,读取对应的batch数据'
'(if -1, read all data and random split train and test sets; '
'else, read the corresponding batch data)')
parser.add_argument('--batch',type=str,default='2C',choices=['2C','3C','R2.5','R3','RW','satellite'])
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--normalization_method', type=str, default='min-max', help='min-max,z-score')
# scheduler related
parser.add_argument('--epochs', type=int, default=200, help='epoch')
parser.add_argument('--early_stop', type=int, default=20, help='early stop')
parser.add_argument('--warmup_epochs', type=int, default=30, help='warmup epoch')
parser.add_argument('--warmup_lr', type=float, default=0.002, help='warmup lr')
parser.add_argument('--lr', type=float, default=0.01, help='base lr')
parser.add_argument('--final_lr', type=float, default=0.0002, help='final lr')
parser.add_argument('--lr_F', type=float, default=0.001, help='lr of F')
# model related
parser.add_argument('--F_layers_num', type=int, default=3, help='the layers num of F')
parser.add_argument('--F_hidden_dim', type=int, default=60, help='the hidden dim of F')
# loss related
parser.add_argument('--alpha', type=float, default=0.7, help='loss = l_data + alpha * l_PDE + beta * l_physics')
parser.add_argument('--beta', type=float, default=0.2, help='loss = l_data + alpha * l_PDE + beta * l_physics')
parser.add_argument('--log_dir', type=str, default='text log.txt', help='log dir, if None, do not save')
parser.add_argument('--save_folder', type=str, default='results of reviewer/XJTU results', help='save folder')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()