forked from deeppavlov/DeepPavlov
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep.py
77 lines (60 loc) · 3.05 KB
/
deep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Copyright 2017 Neural Networks and Deep Learning lab, MIPT
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
from pathlib import Path
import sys
import os
p = (Path(__file__) / ".." / "..").resolve()
sys.path.append(str(p))
from deeppavlov.core.commands.train import train_model_from_config
from deeppavlov.core.commands.infer import interact_model, predict_on_stream
from deeppavlov.core.common.log import get_logger
from deeppavlov.download import deep_download
from utils.telegram_utils.telegram_ui import interact_model_by_telegram
from utils.server_utils.server import start_model_server
log = get_logger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("mode", help="select a mode, train or interact", type=str,
choices={'train', 'interact', 'predict', 'interactbot', 'riseapi', 'download'})
parser.add_argument("config_path", help="path to a pipeline json config", type=str)
parser.add_argument("-t", "--token", help="telegram bot token", type=str)
parser.add_argument("-b", "--batch-size", dest="batch_size", default=1, help="inference batch size", type=int)
parser.add_argument("-f", "--input-file", dest="file_path", default=None, help="Path to the input file", type=str)
parser.add_argument("-d", "--download", action="store_true", help="download model components")
def main():
args = parser.parse_args()
pipeline_config_path = args.config_path
if not Path(pipeline_config_path).is_file():
configs = [c for c in Path(__file__).parent.glob(f'configs/**/{pipeline_config_path}.json')
if str(c.with_suffix('')).endswith(pipeline_config_path)] # a simple way to not allow * and ?
if configs:
log.info(f"Interpriting '{pipeline_config_path}' as '{configs[0]}'")
pipeline_config_path = str(configs[0])
token = args.token or os.getenv('TELEGRAM_TOKEN')
if args.download or args.mode == 'download':
deep_download(['-c', pipeline_config_path])
if args.mode == 'train':
train_model_from_config(pipeline_config_path)
elif args.mode == 'interact':
interact_model(pipeline_config_path)
elif args.mode == 'interactbot':
if not token:
log.error('Token required: initiate -t param or TELEGRAM_BOT env var with Telegram bot token')
else:
interact_model_by_telegram(pipeline_config_path, token)
elif args.mode == 'riseapi':
start_model_server(pipeline_config_path)
elif args.mode == 'predict':
predict_on_stream(pipeline_config_path, args.batch_size, args.file_path)
if __name__ == "__main__":
main()