forked from twistedcubic/attention-rank-collapse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
glue.py
148 lines (127 loc) · 5.38 KB
/
glue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
import os
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Union
import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset
from dataclasses import dataclass
from transformers.tokenization_bart import BartTokenizer, BartTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
from glue_processors import glue_convert_examples_to_features, glue_output_modes, glue_processors, InputFeatures
logger = logging.getLogger(__name__)
@dataclass
class GlueDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: str = field(
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
)
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
def __post_init__(self):
self.task_name = self.task_name.lower()
class Split(Enum):
train = "train"
dev = "dev"
test = "test"
class GlueDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""
args: GlueDataTrainingArguments
output_mode: str
features: List[InputFeatures]
def __init__(
self,
args: GlueDataTrainingArguments,
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
mode: Union[str, Split] = Split.train,
cache_dir: Optional[str] = None,
):
self.args = args
self.processor = glue_processors[args.task_name]()
self.output_mode = glue_output_modes[args.task_name]
if isinstance(mode, str):
try:
mode = Split[mode]
except KeyError:
raise KeyError("mode is not a valid split name")
# Load data features from cache or dataset file
cached_features_file = os.path.join(
cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}_{}".format(
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
)
label_list = self.processor.get_labels()
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
RobertaTokenizer,
RobertaTokenizerFast,
XLMRobertaTokenizer,
BartTokenizer,
BartTokenizerFast,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
self.label_list = label_list
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if False and os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
self.features = torch.load(cached_features_file)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
)
else:
logger.info(f"Creating features from dataset file at {args.data_dir}")
if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir)
elif mode == Split.test:
examples = self.processor.get_test_examples(args.data_dir)
else:
examples = self.processor.get_train_examples(args.data_dir)
if limit_length is not None:
examples = examples[:limit_length]
self.features = glue_convert_examples_to_features(
examples,
tokenizer,
max_length=args.max_seq_length,
label_list=label_list,
output_mode=self.output_mode,
)
start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
def __len__(self):
return len(self.features)
def truncate_dataset(self, n_total):
self.features = self.features[:min(len(self.features), n_total )]
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
def get_labels(self):
return self.label_list