forked from PlusLabNLP/CLUSTER
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
executable file
·83 lines (64 loc) · 3.36 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#encoding:utf-8
import torch
import warnings
from torch.utils.data import DataLoader
from pybert.io.dataset import CreateDataset
from pybert.io.data_transformer import DataTransformer
from pybert.utils.logginger import init_logger
from pybert.utils.utils import seed_everything
from pybert.config.basic_config import configs as config
from pybert.model.nn.bert_fine import BertFine
from pybert.test.predicter import Predicter
from pybert.preprocessing.preprocessor import EnglishPreProcessor
from pytorch_pretrained_bert.tokenization import BertTokenizer
warnings.filterwarnings("ignore")
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'
def main():
logger = init_logger(log_name=config['model']['arch'], log_dir=config['output']['log_dir'])
logger.info(f"seed is {config['train']['seed']}")
device = 'cuda:%d' % config['train']['n_gpu'][0] if len(config['train']['n_gpu']) else 'cpu'
seed_everything(seed=config['train']['seed'],device=device)
logger.info('starting load data from disk')
id2label = {value: key for key, value in config['label2id'].items()}
DT = DataTransformer(logger = logger,seed = config['train']['seed'])
targets, sentences,ids = DT.read_data(raw_data_path=config['data']['test_file_path'],
preprocessor=EnglishPreProcessor(),
is_train=False)
tokenizer = BertTokenizer(vocab_file=config['pretrained']['bert']['vocab_path'],
do_lower_case=config['train']['do_lower_case'])
# test dataset
test_dataset = CreateDataset(data = list(zip(sentences,targets)),
tokenizer = tokenizer,
max_seq_len = config['train']['max_seq_len'],
seed = config['train']['seed'],
example_type = 'test')
test_loader = DataLoader(dataset = test_dataset,
batch_size = config['train']['batch_size'],
num_workers = config['train']['num_workers'],
shuffle = False,
drop_last = False,
pin_memory = False)
# **************************** load pretrained model from cache ***********************
logger.info("initializing model")
model = BertFine.from_pretrained(config['pretrained']['bert']['bert_model_dir'],
cache_dir=config['output']['cache_dir'],
num_classes = len(id2label))
# **************************** inference ***********************
logger.info('model predicting....')
predicter = Predicter(model = model,
logger = logger,
n_gpu=config['train']['n_gpu'],
model_path = config['output']['checkpoint_dir'] / f"best_{config['model']['arch']}_model.pth",
)
# predict
result = predicter.predict(data = test_loader)
file = open(config['output']['inference_output_dir'],'w')
for index, line, score in zip(ids, sentences, result):
file.write(str(index) + '\t' + line + '\t' + str(score[0]))
file.write('\n')
file.close()
if len(config['train']['n_gpu']) > 0:
torch.cuda.empty_cache()
if __name__ == '__main__':
main()