-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery Starbot ⭐ refactored guyrosin/temporal_attention #2
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,8 +107,7 @@ def load_train_test_datasets(train_path, test_path, cache_dir): | |
split='train', | ||
cache_dir=cache_dir, | ||
) | ||
dataset = DatasetDict({"train": train_dataset, "validation": test_dataset}) | ||
return dataset | ||
return DatasetDict({"train": train_dataset, "validation": test_dataset}) | ||
|
||
|
||
def split_temporal_dataset_files( | ||
|
@@ -120,7 +119,7 @@ def split_temporal_dataset_files( | |
train_path = Path(train_path) | ||
test_path = Path(test_path) | ||
dataset_path = get_dataset_path(train_path, corpus_name, train_size, test_size) | ||
exclude_similar_sentences = True if corpus_name.startswith("liverpool") else False | ||
exclude_similar_sentences = bool(corpus_name.startswith("liverpool")) | ||
Comment on lines
-123
to
+122
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
out_train_path = dataset_path / train_path.name | ||
out_test_path = dataset_path / test_path.name | ||
if Path(out_train_path).exists() and Path(out_test_path).exists(): | ||
|
@@ -203,7 +202,7 @@ def find_sentences_of_words( | |
logger.debug(f"Loading word_time_sentences from {file_path}") | ||
word_time_sentences = pickle.load(f) | ||
else: | ||
logger.info(f"Finding relevant sentences in the corpus...") | ||
logger.info("Finding relevant sentences in the corpus...") | ||
Comment on lines
-206
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
word_time_sentences = defaultdict(dict) | ||
for file in text_files: # For each time period | ||
time = TemporalText.find_time(file) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,11 +69,10 @@ def _load_auto_config(model_args, data_args=None, num_labels=None, **kwargs): | |
if num_labels | ||
else {} | ||
) | ||
kwargs.update(additional_kwargs) | ||
config = AutoConfig.from_pretrained( | ||
kwargs |= additional_kwargs | ||
return AutoConfig.from_pretrained( | ||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, **kwargs | ||
) | ||
return config | ||
Comment on lines
-72
to
-76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def load_pretrained_model( | ||
|
@@ -299,8 +298,10 @@ def init_run( | |
logger.add(sys.stderr, level="WARNING") | ||
utils.set_result_logger_level() | ||
logger.warning( | ||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " | ||
+ f"distributed training: {bool(training_args.local_rank != -1)}" | ||
( | ||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " | ||
+ f"distributed training: {training_args.local_rank != -1}" | ||
) | ||
Comment on lines
-302
to
+304
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
) | ||
for args_instance in (model_args, data_args, training_args): | ||
logger.info(args_instance) | ||
|
@@ -345,7 +346,7 @@ def get_model_name(model_name_or_path): | |
if "checkpoint-" in path.name: | ||
model_name_or_path = f"{path.parent.name}/{path.name}" | ||
else: | ||
model_name_or_path = str(path.name) | ||
model_name_or_path = path.name | ||
Comment on lines
-348
to
+349
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return model_name_or_path | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,21 +3,17 @@ | |
from transformers.file_utils import _LazyModule | ||
|
||
_import_structure = { | ||
"configuration_tempobert": [ | ||
"TempoBertConfig", | ||
"configuration_tempobert": ["TempoBertConfig"], | ||
"tokenization_tempobert_fast": ["TempoBertTokenizerFast"], | ||
"modeling_tempobert": [ | ||
"TempoBertForMaskedLM", | ||
"TempoBertModel", | ||
"TempoBertForPreTraining", | ||
"TempoBertForSequenceClassification", | ||
"TempoBertForTokenClassification", | ||
], | ||
} | ||
|
||
_import_structure["tokenization_tempobert_fast"] = ["TempoBertTokenizerFast"] | ||
|
||
_import_structure["modeling_tempobert"] = [ | ||
"TempoBertForMaskedLM", | ||
"TempoBertModel", | ||
"TempoBertForPreTraining", | ||
"TempoBertForSequenceClassification", | ||
"TempoBertForTokenClassification", | ||
] | ||
Comment on lines
-6
to
-19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_tempobert import TempoBertConfig | ||
from .modeling_tempobert import ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,8 +77,8 @@ def init_time_embeddings(self, config): | |
if utils.is_time_id_necessary(self.time_embedding_type): | ||
# time_ids is a range (1, len time emb) | ||
time_ids = torch.arange(len(self.times)).expand((1, -1)) | ||
SPECIAL_TIMES_COUNT = 2 # NOTE: hardcoded (see TempoSpecialTokensMixin) | ||
if "attention" in self.time_embedding_type: | ||
SPECIAL_TIMES_COUNT = 2 # NOTE: hardcoded (see TempoSpecialTokensMixin) | ||
Comment on lines
-80
to
+81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# NOTE: each time (including "special" times) has a single embedding (for all tokens in the vocabulary) | ||
self.time_embeddings = nn.Embedding( | ||
SPECIAL_TIMES_COUNT + len(self.times), config.hidden_size | ||
|
@@ -177,10 +177,7 @@ def __init__(self, config): | |
self.position_embedding_type = getattr( | ||
config, "position_embedding_type", "absolute" | ||
) | ||
if ( | ||
self.position_embedding_type == "relative_key" | ||
or self.position_embedding_type == "relative_key_query" | ||
): | ||
if self.position_embedding_type in {"relative_key", "relative_key_query"}: | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.distance_embedding = nn.Embedding( | ||
2 * config.max_position_embeddings - 1, self.attention_head_size | ||
|
@@ -246,10 +243,7 @@ def forward( | |
# Take the dot product between "query" and "key" to get the raw attention scores. | ||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | ||
|
||
if ( | ||
self.position_embedding_type == "relative_key" | ||
or self.position_embedding_type == "relative_key_query" | ||
): | ||
if self.position_embedding_type in ["relative_key", "relative_key_query"]: | ||
seq_length = hidden_states.size()[1] | ||
position_ids_l = torch.arange( | ||
seq_length, dtype=torch.long, device=hidden_states.device | ||
|
@@ -480,10 +474,7 @@ def forward( | |
output_attentions, | ||
) | ||
attention_output = self.output(self_outputs[0], hidden_states) | ||
outputs = (attention_output,) + self_outputs[ | ||
1: | ||
] # add attentions if we output them | ||
return outputs | ||
return (attention_output,) + self_outputs[1:] | ||
|
||
|
||
class TemporalAttention(nn.Module): | ||
|
@@ -541,10 +532,7 @@ def forward( | |
output_attentions, | ||
) | ||
attention_output = self.output(self_outputs[0], hidden_states) | ||
outputs = (attention_output,) + self_outputs[ | ||
1: | ||
] # add attentions if we output them | ||
return outputs | ||
return (attention_output,) + self_outputs[1:] | ||
|
||
|
||
class BertIntermediate(nn.Module): | ||
|
@@ -665,8 +653,7 @@ def forward( | |
|
||
def feed_forward_chunk(self, attention_output): | ||
intermediate_output = self.intermediate(attention_output) | ||
layer_output = self.output(intermediate_output, attention_output) | ||
return layer_output | ||
return self.output(intermediate_output, attention_output) | ||
|
||
|
||
class BertLayerWithTemporalAttention(nn.Module): | ||
|
@@ -762,8 +749,7 @@ def forward( | |
|
||
def feed_forward_chunk(self, attention_output): | ||
intermediate_output = self.intermediate(attention_output) | ||
layer_output = self.output(intermediate_output, attention_output) | ||
return layer_output | ||
return self.output(intermediate_output, attention_output) | ||
|
||
|
||
class BertEncoder(nn.Module): | ||
|
@@ -846,8 +832,16 @@ def custom_forward(*inputs): | |
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
if not return_dict: | ||
return tuple( | ||
return ( | ||
BaseModelOutputWithPastAndCrossAttentions( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_decoder_cache, | ||
hidden_states=all_hidden_states, | ||
attentions=all_self_attentions, | ||
cross_attentions=all_cross_attentions, | ||
) | ||
if return_dict | ||
else tuple( | ||
v | ||
for v in [ | ||
hidden_states, | ||
|
@@ -858,12 +852,6 @@ def custom_forward(*inputs): | |
] | ||
if v is not None | ||
) | ||
return BaseModelOutputWithPastAndCrossAttentions( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_decoder_cache, | ||
hidden_states=all_hidden_states, | ||
attentions=all_self_attentions, | ||
cross_attentions=all_cross_attentions, | ||
) | ||
|
||
|
||
|
@@ -954,8 +942,16 @@ def custom_forward(*inputs): | |
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
if not return_dict: | ||
return tuple( | ||
return ( | ||
BaseModelOutputWithPastAndCrossAttentions( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_decoder_cache, | ||
hidden_states=all_hidden_states, | ||
attentions=all_self_attentions, | ||
cross_attentions=all_cross_attentions, | ||
) | ||
if return_dict | ||
else tuple( | ||
v | ||
for v in [ | ||
hidden_states, | ||
|
@@ -966,12 +962,6 @@ def custom_forward(*inputs): | |
] | ||
if v is not None | ||
) | ||
return BaseModelOutputWithPastAndCrossAttentions( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_decoder_cache, | ||
hidden_states=all_hidden_states, | ||
attentions=all_self_attentions, | ||
cross_attentions=all_cross_attentions, | ||
) | ||
|
||
|
||
|
@@ -1033,8 +1023,7 @@ def __init__(self, config): | |
self.predictions = BertLMPredictionHead(config) | ||
|
||
def forward(self, sequence_output): | ||
prediction_scores = self.predictions(sequence_output) | ||
return prediction_scores | ||
return self.predictions(sequence_output) | ||
|
||
|
||
class TempoBertPreTrainedModel(PreTrainedModel): | ||
|
@@ -1277,16 +1266,17 @@ def forward( | |
self.pooler(sequence_output) if self.pooler is not None else None | ||
) | ||
|
||
if not return_dict: | ||
return (sequence_output, pooled_output) + encoder_outputs[1:] | ||
|
||
return BaseModelOutputWithPoolingAndCrossAttentions( | ||
last_hidden_state=sequence_output, | ||
pooler_output=pooled_output, | ||
past_key_values=encoder_outputs.past_key_values, | ||
hidden_states=encoder_outputs.hidden_states, | ||
attentions=encoder_outputs.attentions, | ||
cross_attentions=encoder_outputs.cross_attentions, | ||
return ( | ||
BaseModelOutputWithPoolingAndCrossAttentions( | ||
last_hidden_state=sequence_output, | ||
pooler_output=pooled_output, | ||
past_key_values=encoder_outputs.past_key_values, | ||
hidden_states=encoder_outputs.hidden_states, | ||
attentions=encoder_outputs.attentions, | ||
cross_attentions=encoder_outputs.cross_attentions, | ||
) | ||
if return_dict | ||
else (sequence_output, pooled_output) + encoder_outputs[1:] | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,31 +84,23 @@ def get_embedding( | |
require_word_in_vocab=False, | ||
hidden_layers_number=None, | ||
): | ||
if (require_word_in_vocab and not word in model.tokenizer.vocab) or len( | ||
sentences | ||
) == 0: | ||
if ( | ||
require_word_in_vocab | ||
and word not in model.tokenizer.vocab | ||
or len(sentences) == 0 | ||
): | ||
return torch.tensor([]) | ||
if hidden_layers_number is None: | ||
num_hidden_layers = model.config.num_hidden_layers | ||
if num_hidden_layers == 12: | ||
hidden_layers_number = 1 | ||
elif num_hidden_layers == 2: | ||
hidden_layers_number = 3 | ||
else: | ||
hidden_layers_number = 1 | ||
hidden_layers_number = 3 if num_hidden_layers == 2 else 1 | ||
Comment on lines
-87
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
embs = model.embed_word( | ||
sentences, | ||
word, | ||
time=time, | ||
batch_size=batch_size, | ||
hidden_layers_number=hidden_layers_number, | ||
) | ||
if embs.ndim == 1: | ||
# in case of a single sentence, embs is actually the single embedding, not a list | ||
return embs | ||
else: | ||
centroid = torch.mean(embs, dim=0) | ||
return centroid | ||
return embs if embs.ndim == 1 else torch.mean(embs, dim=0) | ||
|
||
|
||
def get_detection_function(score_method, config): | ||
|
@@ -173,10 +165,11 @@ def semantic_change_detection_wrapper( | |
if not sentences: | ||
logger.debug(f"Found no sentences for '{word}' at time '{time}'") | ||
if hasattr(model.config, 'times'): | ||
missing_times = [ | ||
time for time in model.config.times if time not in time_sentences | ||
] | ||
if missing_times: | ||
if missing_times := [ | ||
time | ||
for time in model.config.times | ||
if time not in time_sentences | ||
]: | ||
Comment on lines
-176
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
logger.debug(f"Found no sentences for '{word}' at {missing_times}") | ||
score = detection_function( | ||
time_sentences, | ||
|
@@ -195,14 +188,14 @@ def semantic_change_detection_wrapper( | |
shifts_dict, | ||
) | ||
logger.info("Final results:") | ||
for model, result_str in model_to_result_str.items(): | ||
for result_str in model_to_result_str.values(): | ||
logger.info(result_str) | ||
|
||
|
||
def check_words_in_vocab(words, tokenizer, verbose=False, check_split_words=False): | ||
missing_words = [] | ||
for word in words: | ||
if not word in tokenizer.vocab: | ||
if word not in tokenizer.vocab: | ||
Comment on lines
-205
to
+198
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
if verbose: | ||
logger.warning(f"{word=} doesn't exist in the vocab") | ||
missing_words.append(word) | ||
|
@@ -243,8 +236,7 @@ def semantic_change_detection( | |
method = calc_change_score_cosine_dist | ||
else: | ||
raise ValueError(f"Unknown {score_method=}") | ||
score = method(model, sentences, word, verbose) | ||
return score | ||
return method(model, sentences, word, verbose) | ||
Comment on lines
-246
to
+239
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def semantic_change_detection_temporal( | ||
|
@@ -328,9 +320,7 @@ def get_shifts(corpus_name, tokenizer=None): | |
# The German target words are uppercased | ||
if tokenizer.do_lower_case: | ||
df_shifts.word = df_shifts.word.str.lower() | ||
elif corpus_name.startswith("semeval_lat"): | ||
pass | ||
else: | ||
elif not corpus_name.startswith("semeval_lat"): | ||
Comment on lines
-331
to
+323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
logger.error(f"Unsupported corpus: {corpus_name}") | ||
exit() | ||
shifts_dict = dict(zip(df_shifts.word, df_shifts.score)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,8 +23,9 @@ def predict_time(time_to_label, fill_mask_pipeline, time_pattern, sentence): | |
tokens = list(result_dict.keys()) | ||
# Choose the token with the highest probability | ||
pred_token = tokens[0] | ||
pred = time_to_label[time_pattern.search(pred_token).group(1) if pred_token else 0] | ||
return pred | ||
return time_to_label[ | ||
time_pattern.search(pred_token).group(1) if pred_token else 0 | ||
] | ||
Comment on lines
-26
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def sentence_time_prediction( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,10 +44,8 @@ def find_time(filename): | |
m = re.match(r".*?_(\d+.*?)[\.a-zA-Z]", filename) | ||
if m is None: | ||
return None | ||
time = m.group(1) | ||
# Remove trailing underscores (e.g., for "nyt_2017_train.txt") | ||
time = time.strip("_") | ||
return time | ||
time = m[1] | ||
return time.strip("_") | ||
Comment on lines
-47
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
def _generate_tables(self, files): | ||
for file_idx, file in enumerate(files): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
load_train_test_datasets
refactored with the following changes:inline-immediately-returned-variable
)