-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
45 lines (37 loc) · 1.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from lattice import load_lattices, drawLattice, load_dictionary
from model import LatticeRNN
from trainer import Trainer
if __name__=="__main__":
device = 'cuda'
# Load lattices dataset
lattices_data = load_lattices('./data/lattices.npz')
## Visualize one example
# dot = drawLattice(lattices_data[10])
# dot.render(directory='graph-outputs', view=True)
# Load dictionary (tag2idx)
tag2idx = load_dictionary('./data/dictionary.txt')
for lattice in lattices_data:
lattice.tags_idx = [tag2idx[tag] for tag in lattice.tags]
# Split the dataset
p_train = 0.9
num_train = int(p_train * len(lattices_data))
train_data, eval_data = lattices_data[:num_train], lattices_data[num_train:]
print("Number of training data:", len(train_data))
print("Number of eval data:", len(eval_data))
print("===================================")
# Define model
lattice_rnn_model = LatticeRNN(input_dim=128, output_dim=1, hidden_dim=128)
lattice_rnn_model.to(device)
# Train the model
N_epochs = 10
n_logging_batches = 500
checkpoint_dir = './checkpoints/'
trainer = Trainer(
lattice_rnn_model,
train_data, eval_data,
log_steps=n_logging_batches,
epochs=N_epochs,
device=device,
checkpoint_dir=checkpoint_dir,
)
trainer()