-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbi_enc.jac
69 lines (63 loc) · 1.57 KB
/
bi_enc.jac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
node bi_enc {
can bi_enc.train, bi_enc.infer;
can train {
train_data = file.load_json(visitor.train_file);
bi_enc.train(
dataset=train_data,
from_scratch=visitor.from_scratch,
training_parameters={
"num_train_epochs": visitor.num_train_epochs
}
);
if (visitor.model_name):
bi_enc.save_model(model_path=visitor.model_name);
}
can infer {
res = bi_enc.infer(
contexts=[visitor.query],
candidates=visitor.labels,
context_type="text",
candidate_type="text"
)[0];
visitor.prediction = res["predicted"];
}
}
walker train {
has train_file;
has num_train_epochs = 100, from_scratch = true, model_name = "";
root {
spawn here ++> node::bi_enc;
take --> node::bi_enc;
}
bi_enc: here::train;
}
walker infer {
has query, interactive = true;
has labels, prediction;
root {
spawn here ++> node::bi_enc;
take --> node::bi_enc;
}
bi_enc {
if (interactive) {
while true {
query = std.input("Enter input text (Ctrl-C to exit)> ");
here::infer;
std.out(prediction);
}
} else {
here::infer;
report prediction;
}
}
}
walker save_model {
has model_path;
can bi_enc.save_model;
bi_enc.save_model(model_path);
}
walker load_model {
has model_path;
can bi_enc.load_model;
bi_enc.load_model(model_path);
}