-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlit.py
188 lines (161 loc) · 7.73 KB
/
lit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Lint as: python3
r"""Code example for a custom model, using PyTorch.
This demo shows how to use a custom model with LIT, in just a few lines of code.
We'll use a transformers model, with a minimal amount of code to implement the
LIT API. Compared to models/glue_models.py, this has fewer features, but the
code is more readable.
This demo is similar in functionality to simple_tf2_demo.py, but uses PyTorch
instead of TensorFlow 2.
The transformers library can load weights from either,
so you can use any saved model compatible with the underlying model class
(AutoModelForSequenceClassification). To train something for this demo, you can:
- Use quickstart_sst_demo.py, and set --model_path to somewhere durable
- Or: Use tools/glue_trainer.py
- Or: Use any fine-tuning code that works with transformers, such as
https://github.com/huggingface/transformers#quick-tour-of-the-fine-tuningusage-scripts
To run locally:
python -m lit_nlp.examples.simple_pytorch_demo \
--port=5432 --model_path=/path/to/saved/model
Then navigate to localhost:5432 to access the demo UI.
NOTE: this demo still uses TensorFlow Datasets (which depends on TensorFlow) to
load the data. However, the output of glue.SST2Data is just NumPy arrays and
plain Python data, and you can easily replace this with a different library or
directly loading from CSV.
"""
import logging
import os
from typing import List
import numpy as np
import torch
import transformers
from absl import app
from absl import flags
from absl import logging
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.app import JsonDict
from utils_ner import I2b2Dataset
from utils_ner import get_labels
# NOTE: additional flags defined in server_flags.py
FLAGS = flags.FLAGS
flags.DEFINE_string(
"model_path", None,
"Path to trained model, in standard transformers format, e.g. as "
"saved by model.save_pretrained() and tokenizer.save_pretrained()")
flags.DEFINE_string(
"labels", None,
"Path to labels file")
flags.DEFINE_string(
"test_data_dir", None,
"Directory to data file where test.txt exists")
def _from_pretrained(cls, *args, **kw):
"""Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
try:
return cls.from_pretrained(*args, **kw)
except OSError as e:
logging.warning("Caught OSError loading model: %s", e)
logging.warning(
"Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
return cls.from_pretrained(*args, from_tf=True, **kw)
class NerModel(lit_model.Model):
"""Simple NER model."""
compute_grads: bool = True # if True, compute and return gradients.
def __init__(self, model_name_or_path, labels_file='labels.txt'):
self.LABELS = get_labels(labels_file)
num_labels = len(self.LABELS)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path)
model_config = transformers.AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
output_hidden_states=True,
output_attentions=True,
)
# This is a just a regular PyTorch model.
self.model = _from_pretrained(
transformers.AutoModelForTokenClassification,
model_name_or_path,
config=model_config)
self.model.load_state_dict(
torch.load(os.path.join(model_name_or_path, 'pytorch_model.bin'), map_location='cpu'))
self.model.eval()
def max_minibatch_size(self):
return 1
def predict_minibatch(self,
inputs: List[JsonDict],
config=None) -> List[JsonDict]:
"""
batch size set to 1 for simplicity, to use batch size greater than one, will need
to use self.tokenizer.batch_encode_plus as in the LIT examples
:param inputs: JSON of sentence and token to interpret
:param config:
:return: prediction output aligned with spec
"""
mask_token = '[MASK]'
sentence = inputs[0]['Sentence']
interpret_token_id = inputs[0]['Token Index to Explain']
tokens = ['[CLS]'] + self.tokenizer.tokenize(sentence) + ['[SEP]']
input_ids = [self.tokenizer.convert_tokens_to_ids(tokens)]
input_mask = [[1] * len(input_ids[0])]
input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
input_mask_tensor = torch.tensor(input_mask, dtype=torch.long)
# Needed for calculating grad based on embeddings
interpretable_embedding = configure_interpretable_embedding_layer(self.model, 'bert.embeddings.word_embeddings')
input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids_tensor)
model_input = {
"inputs_embeds": input_embeddings,
"attention_mask": input_mask_tensor}
model_output = self.model(**model_input)
logits, embs, unused_attentions = model_output[:3]
logits_ndarray = logits.detach().cpu().numpy()
example_preds = np.argmax(logits_ndarray, axis=2)
confidences = torch.softmax(torch.from_numpy(logits_ndarray), dim=2).detach().cpu().numpy()
label_map = {i: label for i, label in enumerate(self.LABELS)}
predictions = [label_map[pred] for pred in example_preds[0]]
outputs = {}
for i, attention_layer in enumerate(unused_attentions):
outputs[f'layer_{i}/attention'] = attention_layer[0].detach().cpu().numpy().copy()
# TODO Currently LIT lime explainer does not support targeting a specific token, until that's fixed,
# we explain the first non-O index if there's one, or the first token (after [CLS]).
if interpret_token_id < 0 or mask_token in sentence:
scalar_output = np.where(example_preds[0] != 0)[0]
token_index = scalar_output[0] if len(scalar_output > 1) else 1
else:
# TODO When LIT lime explainer is configurable, we'll set the token_index from the UI
token_index = interpret_token_id
outputs['tokens'] = tokens
outputs['bio_tags'] = predictions
grad = torch.autograd.grad(torch.unbind(logits[0][token_index]), embs[0])
outputs['grads'] = grad[0][0].detach().cpu().numpy()
outputs['probas'] = confidences[0][token_index]
outputs['token_ids'] = list(range(0, len(tokens)))
remove_interpretable_embedding_layer(self.model, interpretable_embedding)
yield outputs
def input_spec(self) -> lit_types.Spec:
return {
"Sentence": lit_types.TextSegment(),
"Token Index to Explain": lit_types.Scalar()
}
def output_spec(self) -> lit_types.Spec:
spec = {
"tokens": lit_types.Tokens(),
"bio_tags": lit_types.SequenceTags(align="tokens"),
"token_ids": lit_types.SequenceTags(align="tokens"),
"grads": lit_types.TokenGradients(align="tokens"),
"probas": lit_types.MulticlassPreds(parent="bio_tags", vocab=self.LABELS)
}
for i in range(self.model.config.num_hidden_layers):
spec[f'layer_{i}/attention'] = lit_types.AttentionHeads(align=("tokens", "tokens"))
return spec
def main(_):
# Load the model we defined above.
models = {"NCBI BERT Finetuned": NerModel(FLAGS.model_path, labels_file=FLAGS.labels)}
datasets = {"I2b2 2014": I2b2Dataset(data_dir=FLAGS.test_data_dir)}
# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
lit_demo.serve()
if __name__ == "__main__":
app.run(main)