-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
68 lines (61 loc) · 3.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python
__author__ = "Alireza Moradzadeh"
__license__ = "GPL"
__version__ = "1.0.1"
__maintainer__ = "Alireza Moradzadeh"
__email__ = "[email protected]"
import numpy as np
import os, sys
import time
import tensorflow as tf
import csv
def train(model, data_set_train, data_set_val, learning_rate=0.0005,
batch_size=16, num_steps=5000, step_save=1000, step_show=1000,
pre_train=False, working_dir=None):
"""
Training loop of mini-batch gradient descent.
Performs mini-batch gradient descent with the indicated batch_size and
learning_rate.
----------------------------------------------------------------------
Inputs:
model(ForceFieldNN): Initialized a NN Force Field model.
data_set: MD dataset.
learning_rate(float): Learning rate.
batch_size(int): Batch size used for training.
num_steps(int): Number of steps to run the update ops.
"""
if pre_train:
print("not implemented yet!")
else:
for step in range(0, num_steps):
batch_r2, batch_z2, batch_r3_jik, batch_z3_jik, batch_r3_ijk, batch_z3_ijk, batch_r3_jki, batch_z3_jki, batch_force= data_set_train.next_batch(batch_size)
model.session.run(model.update_op_tensor,
feed_dict={model.R2_ph: batch_r2, model.Z2_ph: batch_z2,
model.R3_jik_ph: batch_r3_jik, model.Z3_jik_ph: batch_z3_jik,
model.R3_ijk_ph: batch_r3_ijk, model.Z3_ijk_ph: batch_z3_ijk,
model.R3_jki_ph: batch_r3_jki, model.Z3_jki_ph: batch_z3_jki,
model.F_ph: batch_force,
model.lr_placeholder: learning_rate})
if step % step_save == 0:
model.save(step)
if step % step_show == 0:
loss_train = model.session.run(model.mae_loss,
feed_dict={model.R2_ph: batch_r2, model.Z2_ph: batch_z2,
model.R3_jik_ph: batch_r3_jik, model.Z3_jik_ph: batch_z3_jik,
model.R3_ijk_ph: batch_r3_ijk, model.Z3_ijk_ph: batch_z3_ijk,
model.R3_jki_ph: batch_r3_jki, model.Z3_jki_ph: batch_z3_jki,
model.F_ph: batch_force})
batch_r2, batch_z2, batch_r3_jik, batch_z3_jik, batch_r3_ijk, batch_z3_ijk, batch_r3_jki, batch_z3_jki, batch_force= data_set_val.next_batch(5*batch_size)
loss_val = model.session.run(model.mae_loss,
feed_dict={model.R2_ph: batch_r2, model.Z2_ph: batch_z2,
model.R3_jik_ph: batch_r3_jik, model.Z3_jik_ph: batch_z3_jik,
model.R3_ijk_ph: batch_r3_ijk, model.Z3_ijk_ph: batch_z3_ijk,
model.R3_jki_ph: batch_r3_jki, model.Z3_jki_ph: batch_z3_jki,
model.F_ph: batch_force})
os.chdir(os.path.join(working_dir))
with open('loss.dat','a') as f:
f.write(str(step+model.global_step )+' ')
f.write(str(loss_train)+' ')
f.write(str(loss_val)+' ')
f.write('\n')
f.close()