forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtanda-sequential-finetuning-with-asnq.diff
110 lines (107 loc) · 4.52 KB
/
tanda-sequential-finetuning-with-asnq.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
diff --git a/examples/run_glue.py b/examples/run_glue.py
index 19316cb..bef98bb 100644
--- a/examples/run_glue.py
+++ b/examples/run_glue.py
@@ -354,6 +354,8 @@ def main(): @@ this change adds `sequential fine-tuning` option
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
+ parser.add_argument("--sequential", action='store_true',
+ help="Sequential fine-tune from a checkpoint.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true',
@@ -488,7 +490,10 @@ def main(): @@ this change adds `sequential fine-tuning` option
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None)
- model = model_class.from_pretrained(args.model_name_or_path,
+ if args.sequential:
+ model = model_class.from_pretrained(args.output_dir)
+ else:
+ model = model_class.from_pretrained(args.model_name_or_path,
from_tf=bool('.ckpt' in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None)
diff --git a/transformers/data/metrics/__init__.py b/transformers/data/metrics/__init__.py
index c9ebaac..eddb931 100644
--- a/transformers/data/metrics/__init__.py
+++ b/transformers/data/metrics/__init__.py
@@ -79,5 +79,7 @@ if _has_sklearn: @@ this change adds processing for `asnq`
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
+ elif task_name == "asnq":
+ return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
diff --git a/transformers/data/processors/glue.py b/transformers/data/processors/glue.py
index 518251b..de77dca 100644
--- a/transformers/data/processors/glue.py
+++ b/transformers/data/processors/glue.py
@@ -513,6 +513,46 @@ class WnliProcessor(DataProcessor): @@ this change adds processing for `asnq`
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
+
+class AsnqProcessor(DataProcessor):
+ """Processor for the ASNQ data set (for TANDA)."""
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(tensor_dict['idx'].numpy(),
+ tensor_dict['sentence'].numpy().decode('utf-8'),
+ None,
+ str(tensor_dict['label'].numpy()))
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "%s-%s" % (set_type, i)
+ text_a = line[0]
+ text_b = line[1]
+ if eval(line[-1].strip()) == 4:
+ label = "1"
+ else:
+ label = "0"
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+ return examples
+
glue_tasks_num_labels = {
"cola": 2,
"mnli": 3,
@@ -523,6 +563,7 @@ glue_tasks_num_labels = {
"qnli": 2,
"rte": 2,
"wnli": 2,
+ "asnq": 2,
}
glue_processors = {
@@ -536,6 +577,7 @@ glue_processors = {
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
+ "asnq": AsnqProcessor,
}
glue_output_modes = {
@@ -549,4 +591,5 @@ glue_output_modes = {
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
+ "asnq": "classification",
}