Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery Starbot ⭐ refactored guyrosin/temporal_attention #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def load_train_test_datasets(train_path, test_path, cache_dir):
split='train',
cache_dir=cache_dir,
)
dataset = DatasetDict({"train": train_dataset, "validation": test_dataset})
return dataset
return DatasetDict({"train": train_dataset, "validation": test_dataset})
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function load_train_test_datasets refactored with the following changes:



def split_temporal_dataset_files(
Expand All @@ -120,7 +119,7 @@ def split_temporal_dataset_files(
train_path = Path(train_path)
test_path = Path(test_path)
dataset_path = get_dataset_path(train_path, corpus_name, train_size, test_size)
exclude_similar_sentences = True if corpus_name.startswith("liverpool") else False
exclude_similar_sentences = bool(corpus_name.startswith("liverpool"))
Comment on lines -123 to +122
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function split_temporal_dataset_files refactored with the following changes:

out_train_path = dataset_path / train_path.name
out_test_path = dataset_path / test_path.name
if Path(out_train_path).exists() and Path(out_test_path).exists():
Expand Down Expand Up @@ -203,7 +202,7 @@ def find_sentences_of_words(
logger.debug(f"Loading word_time_sentences from {file_path}")
word_time_sentences = pickle.load(f)
else:
logger.info(f"Finding relevant sentences in the corpus...")
logger.info("Finding relevant sentences in the corpus...")
Comment on lines -206 to +205
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function find_sentences_of_words refactored with the following changes:

word_time_sentences = defaultdict(dict)
for file in text_files: # For each time period
time = TemporalText.find_time(file)
Expand Down
13 changes: 7 additions & 6 deletions hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ def _load_auto_config(model_args, data_args=None, num_labels=None, **kwargs):
if num_labels
else {}
)
kwargs.update(additional_kwargs)
config = AutoConfig.from_pretrained(
kwargs |= additional_kwargs
return AutoConfig.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, **kwargs
)
return config
Comment on lines -72 to -76
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _load_auto_config refactored with the following changes:



def load_pretrained_model(
Expand Down Expand Up @@ -299,8 +298,10 @@ def init_run(
logger.add(sys.stderr, level="WARNING")
utils.set_result_logger_level()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}"
(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {training_args.local_rank != -1}"
)
Comment on lines -302 to +304
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function init_run refactored with the following changes:

)
for args_instance in (model_args, data_args, training_args):
logger.info(args_instance)
Expand Down Expand Up @@ -345,7 +346,7 @@ def get_model_name(model_name_or_path):
if "checkpoint-" in path.name:
model_name_or_path = f"{path.parent.name}/{path.name}"
else:
model_name_or_path = str(path.name)
model_name_or_path = path.name
Comment on lines -348 to +349
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_model_name refactored with the following changes:

return model_name_or_path


Expand Down
20 changes: 8 additions & 12 deletions models/tempobert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,17 @@
from transformers.file_utils import _LazyModule

_import_structure = {
"configuration_tempobert": [
"TempoBertConfig",
"configuration_tempobert": ["TempoBertConfig"],
"tokenization_tempobert_fast": ["TempoBertTokenizerFast"],
"modeling_tempobert": [
"TempoBertForMaskedLM",
"TempoBertModel",
"TempoBertForPreTraining",
"TempoBertForSequenceClassification",
"TempoBertForTokenClassification",
],
}

_import_structure["tokenization_tempobert_fast"] = ["TempoBertTokenizerFast"]

_import_structure["modeling_tempobert"] = [
"TempoBertForMaskedLM",
"TempoBertModel",
"TempoBertForPreTraining",
"TempoBertForSequenceClassification",
"TempoBertForTokenClassification",
]
Comment on lines -6 to -19
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 6-19 refactored with the following changes:


if TYPE_CHECKING:
from .configuration_tempobert import TempoBertConfig
from .modeling_tempobert import (
Expand Down
88 changes: 39 additions & 49 deletions models/tempobert/modeling_tempobert.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def init_time_embeddings(self, config):
if utils.is_time_id_necessary(self.time_embedding_type):
# time_ids is a range (1, len time emb)
time_ids = torch.arange(len(self.times)).expand((1, -1))
SPECIAL_TIMES_COUNT = 2 # NOTE: hardcoded (see TempoSpecialTokensMixin)
if "attention" in self.time_embedding_type:
SPECIAL_TIMES_COUNT = 2 # NOTE: hardcoded (see TempoSpecialTokensMixin)
Comment on lines -80 to +81
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function TempoBertEmbeddings.init_time_embeddings refactored with the following changes:

# NOTE: each time (including "special" times) has a single embedding (for all tokens in the vocabulary)
self.time_embeddings = nn.Embedding(
SPECIAL_TIMES_COUNT + len(self.times), config.hidden_size
Expand Down Expand Up @@ -177,10 +177,7 @@ def __init__(self, config):
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
if self.position_embedding_type in {"relative_key", "relative_key_query"}:
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
Expand Down Expand Up @@ -246,10 +243,7 @@ def forward(
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
if self.position_embedding_type in ["relative_key", "relative_key_query"]:
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
Expand Down Expand Up @@ -480,10 +474,7 @@ def forward(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
return (attention_output,) + self_outputs[1:]


class TemporalAttention(nn.Module):
Expand Down Expand Up @@ -541,10 +532,7 @@ def forward(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
return (attention_output,) + self_outputs[1:]


class BertIntermediate(nn.Module):
Expand Down Expand Up @@ -665,8 +653,7 @@ def forward(

def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
return self.output(intermediate_output, attention_output)


class BertLayerWithTemporalAttention(nn.Module):
Expand Down Expand Up @@ -762,8 +749,7 @@ def forward(

def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
return self.output(intermediate_output, attention_output)


class BertEncoder(nn.Module):
Expand Down Expand Up @@ -846,8 +832,16 @@ def custom_forward(*inputs):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
return (
BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
if return_dict
else tuple(
v
for v in [
hidden_states,
Expand All @@ -858,12 +852,6 @@ def custom_forward(*inputs):
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -954,8 +942,16 @@ def custom_forward(*inputs):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
return (
BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
if return_dict
else tuple(
v
for v in [
hidden_states,
Expand All @@ -966,12 +962,6 @@ def custom_forward(*inputs):
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -1033,8 +1023,7 @@ def __init__(self, config):
self.predictions = BertLMPredictionHead(config)

def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
return self.predictions(sequence_output)


class TempoBertPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -1277,16 +1266,17 @@ def forward(
self.pooler(sequence_output) if self.pooler is not None else None
)

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
return (
BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
if return_dict
else (sequence_output, pooled_output) + encoder_outputs[1:]
)


Expand Down
42 changes: 16 additions & 26 deletions semantic_change_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,23 @@ def get_embedding(
require_word_in_vocab=False,
hidden_layers_number=None,
):
if (require_word_in_vocab and not word in model.tokenizer.vocab) or len(
sentences
) == 0:
if (
require_word_in_vocab
and word not in model.tokenizer.vocab
or len(sentences) == 0
):
return torch.tensor([])
if hidden_layers_number is None:
num_hidden_layers = model.config.num_hidden_layers
if num_hidden_layers == 12:
hidden_layers_number = 1
elif num_hidden_layers == 2:
hidden_layers_number = 3
else:
hidden_layers_number = 1
hidden_layers_number = 3 if num_hidden_layers == 2 else 1
Comment on lines -87 to +95
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_embedding refactored with the following changes:

This removes the following comments ( why? ):

#  in case of a single sentence, embs is actually the single embedding, not a list

embs = model.embed_word(
sentences,
word,
time=time,
batch_size=batch_size,
hidden_layers_number=hidden_layers_number,
)
if embs.ndim == 1:
# in case of a single sentence, embs is actually the single embedding, not a list
return embs
else:
centroid = torch.mean(embs, dim=0)
return centroid
return embs if embs.ndim == 1 else torch.mean(embs, dim=0)


def get_detection_function(score_method, config):
Expand Down Expand Up @@ -173,10 +165,11 @@ def semantic_change_detection_wrapper(
if not sentences:
logger.debug(f"Found no sentences for '{word}' at time '{time}'")
if hasattr(model.config, 'times'):
missing_times = [
time for time in model.config.times if time not in time_sentences
]
if missing_times:
if missing_times := [
time
for time in model.config.times
if time not in time_sentences
]:
Comment on lines -176 to +172
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function semantic_change_detection_wrapper refactored with the following changes:

logger.debug(f"Found no sentences for '{word}' at {missing_times}")
score = detection_function(
time_sentences,
Expand All @@ -195,14 +188,14 @@ def semantic_change_detection_wrapper(
shifts_dict,
)
logger.info("Final results:")
for model, result_str in model_to_result_str.items():
for result_str in model_to_result_str.values():
logger.info(result_str)


def check_words_in_vocab(words, tokenizer, verbose=False, check_split_words=False):
missing_words = []
for word in words:
if not word in tokenizer.vocab:
if word not in tokenizer.vocab:
Comment on lines -205 to +198
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function check_words_in_vocab refactored with the following changes:

  • Simplify logical expression using De Morgan identities (de-morgan)

if verbose:
logger.warning(f"{word=} doesn't exist in the vocab")
missing_words.append(word)
Expand Down Expand Up @@ -243,8 +236,7 @@ def semantic_change_detection(
method = calc_change_score_cosine_dist
else:
raise ValueError(f"Unknown {score_method=}")
score = method(model, sentences, word, verbose)
return score
return method(model, sentences, word, verbose)
Comment on lines -246 to +239
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function semantic_change_detection refactored with the following changes:



def semantic_change_detection_temporal(
Expand Down Expand Up @@ -328,9 +320,7 @@ def get_shifts(corpus_name, tokenizer=None):
# The German target words are uppercased
if tokenizer.do_lower_case:
df_shifts.word = df_shifts.word.str.lower()
elif corpus_name.startswith("semeval_lat"):
pass
else:
elif not corpus_name.startswith("semeval_lat"):
Comment on lines -331 to +323
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_shifts refactored with the following changes:

logger.error(f"Unsupported corpus: {corpus_name}")
exit()
shifts_dict = dict(zip(df_shifts.word, df_shifts.score))
Expand Down
5 changes: 3 additions & 2 deletions sentence_time_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def predict_time(time_to_label, fill_mask_pipeline, time_pattern, sentence):
tokens = list(result_dict.keys())
# Choose the token with the highest probability
pred_token = tokens[0]
pred = time_to_label[time_pattern.search(pred_token).group(1) if pred_token else 0]
return pred
return time_to_label[
time_pattern.search(pred_token).group(1) if pred_token else 0
]
Comment on lines -26 to +28
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function predict_time refactored with the following changes:



def sentence_time_prediction(
Expand Down
6 changes: 2 additions & 4 deletions temporal_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ def find_time(filename):
m = re.match(r".*?_(\d+.*?)[\.a-zA-Z]", filename)
if m is None:
return None
time = m.group(1)
# Remove trailing underscores (e.g., for "nyt_2017_train.txt")
time = time.strip("_")
return time
time = m[1]
return time.strip("_")
Comment on lines -47 to +48
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function TemporalText.find_time refactored with the following changes:

This removes the following comments ( why? ):

# Remove trailing underscores (e.g., for "nyt_2017_train.txt")


def _generate_tables(self, files):
for file_idx, file in enumerate(files):
Expand Down
Loading