diff --git a/data_utils.py b/data_utils.py index 0730d6d..9d87b62 100644 --- a/data_utils.py +++ b/data_utils.py @@ -107,8 +107,7 @@ def load_train_test_datasets(train_path, test_path, cache_dir): split='train', cache_dir=cache_dir, ) - dataset = DatasetDict({"train": train_dataset, "validation": test_dataset}) - return dataset + return DatasetDict({"train": train_dataset, "validation": test_dataset}) def split_temporal_dataset_files( @@ -120,7 +119,7 @@ def split_temporal_dataset_files( train_path = Path(train_path) test_path = Path(test_path) dataset_path = get_dataset_path(train_path, corpus_name, train_size, test_size) - exclude_similar_sentences = True if corpus_name.startswith("liverpool") else False + exclude_similar_sentences = bool(corpus_name.startswith("liverpool")) out_train_path = dataset_path / train_path.name out_test_path = dataset_path / test_path.name if Path(out_train_path).exists() and Path(out_test_path).exists(): @@ -203,7 +202,7 @@ def find_sentences_of_words( logger.debug(f"Loading word_time_sentences from {file_path}") word_time_sentences = pickle.load(f) else: - logger.info(f"Finding relevant sentences in the corpus...") + logger.info("Finding relevant sentences in the corpus...") word_time_sentences = defaultdict(dict) for file in text_files: # For each time period time = TemporalText.find_time(file) diff --git a/hf_utils.py b/hf_utils.py index aff6f53..62de830 100644 --- a/hf_utils.py +++ b/hf_utils.py @@ -69,11 +69,10 @@ def _load_auto_config(model_args, data_args=None, num_labels=None, **kwargs): if num_labels else {} ) - kwargs.update(additional_kwargs) - config = AutoConfig.from_pretrained( + kwargs |= additional_kwargs + return AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, **kwargs ) - return config def load_pretrained_model( @@ -299,8 +298,10 @@ def init_run( logger.add(sys.stderr, level="WARNING") utils.set_result_logger_level() logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " - + f"distributed training: {bool(training_args.local_rank != -1)}" + ( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + + f"distributed training: {training_args.local_rank != -1}" + ) ) for args_instance in (model_args, data_args, training_args): logger.info(args_instance) @@ -345,7 +346,7 @@ def get_model_name(model_name_or_path): if "checkpoint-" in path.name: model_name_or_path = f"{path.parent.name}/{path.name}" else: - model_name_or_path = str(path.name) + model_name_or_path = path.name return model_name_or_path diff --git a/models/tempobert/__init__.py b/models/tempobert/__init__.py index 8d03904..2d8005c 100644 --- a/models/tempobert/__init__.py +++ b/models/tempobert/__init__.py @@ -3,21 +3,17 @@ from transformers.file_utils import _LazyModule _import_structure = { - "configuration_tempobert": [ - "TempoBertConfig", + "configuration_tempobert": ["TempoBertConfig"], + "tokenization_tempobert_fast": ["TempoBertTokenizerFast"], + "modeling_tempobert": [ + "TempoBertForMaskedLM", + "TempoBertModel", + "TempoBertForPreTraining", + "TempoBertForSequenceClassification", + "TempoBertForTokenClassification", ], } -_import_structure["tokenization_tempobert_fast"] = ["TempoBertTokenizerFast"] - -_import_structure["modeling_tempobert"] = [ - "TempoBertForMaskedLM", - "TempoBertModel", - "TempoBertForPreTraining", - "TempoBertForSequenceClassification", - "TempoBertForTokenClassification", -] - if TYPE_CHECKING: from .configuration_tempobert import TempoBertConfig from .modeling_tempobert import ( diff --git a/models/tempobert/modeling_tempobert.py b/models/tempobert/modeling_tempobert.py index 1046bab..df5c3a7 100644 --- a/models/tempobert/modeling_tempobert.py +++ b/models/tempobert/modeling_tempobert.py @@ -77,8 +77,8 @@ def init_time_embeddings(self, config): if utils.is_time_id_necessary(self.time_embedding_type): # time_ids is a range (1, len time emb) time_ids = torch.arange(len(self.times)).expand((1, -1)) - SPECIAL_TIMES_COUNT = 2 # NOTE: hardcoded (see TempoSpecialTokensMixin) if "attention" in self.time_embedding_type: + SPECIAL_TIMES_COUNT = 2 # NOTE: hardcoded (see TempoSpecialTokensMixin) # NOTE: each time (including "special" times) has a single embedding (for all tokens in the vocabulary) self.time_embeddings = nn.Embedding( SPECIAL_TIMES_COUNT + len(self.times), config.hidden_size @@ -177,10 +177,7 @@ def __init__(self, config): self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) - if ( - self.position_embedding_type == "relative_key" - or self.position_embedding_type == "relative_key_query" - ): + if self.position_embedding_type in {"relative_key", "relative_key_query"}: self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size @@ -246,10 +243,7 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if ( - self.position_embedding_type == "relative_key" - or self.position_embedding_type == "relative_key_query" - ): + if self.position_embedding_type in ["relative_key", "relative_key_query"]: seq_length = hidden_states.size()[1] position_ids_l = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device @@ -480,10 +474,7 @@ def forward( output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[ - 1: - ] # add attentions if we output them - return outputs + return (attention_output,) + self_outputs[1:] class TemporalAttention(nn.Module): @@ -541,10 +532,7 @@ def forward( output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[ - 1: - ] # add attentions if we output them - return outputs + return (attention_output,) + self_outputs[1:] class BertIntermediate(nn.Module): @@ -665,8 +653,7 @@ def forward( def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output + return self.output(intermediate_output, attention_output) class BertLayerWithTemporalAttention(nn.Module): @@ -762,8 +749,7 @@ def forward( def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output + return self.output(intermediate_output, attention_output) class BertEncoder(nn.Module): @@ -846,8 +832,16 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( + return ( + BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + if return_dict + else tuple( v for v in [ hidden_states, @@ -858,12 +852,6 @@ def custom_forward(*inputs): ] if v is not None ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -954,8 +942,16 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( + return ( + BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + if return_dict + else tuple( v for v in [ hidden_states, @@ -966,12 +962,6 @@ def custom_forward(*inputs): ] if v is not None ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -1033,8 +1023,7 @@ def __init__(self, config): self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores + return self.predictions(sequence_output) class TempoBertPreTrainedModel(PreTrainedModel): @@ -1277,16 +1266,17 @@ def forward( self.pooler(sequence_output) if self.pooler is not None else None ) - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, + return ( + BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + if return_dict + else (sequence_output, pooled_output) + encoder_outputs[1:] ) diff --git a/semantic_change_detection.py b/semantic_change_detection.py index 989e391..e624c4f 100644 --- a/semantic_change_detection.py +++ b/semantic_change_detection.py @@ -84,18 +84,15 @@ def get_embedding( require_word_in_vocab=False, hidden_layers_number=None, ): - if (require_word_in_vocab and not word in model.tokenizer.vocab) or len( - sentences - ) == 0: + if ( + require_word_in_vocab + and word not in model.tokenizer.vocab + or len(sentences) == 0 + ): return torch.tensor([]) if hidden_layers_number is None: num_hidden_layers = model.config.num_hidden_layers - if num_hidden_layers == 12: - hidden_layers_number = 1 - elif num_hidden_layers == 2: - hidden_layers_number = 3 - else: - hidden_layers_number = 1 + hidden_layers_number = 3 if num_hidden_layers == 2 else 1 embs = model.embed_word( sentences, word, @@ -103,12 +100,7 @@ def get_embedding( batch_size=batch_size, hidden_layers_number=hidden_layers_number, ) - if embs.ndim == 1: - # in case of a single sentence, embs is actually the single embedding, not a list - return embs - else: - centroid = torch.mean(embs, dim=0) - return centroid + return embs if embs.ndim == 1 else torch.mean(embs, dim=0) def get_detection_function(score_method, config): @@ -173,10 +165,11 @@ def semantic_change_detection_wrapper( if not sentences: logger.debug(f"Found no sentences for '{word}' at time '{time}'") if hasattr(model.config, 'times'): - missing_times = [ - time for time in model.config.times if time not in time_sentences - ] - if missing_times: + if missing_times := [ + time + for time in model.config.times + if time not in time_sentences + ]: logger.debug(f"Found no sentences for '{word}' at {missing_times}") score = detection_function( time_sentences, @@ -195,14 +188,14 @@ def semantic_change_detection_wrapper( shifts_dict, ) logger.info("Final results:") - for model, result_str in model_to_result_str.items(): + for result_str in model_to_result_str.values(): logger.info(result_str) def check_words_in_vocab(words, tokenizer, verbose=False, check_split_words=False): missing_words = [] for word in words: - if not word in tokenizer.vocab: + if word not in tokenizer.vocab: if verbose: logger.warning(f"{word=} doesn't exist in the vocab") missing_words.append(word) @@ -243,8 +236,7 @@ def semantic_change_detection( method = calc_change_score_cosine_dist else: raise ValueError(f"Unknown {score_method=}") - score = method(model, sentences, word, verbose) - return score + return method(model, sentences, word, verbose) def semantic_change_detection_temporal( @@ -328,9 +320,7 @@ def get_shifts(corpus_name, tokenizer=None): # The German target words are uppercased if tokenizer.do_lower_case: df_shifts.word = df_shifts.word.str.lower() - elif corpus_name.startswith("semeval_lat"): - pass - else: + elif not corpus_name.startswith("semeval_lat"): logger.error(f"Unsupported corpus: {corpus_name}") exit() shifts_dict = dict(zip(df_shifts.word, df_shifts.score)) diff --git a/sentence_time_prediction.py b/sentence_time_prediction.py index 280c2c2..e1fa512 100644 --- a/sentence_time_prediction.py +++ b/sentence_time_prediction.py @@ -23,8 +23,9 @@ def predict_time(time_to_label, fill_mask_pipeline, time_pattern, sentence): tokens = list(result_dict.keys()) # Choose the token with the highest probability pred_token = tokens[0] - pred = time_to_label[time_pattern.search(pred_token).group(1) if pred_token else 0] - return pred + return time_to_label[ + time_pattern.search(pred_token).group(1) if pred_token else 0 + ] def sentence_time_prediction( diff --git a/temporal_text_dataset.py b/temporal_text_dataset.py index 98c941b..05189fa 100644 --- a/temporal_text_dataset.py +++ b/temporal_text_dataset.py @@ -44,10 +44,8 @@ def find_time(filename): m = re.match(r".*?_(\d+.*?)[\.a-zA-Z]", filename) if m is None: return None - time = m.group(1) - # Remove trailing underscores (e.g., for "nyt_2017_train.txt") - time = time.strip("_") - return time + time = m[1] + return time.strip("_") def _generate_tables(self, files): for file_idx, file in enumerate(files): diff --git a/test_bert.py b/test_bert.py index dca4768..55c1334 100644 --- a/test_bert.py +++ b/test_bert.py @@ -12,7 +12,7 @@ def predict_time(sentence, fill_mask_pipelines, print_results=True): time_tokens = [f"<{time}>" for time in fill_mask_pipelines[0].model.config.times] result_dict = {} original_sentence = sentence - sentence = "[MASK] " + sentence + sentence = f"[MASK] {sentence}" for model_i, fill_mask in enumerate(fill_mask_pipelines): fill_result = fill_mask(sentence, targets=time_tokens, truncation=True) result = {res["token_str"]: res["score"] for res in fill_result} diff --git a/tokenization_utils_base.py b/tokenization_utils_base.py index b8a5cd7..39e55dd 100644 --- a/tokenization_utils_base.py +++ b/tokenization_utils_base.py @@ -214,41 +214,39 @@ def max_len_sentences_pair(self) -> int: @max_len_single_sentence.setter def max_len_single_sentence(self, value) -> int: - # For backward compatibility, allow to try to setup 'max_len_single_sentence'. if ( - value == self.model_max_length - self.num_special_tokens_to_add(pair=False) - and self.verbose + value + != self.model_max_length - self.num_special_tokens_to_add(pair=False) + or not self.verbose ): - if not self.deprecation_warnings.get("max_len_single_sentence", False): - logger.warning( - "Setting 'max_len_single_sentence' is now deprecated. " - "This value is automatically set up." - ) - self.deprecation_warnings["max_len_single_sentence"] = True - else: raise ValueError( "Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up." ) + if not self.deprecation_warnings.get("max_len_single_sentence", False): + logger.warning( + "Setting 'max_len_single_sentence' is now deprecated. " + "This value is automatically set up." + ) + self.deprecation_warnings["max_len_single_sentence"] = True @max_len_sentences_pair.setter def max_len_sentences_pair(self, value) -> int: - # For backward compatibility, allow to try to setup 'max_len_sentences_pair'. if ( - value == self.model_max_length - self.num_special_tokens_to_add(pair=True) - and self.verbose + value + != self.model_max_length - self.num_special_tokens_to_add(pair=True) + or not self.verbose ): - if not self.deprecation_warnings.get("max_len_sentences_pair", False): - logger.warning( - "Setting 'max_len_sentences_pair' is now deprecated. " - "This value is automatically set up." - ) - self.deprecation_warnings["max_len_sentences_pair"] = True - else: raise ValueError( "Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up." ) + if not self.deprecation_warnings.get("max_len_sentences_pair", False): + logger.warning( + "Setting 'max_len_sentences_pair' is now deprecated. " + "This value is automatically set up." + ) + self.deprecation_warnings["max_len_sentences_pair"] = True def __repr__(self) -> str: return ( @@ -369,12 +367,7 @@ def from_pretrained( else: # Get the vocabulary from local files logger.info( - "Model name '{}' not found in model shortcut name list ({}). " - "Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format( - pretrained_model_name_or_path, - ", ".join(s3_models), - pretrained_model_name_or_path, - ) + f"""Model name '{pretrained_model_name_or_path}' not found in model shortcut name list ({", ".join(s3_models)}). Assuming '{pretrained_model_name_or_path}' is a path, a model identifier, or url to a directory containing tokenizer files.""" ) if os.path.isfile(pretrained_model_name_or_path) or is_remote_url( @@ -459,13 +452,12 @@ def from_pretrained( raise error except requests.exceptions.HTTPError as err: - if "404 Client Error" in str(err): - logger.debug(err) - resolved_vocab_files[file_id] = None - else: + if "404 Client Error" not in str(err): raise err - if len(unresolved_files) > 0: + logger.debug(err) + resolved_vocab_files[file_id] = None + if unresolved_files: logger.info( f"Can't load following files from cache: {unresolved_files} and cannot check if these " "files are necessary for the tokenizer to operate." @@ -583,16 +575,15 @@ def _from_pretrained( if config_tokenizer_class is None: config_tokenizer_class = config_tokenizer_class_fast - if config_tokenizer_class is not None: - if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace( - "Fast", "" - ): - logger.warning( - "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " - "It may result in unexpected tokenization. \n" - f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" - f"The class this function is called from is '{cls.__name__}'." - ) + if config_tokenizer_class is not None and cls.__name__.replace( + "Fast", "" + ) != config_tokenizer_class.replace("Fast", ""): + logger.warning( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " + "It may result in unexpected tokenization. \n" + f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" + f"The class this function is called from is '{cls.__name__}'." + ) # Update with newly provided kwargs init_kwargs.update(kwargs) @@ -607,7 +598,7 @@ def convert_added_tokens(obj: Union[AddedToken, Any]): obj.pop("__type") return AddedToken(**obj) elif isinstance(obj, (list, tuple)): - return list(convert_added_tokens(o) for o in obj) + return [convert_added_tokens(o) for o in obj] elif isinstance(obj, dict): return {k: convert_added_tokens(v) for k, v in obj.items()} return obj @@ -708,9 +699,7 @@ def convert_added_tokens(obj: Union[AddedToken, Any]): ) # Safe to call on a tokenizer fast even if token already there. - tokenizer.add_tokens( - token, special_tokens=bool(token in special_tokens) - ) + tokenizer.add_tokens(token, special_tokens=token in special_tokens) # Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab added_tokens = tokenizer.sanitize_special_tokens() @@ -766,12 +755,15 @@ def save_pretrained( special_tokens_map_file = os.path.join( save_directory, - (filename_prefix + "-" if filename_prefix else "") - + SPECIAL_TOKENS_MAP_FILE, + ( + (f"{filename_prefix}-" if filename_prefix else "") + + SPECIAL_TOKENS_MAP_FILE + ), ) tokenizer_config_file = os.path.join( save_directory, - (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE, + (f"{filename_prefix}-" if filename_prefix else "") + + TOKENIZER_CONFIG_FILE, ) tokenizer_config = copy.deepcopy(self.init_kwargs) @@ -788,9 +780,7 @@ def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True): out["__type"] = "AddedToken" return out elif isinstance(obj, (list, tuple)): - return list( - convert_added_tokens(o, add_type_field=add_type_field) for o in obj - ) + return [convert_added_tokens(o, add_type_field=add_type_field) for o in obj] elif isinstance(obj, dict): return { k: convert_added_tokens(v, add_type_field=add_type_field) @@ -860,10 +850,9 @@ def _save_pretrained( added_tokens_file = os.path.join( save_directory, - (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE, + (f"{filename_prefix}-" if filename_prefix else "") + ADDED_TOKENS_FILE, ) - added_vocab = self.get_added_vocab() - if added_vocab: + if added_vocab := self.get_added_vocab(): with open(added_tokens_file, "w", encoding="utf-8") as f: out_str = json.dumps(added_vocab, ensure_ascii=False) f.write(out_str) @@ -1015,14 +1004,13 @@ def _get_padding_truncation_strategies( elif padding is not False: if padding is True: if verbose: - if max_length is not None: - if max_length is not None and ( - truncation is False or truncation == "do_not_truncate" - ): - warnings.warn( - "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " - "To pad to max length, use `padding='max_length'`." - ) + if max_length is not None and ( + truncation is False or truncation == "do_not_truncate" + ): + warnings.warn( + "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " + "To pad to max length, use `padding='max_length'`." + ) if old_pad_to_max_length is not False: warnings.warn( "Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`." @@ -1032,7 +1020,7 @@ def _get_padding_truncation_strategies( ) # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) - elif isinstance(padding, PaddingStrategy): + else: padding_strategy = padding else: padding_strategy = PaddingStrategy.DO_NOT_PAD @@ -1059,7 +1047,7 @@ def _get_padding_truncation_strategies( ) # Default to truncate the longest sequences in pairs of inputs elif not isinstance(truncation, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation) - elif isinstance(truncation, TruncationStrategy): + else: truncation_strategy = truncation else: truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE @@ -1639,7 +1627,7 @@ def pad( batch_outputs = {} for i in range(batch_size): - inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) + inputs = {k: v[i] for k, v in encoded_inputs.items()} outputs = self._pad( inputs, max_length=max_length, @@ -1691,9 +1679,7 @@ def build_inputs_with_special_tokens( Returns: :obj:`List[int]`: The model input with special tokens. """ - if token_ids_1 is None: - return token_ids_0 - return token_ids_0 + token_ids_1 + return token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1 def prepare_for_model( self, @@ -1747,7 +1733,7 @@ def prepare_for_model( **kwargs, ) - pair = bool(pair_ids is not None) + pair = pair_ids is not None len_ids = len(ids) len_pair_ids = len(pair_ids) if pair else 0 @@ -1841,14 +1827,12 @@ def prepare_for_model( if return_length: encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - batch_outputs = BatchEncoding( + return BatchEncoding( encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis, ) - return batch_outputs - def truncate_sequences( self, ids: List[int], @@ -1913,10 +1897,7 @@ def truncate_sequences( f"but the first sequence has a length {len(ids)}. " ) if truncation_strategy == TruncationStrategy.ONLY_FIRST: - error_msg = ( - error_msg + "Please select another truncation strategy than " - f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." - ) + error_msg = f"{error_msg}Please select another truncation strategy than {truncation_strategy}, for instance 'longest_first' or 'only_second'." logger.error(error_msg) elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: logger.warning( @@ -2042,7 +2023,7 @@ def _pad( self.pad_token_id ] * difference + required_input else: - raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + raise ValueError(f"Invalid padding strategy:{str(self.padding_side)}") elif return_attention_mask and "attention_mask" not in encoded_inputs: encoded_inputs["attention_mask"] = [1] * len(required_input) @@ -2170,11 +2151,7 @@ def get_special_tokens_mask( all_special_ids = self.all_special_ids # cache the property - special_tokens_mask = [ - 1 if token in all_special_ids else 0 for token in token_ids_0 - ] - - return special_tokens_mask + return [1 if token in all_special_ids else 0 for token in token_ids_0] @staticmethod def clean_up_tokenization(out_string: str) -> str: @@ -2187,7 +2164,7 @@ def clean_up_tokenization(out_string: str) -> str: Returns: :obj:`str`: The cleaned-up string. """ - out_string = ( + return ( out_string.replace(" .", ".") .replace(" ?", "?") .replace(" !", "!") @@ -2199,7 +2176,6 @@ def clean_up_tokenization(out_string: str) -> str: .replace(" 've", "'ve") .replace(" 're", "'re") ) - return out_string def _eventual_warn_about_too_long_sequence( self, ids: List[int], max_length: Optional[int], verbose: bool diff --git a/tokenization_utils_fast.py b/tokenization_utils_fast.py index a43f162..96e8724 100644 --- a/tokenization_utils_fast.py +++ b/tokenization_utils_fast.py @@ -67,11 +67,10 @@ def prepend(texts, time_ids, stringify_time_id=noop): def prepend_tuples(text_pairs, time_id_pairs, stringify_time_id=noop): """Prepend the time to each text, where the inputs are pairs""" - texts = [ + return [ prepend(text_pair, time_id_pair, stringify_time_id) for text_pair, time_id_pair in zip(text_pairs, time_id_pairs) ] - return texts MODEL_TO_TRAINER_MAPPING = { @@ -134,7 +133,7 @@ def __init__(self, *args, **kwargs): self._tokenizer = fast_tokenizer if slow_tokenizer is not None: - kwargs.update(slow_tokenizer.init_kwargs) + kwargs |= slow_tokenizer.init_kwargs # We call this after having initialized the backend tokenizer because we update it. super().__init__(**kwargs) @@ -166,10 +165,11 @@ def get_added_vocab(self) -> Dict[str, int]: """ base_vocab = self._tokenizer.get_vocab(with_added_tokens=False) full_vocab = self._tokenizer.get_vocab(with_added_tokens=True) - added_vocab = dict( - (tok, index) for tok, index in full_vocab.items() if tok not in base_vocab - ) - return added_vocab + return { + tok: index + for tok, index in full_vocab.items() + if tok not in base_vocab + } def __len__(self) -> int: """ @@ -262,19 +262,14 @@ def convert_tokens_to_ids( if isinstance(tokens, str): return self._convert_token_to_id_with_added_voc(tokens) - ids = [] - for token in tokens: - ids.append(self._convert_token_to_id_with_added_voc(token)) - return ids + return [self._convert_token_to_id_with_added_voc(token) for token in tokens] def _convert_token_to_id_with_added_voc(self, token: str) -> int: index = self._tokenizer.token_to_id(token) - if index is None: - return self.unk_token_id - return index + return self.unk_token_id if index is None else index def _convert_id_to_token(self, index: int) -> Optional[str]: - return self._tokenizer.id_to_token(int(index)) + return self._tokenizer.id_to_token(index) def _add_tokens( self, new_tokens: List[Union[str, AddedToken]], special_tokens=False @@ -470,14 +465,14 @@ def _batch_encode_plus( time_id = batch_time_id[seq_i] encoding.time_ids = ( [ - self.time_to_id[time_id] + self.time_to_id[time_id[sequence_id]] if sequence_id is not None else self.pad_time_id for sequence_id in encoding.sequence_ids ] - if not isinstance(time_id, tuple) + if isinstance(time_id, tuple) else [ - self.time_to_id[time_id[sequence_id]] + self.time_to_id[time_id] if sequence_id is not None else self.pad_time_id for sequence_id in encoding.sequence_ids @@ -624,8 +619,7 @@ def _decode( ) if clean_up_tokenization_spaces: - clean_text = self.clean_up_tokenization(text) - return clean_text + return self.clean_up_tokenization(text) else: return text @@ -658,10 +652,10 @@ def _save_pretrained( if save_slow: added_tokens_file = os.path.join( save_directory, - (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE, + (f"{filename_prefix}-" if filename_prefix else "") + + ADDED_TOKENS_FILE, ) - added_vocab = self.get_added_vocab() - if added_vocab: + if added_vocab := self.get_added_vocab(): with open(added_tokens_file, "w", encoding="utf-8") as f: out_str = json.dumps(added_vocab, ensure_ascii=False) f.write(out_str) @@ -674,7 +668,8 @@ def _save_pretrained( if save_fast: tokenizer_file = os.path.join( save_directory, - (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE, + (f"{filename_prefix}-" if filename_prefix else "") + + TOKENIZER_FILE, ) self.backend_tokenizer.save(tokenizer_file) file_names = file_names + (tokenizer_file,) diff --git a/train_tempobert.py b/train_tempobert.py index 549db5e..41d6f4a 100644 --- a/train_tempobert.py +++ b/train_tempobert.py @@ -353,8 +353,8 @@ def load_data( # DataCollatorForLanguageModeling is more efficient when it receives the `special_tokens_mask`. return_special_tokens_mask = True - if data_args.line_by_line: - tokenized_dataset = tokenize_dataset_line_by_line( + return ( + tokenize_dataset_line_by_line( dataset, data_args, training_args, @@ -364,8 +364,8 @@ def load_data( max_seq_length, return_special_tokens_mask, ) - else: - tokenized_dataset = tokenize_dataset_concat( + if data_args.line_by_line + else tokenize_dataset_concat( dataset, data_args, training_args, @@ -375,7 +375,7 @@ def load_data( max_seq_length, return_special_tokens_mask, ) - return tokenized_dataset + ) def train_tempobert():