Skip to content

Commit b8b318b

Browse files
committed
Add trainer file for running the transformer
1 parent d05b634 commit b8b318b

File tree

3 files changed

+322
-5
lines changed

3 files changed

+322
-5
lines changed

applications/FLASK/Transformer/arg_utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,28 @@ def add_dataset_arguments(args: argparse.Namespace, default: str):
5050
type=float,
5151
help='Fraction of dataset to use (default: 1.0)',
5252
metavar='NUM')
53+
54+
55+
def add_training_arguments(parser: argparse.ArgumentParser):
56+
parser.add_argument("--skip-validation",
57+
action="store_true",
58+
default=False,
59+
help="Do not run validation (default: false)")
60+
parser.add_argument(
61+
"--always-shuffle",
62+
action="store_true",
63+
default=False,
64+
help=
65+
"Always shuffle training dataset, even if pretokenized (default: false)"
66+
)
67+
parser.add_argument(
68+
"--validation-set-fraction",
69+
type=float,
70+
default=0.01,
71+
help="Fraction of the validation dataset to use (default: 0.001)")
72+
parser.add_argument(
73+
"--save-prototext",
74+
action="store_true",
75+
default=False,
76+
help="Save prototext experiment file instead of protobin (slower but "
77+
"debuggable) (default: false)")

applications/FLASK/Transformer/datasets/pretokenize/QM9_Pretokenize.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import numpy as np
22
from SMILES_tokenizer import MolTokenizer
3-
4-
5-
def random_zero_array(arr, probability, mask):
6-
return np.where(np.random.random(arr.shape) < probability, mask, arr)
3+
from data_utils import random_zero_array
74

85

96
def main():
@@ -27,4 +24,3 @@ def main():
2724

2825
if __name__ == '__main__':
2926
main()
30-
+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
"""
2+
Constructs the LBANN distributed training script for transformers.
3+
"""
4+
import argparse
5+
import datetime
6+
import os.path
7+
8+
import lbann
9+
import lbann.models
10+
import lbann.contrib.args
11+
import lbann.contrib.launcher
12+
from lbann.launcher.batch_script import BatchScript
13+
14+
import utils.paths
15+
16+
import dataset_utils
17+
import model
18+
19+
20+
def construct_training_task(model: lbann.Model,
21+
args: argparse.Namespace,
22+
learning_rate: float = 0.0001,
23+
beta1: float = 0.9,
24+
beta2: float = 0.98,
25+
eps: float = 1e-9,
26+
clip_gradient: float = 0.0,
27+
lr_decay: str = 'fixed',
28+
lr_decay_steps: int = 0,
29+
end_learning_rate: float = 1e-5,
30+
warmup_steps: int = 0,
31+
adamw_decay: float = 0.1) -> BatchScript:
32+
"""
33+
Construct an LBANN trainer batch script for training transformers.
34+
35+
:param model: An LBANN model.
36+
:param args: Command-line arguments.
37+
:param learning_rate: Learning rate.
38+
:param beta1: Adam beta1 factor.
39+
:param beta2: Adam beta2 factor.
40+
:param eps: Adam epsilon factor.
41+
:param clip_gradient: Clip gradient norm to value (0 disables).
42+
:param lr_decay: Learning rate decay type (values: fixed, cosine, none).
43+
:param lr_decay_steps: Steps for the total learning decay process (in cosine
44+
decay).
45+
:param end_learning_rate: Learning rate after decay.
46+
:param warmup_steps: Learning rate warmup steps.
47+
:param adamw_decay: Weight decay if using the AdamW optimizer.
48+
:return: A batch script object that will run distributed training.
49+
"""
50+
51+
# Setup working directory
52+
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
53+
work_dir = f'{timestamp}_{args.job_name}'
54+
work_dir = os.path.abspath(work_dir)
55+
os.makedirs(work_dir, exist_ok=True)
56+
57+
# Create batch script
58+
train_script = make_batch_script(model, args.dataset, work_dir, args,
59+
learning_rate, beta1, beta2, eps,
60+
clip_gradient, lr_decay, lr_decay_steps,
61+
end_learning_rate, warmup_steps,
62+
adamw_decay)
63+
64+
return train_script
65+
66+
67+
# ----------------------------------------------
68+
# Data reader
69+
# ----------------------------------------------
70+
def make_data_reader(dataset_name: str, fraction: float, validate: bool,
71+
val_fraction: float, always_shuffle: bool):
72+
reader = lbann.reader_pb2.DataReader()
73+
_reader = reader.reader.add()
74+
_reader.name = 'python'
75+
_reader.role = 'train'
76+
_reader.shuffle = (True if always_shuffle
77+
or 'pretokenized' not in dataset_name else False)
78+
_reader.fraction_of_data_to_use = fraction
79+
_reader.python.module = dataset_name
80+
_reader.python.module_dir = os.path.join(
81+
os.path.dirname(os.path.realpath(__file__)),
82+
'datasets',
83+
)
84+
_reader.python.sample_function = 'get_train_sample'
85+
_reader.python.num_samples_function = 'num_train_samples'
86+
_reader.python.sample_dims_function = 'sample_dims'
87+
88+
if validate:
89+
# Validation data reader
90+
vreader = reader.reader.add()
91+
vreader.name = 'python'
92+
vreader.role = 'validate'
93+
vreader.shuffle = False
94+
vreader.fraction_of_data_to_use = val_fraction
95+
vreader.python.module = _reader.python.module
96+
vreader.python.module_dir = _reader.python.module_dir
97+
vreader.python.sample_function = 'get_val_sample'
98+
vreader.python.num_samples_function = 'num_val_samples'
99+
vreader.python.sample_dims_function = 'sample_dims'
100+
101+
return reader
102+
103+
104+
# ----------------------------------------------
105+
# Batch script
106+
# ----------------------------------------------
107+
def make_batch_script(model: lbann.Model,
108+
dataset_name: str,
109+
work_dir: str,
110+
args: argparse.Namespace,
111+
learning_rate: float = 0.0001,
112+
beta1: float = 0.9,
113+
beta2: float = 0.98,
114+
eps: float = 1e-9,
115+
clip_gradient: float = 0.0,
116+
lr_decay: str = 'fixed',
117+
lr_decay_steps: int = 0,
118+
end_learning_rate: float = 1e-5,
119+
warmup_steps: int = 0,
120+
adamw_decay: float = 0.1):
121+
# Setup training algorithm
122+
algo = lbann.BatchedIterativeOptimizer("sgd", epoch_count=args.num_epochs)
123+
if hasattr(args, 'kfac') and args.kfac:
124+
algo = create_kfac_optimizer(algo, args)
125+
126+
# Create LBANN trainer and data reader
127+
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size,
128+
training_algo=algo)
129+
reader = make_data_reader(dataset_name, args.dataset_fraction,
130+
not args.skip_validation,
131+
args.validation_set_fraction,
132+
args.always_shuffle)
133+
134+
# Optimizer with learning rate schedule
135+
if args.optimizer.lower() == 'adamw':
136+
opt = lbann.Adam(learn_rate=learning_rate,
137+
beta1=beta1,
138+
beta2=beta2,
139+
eps=eps,
140+
adamw_weight_decay=adamw_decay)
141+
elif args.optimizer.lower() == 'adam':
142+
opt = lbann.Adam(learn_rate=learning_rate,
143+
beta1=beta1,
144+
beta2=beta2,
145+
eps=eps)
146+
147+
if lr_decay == 'fixed':
148+
if warmup_steps > 0:
149+
raise NotImplementedError(
150+
'Warmup not implemented with fixed learning rate')
151+
152+
model.callbacks.append(
153+
lbann.CallbackDropFixedLearningRate(
154+
drop_epoch=[1],
155+
amt=2,
156+
))
157+
model.callbacks.append(
158+
lbann.CallbackDropFixedLearningRate(
159+
drop_epoch=[2, 4, 8, 12],
160+
amt=0.75,
161+
))
162+
elif lr_decay == 'cosine':
163+
model.callbacks.append(
164+
lbann.CallbackCosineDecayLearningRate(
165+
lr_max=learning_rate,
166+
lr_min=end_learning_rate,
167+
decay_steps=lr_decay_steps,
168+
initial_warmup_learning_rate=end_learning_rate,
169+
warmup_steps=warmup_steps,
170+
))
171+
172+
print(f'Training schedule: warmup to LR={learning_rate:.6f} in '
173+
f'{warmup_steps} steps, cosine decay to '
174+
f'LR={end_learning_rate:.6f} in {lr_decay_steps} steps')
175+
176+
if clip_gradient > 0:
177+
model.callbacks.append(
178+
lbann.CallbackClipGradientNorm(global_norm=True,
179+
value=clip_gradient))
180+
181+
# Checkpoint after every epoch
182+
if args.checkpoint:
183+
trainer.callbacks.append(
184+
lbann.CallbackCheckpoint(
185+
checkpoint_dir=os.path.join(work_dir, 'checkpoint'),
186+
checkpoint_epochs=1,
187+
))
188+
189+
# Dump weights after every epoch
190+
model.callbacks.append(
191+
lbann.CallbackDumpWeights(
192+
directory=os.path.join(work_dir, 'weights'),
193+
epoch_interval=1,
194+
))
195+
196+
# Print a progress bar
197+
if args.progress:
198+
model.callbacks.append(
199+
lbann.CallbackProgressBar(newline_interval=100,
200+
print_mem_usage=True))
201+
202+
model.callbacks.extend(lbann.contrib.args.create_profile_callbacks(args))
203+
204+
script_params = lbann.contrib.args.get_scheduler_kwargs(args)
205+
script_params['work_dir'] = work_dir
206+
script_params['job_name'] = args.job_name
207+
script_params['environment'] = {
208+
"LBANN_NO_INPLACE": 1,
209+
"LBANN_DISABLE_DISTCONV": 1,
210+
}
211+
212+
save_text = args.save_prototext
213+
filename = 'experiment.prototext' if save_text else 'experiment.protobin'
214+
# Create Protobuf file
215+
protobuf_file = os.path.join(work_dir, filename)
216+
217+
lbann.proto.save_prototext(protobuf_file,
218+
binary=not save_text,
219+
trainer=trainer,
220+
model=model,
221+
data_reader=reader,
222+
optimizer=opt)
223+
224+
# Create batch script
225+
script_params.pop('setup_only', None) # Drop this argument.
226+
script = lbann.contrib.launcher.make_batch_script(**script_params)
227+
script.add_command('echo "Started training at $(date)"')
228+
script.add_parallel_command([
229+
lbann.lbann_exe(),
230+
f'--prototext={protobuf_file}',
231+
] + lbann.contrib.args.get_profile_args(args))
232+
script.add_command('status=$?')
233+
script.add_command('echo "Finished training at $(date)"')
234+
script.add_command('exit ${status}')
235+
return script
236+
237+
238+
def main():
239+
# Setup command line options
240+
parser = argparse.ArgumentParser()
241+
lbann.contrib.args.add_scheduler_arguments(parser, 'lbann_transformer')
242+
lbann.contrib.args.add_profiling_arguments(parser)
243+
lbann.contrib.args.add_training_arguments(parser)
244+
dataset_utils.add_transformer_architecture_arguments(parser)
245+
dataset_utils.add_training_arguments(parser)
246+
dataset_utils.add_dataset_arguments(parser, default='qm9')
247+
248+
parser.add_argument('--optimizer',
249+
type=str,
250+
default='adam',
251+
choices=['adam', 'adamw'],
252+
help='Stochastic optimizer used in training')
253+
254+
# Model parameters
255+
parser.add_argument(
256+
"--dropout",
257+
type=float,
258+
default=0.1,
259+
help="Dropout ratio in transformer model. 0 disables (default: 0.1)")
260+
parser.add_argument(
261+
"--input-dropout",
262+
type=float,
263+
default=0.0,
264+
help="Dropout ratio after input encoding (default: 0.0 = disabled)")
265+
266+
parser.set_defaults(progress=True, skip_validation=True)
267+
args = parser.parse_args()
268+
269+
# Load dataset
270+
dataset = dataset_utils.load_dataset(args.dataset)
271+
272+
# Construct model
273+
model: lbann.Model = model.create_encoder_decoder_transformer(
274+
dataset, args)
275+
276+
# Construct trainer
277+
train_script: BatchScript = construct_training_task(model, args)
278+
279+
# Run trainer
280+
retval = train_script.run(overwrite=True)
281+
if retval != 0:
282+
return
283+
284+
if args.checkpoint:
285+
print(
286+
'Training complete, to evaluate the translation with a BLEU score, '
287+
'run ``evaluate_translation_bleu.py`` with the model checkpoint path '
288+
'and the same arguments as this training run.')
289+
else:
290+
print(
291+
'Training complete, to evaluate the translation with a BLEU score, '
292+
'run this script with --checkpoint')
293+
294+
295+
if __name__ == '__main__':
296+
main()

0 commit comments

Comments
 (0)