-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmy_fonduer_model.py
74 lines (58 loc) · 2.59 KB
/
my_fonduer_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Iterable, Set, Tuple
from pandas import DataFrame
import numpy as np
from emmental.data import EmmentalDataLoader
from fonduer.parser.models import Document
from fonduer.candidates.models import Candidate
from fonduer.learning.dataset import FonduerDataset
from fonduer_model import FonduerModel, _F_matrix, _L_matrix
def get_entity_relation(candidate: Candidate) -> Tuple:
return tuple(m.context.get_span() for m in candidate.get_mentions())
def get_unique_entity_relations(candidates: Iterable[Candidate]) -> Set[Tuple]:
unique_entity_relation = set()
for candidate in candidates:
entity_relation = get_entity_relation(candidate)
unique_entity_relation.add(entity_relation)
return unique_entity_relation
ABSTAIN = -1
FALSE = 0
TRUE = 1
class MyFonduerModel(FonduerModel):
def _classify(self, doc: Document) -> DataFrame:
# Only one candidate class is defined.
candidate_class = self.candidate_extractor.candidate_classes[0]
test_cands = getattr(doc, candidate_class.__tablename__ + "s")
if self.model_type == "emmental":
# Featurization
features_list = self.featurizer.apply(doc)
# Convert features into a sparse matrix
F_test = _F_matrix(features_list[0], self.key_names)
# Dataloader for test
ATTRIBUTE = "wiki"
test_dataloader = EmmentalDataLoader(
task_to_label_dict={ATTRIBUTE: "labels"},
dataset=FonduerDataset(
ATTRIBUTE, test_cands, F_test, self.word2id, 2
),
split="test",
batch_size=100,
shuffle=False,
)
test_preds = self.emmental_model.predict(test_dataloader, return_preds=True)
positive = np.where(np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
true_preds = [test_cands[_] for _ in positive[0]]
else:
labels_list = self.labeler.apply(doc, lfs=self.lfs)
L_test = _L_matrix(labels_list[0], self.key_names)
marginals = self.label_models[0].predict_proba(L_test)
for cand, prob in zip(test_cands, marginals[:,1]):
cand.prob = prob
true_preds = sorted(test_cands, key=lambda cand: cand.prob, reverse=True)
df = DataFrame()
for entity_relation in get_unique_entity_relations(true_preds):
df = df.append(
DataFrame([entity_relation],
columns=[m.__name__ for m in candidate_class.mentions]
)
)
return df