-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathtrain_mlnmt.py
28 lines (22 loc) · 930 Bytes
/
train_mlnmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse
import logging
import pprint
import importlib
from mlnmt import train
from mcg.stream import (get_tr_stream, get_dev_streams,
get_logprob_streams)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
# Get the arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config", help="model config file")
parser.add_argument("--proto",
default="get_config_multiWay",
help="Prototype config to use for model configuration")
args = parser.parse_args()
cfg = importlib.import_module(
args.config.split('.')[0] if '.py' in args.config else args.config)
config = getattr(cfg, args.proto)()
logger.info("Model options:\n{}".format(pprint.pformat(config)))
train(config, get_tr_stream(config), get_dev_streams(config),
get_logprob_streams(config))