-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathparameters.py
84 lines (70 loc) · 1.84 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import os
import codecs
# Device
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
# Default word tokens
PAD_token = 0 # Used for padding short sentences
EOS_token = 1 # End-of-sentence token
UNK_token = 2 # Unknowen token (out of vocabulary)
SOS_token = 3 # Start-of-sentence token
# Data Hyperparams
BATCH_SIZE = 32
BATCH_SIZE_FAST = 16
BATCH_SIZE_TEST = 128
BATCH_SIZE_VALI = 32
EPOCH = 10
EPOCH_FAST_CONVERGENCE = 20
EPOCH_RNN = 20
SRC_LENGTH = 500
TGT_LENGTH = 70
# Load pretrained embeddings
# Choose embeddings with the same dimension as hidden units
EMBED_PATH_SRC = "./data/pretrained_embedding_512_src.dat"
EMBED_PATH_TGT = "./data/pretrained_embedding_512_tgt.dat"
# training process params
CLIP = 5.0
CONTINUE_TRAIN = 0
# Load test data
TEST_DATA_PATH = [
"./data/test.txt.src.onehot", "./data/test.txt.tgt.tagged.onehot",
"./data/test.txt.tgt.tagged.mask", "./data/test.txt.tgt.tagged.gold"
]
# Load train data
TRAIN_DATA_PATH = [
"./data/train.txt.src.onehot", "./data/train.txt.tgt.tagged.onehot",
"./data/train.txt.tgt.tagged.mask", "./data/train.txt.tgt.tagged.gold"
]
# Load valid data
VALID_DATA_PATH = [
"./data/val.txt.src.onehot", "./data/val.txt.tgt.tagged.onehot",
"./data/val.txt.tgt.tagged.mask", "./data/val.txt.tgt.tagged.gold"
]
# Encoder & Decoder params
EMBED_SIZE = 512
HIDDEN_SIZE = 512
KERNEL_SIZE_ENC = 5
KERNEL_SIZE_DEC = 3
ENC_LAYERS = 20
DEC_LAYERS = 5
# vocab size
VOCAB_SIZE_SRC = 40000
VOCAB_SIZE_TGT = 10000
# dataset size
# 训练集287227
# 测试集11490
# 验证集13368
TRAIN_SIZE = 287227
TEST_SIZE = 11490
VALI_SIZE = 13368
# for print loss
PRINT_EVERY = 50
PRINT_EVERY_SMALL = 10
VALID_EVERY = 100
# for scheduled sampling
TEACHER_FORCING_RATIO_MAX = 1.0
TEACHER_FORCING_RATIO_MIN = 0.3
TEACHER_FORCING_RATIO = 1.0
# dropout
DROPOUT_RATIO = 0.1