diff --git a/EduNLP/ModelZoo/jiuzhang/__init__.py b/EduNLP/ModelZoo/jiuzhang/__init__.py index e332ae3b..9e805efa 100644 --- a/EduNLP/ModelZoo/jiuzhang/__init__.py +++ b/EduNLP/ModelZoo/jiuzhang/__init__.py @@ -1,2 +1,2 @@ from .jiuzhang import * -from .modeling import CPTModel as JiuzhangModel +from .modeling import CPTModel as Jiuzhang diff --git a/EduNLP/ModelZoo/jiuzhang/jiuzhang.py b/EduNLP/ModelZoo/jiuzhang/jiuzhang.py index 92296aaa..b42fee19 100644 --- a/EduNLP/ModelZoo/jiuzhang/jiuzhang.py +++ b/EduNLP/ModelZoo/jiuzhang/jiuzhang.py @@ -8,7 +8,7 @@ from typing import List from ..rnn.harnn import HAM from transformers import BartConfig as JiuzhangConfig -from .modeling import CPTModel as JiuzhangModel +from .modeling import CPTModel as Jiuzhang __all__ = ["JiuzhangForPropertyPrediction", "JiuzhangForKnowledgePrediction"] @@ -20,10 +20,10 @@ def __init__(self, pretrained_model_dir=None, head_dropout=0.5, init=True): jiuzhang_config = JiuzhangConfig.from_pretrained(pretrained_model_dir) if init: print(f'Load Jiuzhang from checkpoint: {pretrained_model_dir}') - self.jiuzhang = JiuzhangModel.from_pretrained(pretrained_model_dir, ignore_mismatched_sizes=True) + self.jiuzhang = Jiuzhang.from_pretrained(pretrained_model_dir, ignore_mismatched_sizes=True) else: print(f'Load Jiuzhang from config: {pretrained_model_dir}') - self.jiuzhang = JiuzhangModel(jiuzhang_config) + self.jiuzhang = Jiuzhang(jiuzhang_config) self.hidden_size = self.jiuzhang.config.hidden_size self.head_dropout = head_dropout self.dropout = nn.Dropout(head_dropout) @@ -90,10 +90,10 @@ def __init__(self, jiuzhang_config = JiuzhangConfig.from_pretrained(pretrained_model_dir) if init: print(f'Load Jiuzhang from checkpoint: {pretrained_model_dir}') - self.jiuzhang = JiuzhangModel.from_pretrained(pretrained_model_dir, ignore_mismatched_sizes=True) + self.jiuzhang = Jiuzhang.from_pretrained(pretrained_model_dir, ignore_mismatched_sizes=True) else: print(f'Load Jiuzhang from config: {pretrained_model_dir}') - self.jiuzhang = JiuzhangModel(jiuzhang_config) + self.jiuzhang = Jiuzhang(jiuzhang_config) self.hidden_size = self.jiuzhang.config.hidden_size self.head_dropout = head_dropout self.dropout = nn.Dropout(head_dropout) diff --git a/EduNLP/ModelZoo/jiuzhang/modeling.py b/EduNLP/ModelZoo/jiuzhang/modeling.py index a9291f9d..b661d9f3 100644 --- a/EduNLP/ModelZoo/jiuzhang/modeling.py +++ b/EduNLP/ModelZoo/jiuzhang/modeling.py @@ -28,18 +28,12 @@ BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqModelOutput, - Seq2SeqLMOutput, - Seq2SeqSequenceClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging from transformers import BartConfig as CPTConfig from transformers import BertModel, BertConfig -logger = logging.get_logger(__name__) - - def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. @@ -84,17 +78,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) -def attention_mask_func(attention_scores, attention_mask): - return attention_scores + attention_mask - - -def init_method(std): - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - class CPTLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -377,30 +360,6 @@ def forward( return outputs -class CPTClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.dense = nn.Linear(input_dim, inner_dim) - self.dropout = nn.Dropout(p=pooler_dropout) - self.out_proj = nn.Linear(inner_dim, num_classes) - - def forward(self, hidden_states: torch.Tensor): - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - class CPTPretrainedModel(PreTrainedModel): config_class = CPTConfig base_model_prefix = "model" @@ -725,568 +684,3 @@ def forward( # encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None, encoder_last_hidden_state=encoder_hidden_states, ) - - -class CPTForConditionalGeneration(CPTPretrainedModel): - base_model_prefix = "model" - _keys_to_ignore_on_load_missing = [ - r"final_logits_bias", - r"encoder\.version", - r"decoder\.version", - r"lm_head\.weight", - ] - - def __init__(self, config): - super().__init__(config) - self.model = CPTModel(config) - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) - self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) - - self.init_weights() - - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) - return new_embeddings - - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - encoder_outputs=None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., - config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. - Returns: - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return Seq2SeqLMOutput( - loss=masked_lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past=None, - attention_mask=None, - head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - - @staticmethod - def _expand_inputs_for_generation( - input_ids: torch.LongTensor, - expand_size: int = 1, - is_encoder_decoder: bool = False, - attention_mask: torch.LongTensor = None, - encoder_outputs=None, - **model_kwargs, - ): - expanded_return_idx = ( - torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) - ) - input_ids = input_ids.index_select(0, expanded_return_idx) - - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) - - if attention_mask is not None: - model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) - - if is_encoder_decoder: - assert encoder_outputs is not None - device = encoder_outputs.last_hidden_state.device - encoder_outputs["hidden_states"] = tuple( - h.index_select(0, expanded_return_idx.to(device)) for h in encoder_outputs["hidden_states"] - ) - model_kwargs["encoder_outputs"] = encoder_outputs - return input_ids, model_kwargs - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - - -class CPTForSequenceClassification(CPTPretrainedModel): - def __init__(self, config: CPTConfig, cls_mode=3, **kwargs): - super().__init__(config, **kwargs) - self.model = CPTModel(config) - cls_mode = getattr(config, "cls_mode", cls_mode) - if cls_mode == 1: - logger.info("Encoder for classification.") - cls_dim = config.d_model - elif cls_mode == 2: - logger.info("Decoder for classification.") - cls_dim = config.d_model - elif cls_mode == 3: - logger.info("Both encoder & decoder for classification.") - cls_dim = config.d_model * 2 - else: - raise NotImplementedError - - self.cls_head = CPTClassificationHead( - cls_dim, - cls_dim, - config.num_labels, - config.classifier_dropout, - ) - self.model._init_weights(self.cls_head.dense) - self.model._init_weights(self.cls_head.out_proj) - self.cls_mode = cls_mode - config.cls_mode = cls_mode - - def forward( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - encoder_outputs=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - if input_ids is None and inputs_embeds is not None: - raise NotImplementedError( - f"Passing input embeddings is currently not supported for {self.__class__.__name__}" - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - ) - - hidden_states = outputs.last_hidden_state - enc_hidden_states = outputs.encoder_last_hidden_state - enc_rep = enc_hidden_states[:, 0] - - if self.cls_mode >= 2: - eos_mask = input_ids.eq(self.config.eos_token_id) - - if len(torch.unique(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - dec_rep = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] - - if self.cls_mode == 1: - logits = self.cls_head(enc_rep) - elif self.cls_mode == 2: - logits = self.cls_head(dec_rep) - elif self.cls_mode == 3: - rep = torch.cat([enc_rep, dec_rep], dim=-1) - logits = self.cls_head(rep) - else: - raise NotImplementedError - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -class CPTForPretraining(CPTPretrainedModel): - base_model_prefix = "model" - _keys_to_ignore_on_load_missing = [ - r"final_logits_bias", - r"encoder\.version", - r"decoder\.version", - r"lm_head\.weight", - ] - - def __init__(self, config: CPTConfig): - super().__init__(config) - self.model = CPTModel(config) - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) - self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) - self.num_decoder_layers = config.decoder_layers - - self.init_weights() - - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) - return new_embeddings - - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - labels=None, - use_decoder=None, - ): - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - batch_ids = torch.arange(input_ids.size(0)).to(use_decoder) - use_decoder_batch_ids = batch_ids[use_decoder == 1] - no_use_decoder_batch_ids = batch_ids[use_decoder != 1] - reorder_batch_ids = torch.cat([use_decoder_batch_ids, no_use_decoder_batch_ids], dim=0) - input_ids = input_ids[reorder_batch_ids] - attention_mask = attention_mask[reorder_batch_ids] - decoder_input_ids = decoder_input_ids[reorder_batch_ids] - num_use_decoder = use_decoder_batch_ids.size(0) - - encoder_outputs = self.model.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=torch.ones_like(input_ids), - output_hidden_states=True, - ) - encoder_outputs_for_decoder = encoder_outputs.hidden_states[-self.num_decoder_layers - 1] - encoder_output = encoder_outputs.last_hidden_state - - decoder_lm_logits = None - if num_use_decoder > 0: - decoder_outputs = self.model( - input_ids[:num_use_decoder], attention_mask=attention_mask[:num_use_decoder], - decoder_input_ids=decoder_input_ids[:num_use_decoder], - encoder_outputs=encoder_outputs_for_decoder[:num_use_decoder] - ).last_hidden_state - decoder_lm_logits = self.lm_head(decoder_outputs) + self.final_logits_bias - - encoder_lm_logits = None - if num_use_decoder < input_ids.size(0): - encoder_lm_logits = self.lm_head(encoder_output[num_use_decoder:]) + self.final_logits_bias - - if decoder_lm_logits is None: - reorder_lm_logits = encoder_lm_logits - elif encoder_lm_logits is None: - reorder_lm_logits = decoder_lm_logits - else: - reorder_lm_logits = torch.cat([decoder_lm_logits, encoder_lm_logits], dim=0) - _, reverse_batch_ids = reorder_batch_ids.sort(dim=0) - lm_logits = reorder_lm_logits[reverse_batch_ids] - - loss_fct = CrossEntropyLoss() - loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) - - return Seq2SeqLMOutput( - loss=loss, - logits=decoder_lm_logits, - ) - - -class CPTForSC(CPTPretrainedModel): - base_model_prefix = "model" - _keys_to_ignore_on_load_missing = [ - r"final_logits_bias", - r"encoder\.version", - r"decoder\.version", - r"lm_head\.weight", - ] - - def __init__(self, config: CPTConfig, fronzen=True, cross=False, decoder_rate=0.5): - super().__init__(config) - self.model = CPTModel(config) - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) - self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) - self.num_decoder_layers = config.decoder_layers - self.fronzen = fronzen - self.cross = cross - self.decoder_rate = decoder_rate - - self.init_weights() - - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) - return new_embeddings - - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def encode_and_decode(self, input_ids, attention_mask, decoder_input_ids, num_use_decoder): - encoder_outputs = self.model.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=torch.ones_like(input_ids), - output_hidden_states=True, - ) - encoder_outputs_for_decoder = encoder_outputs.hidden_states[-self.num_decoder_layers - 1] - encoder_output = encoder_outputs.last_hidden_state - - decoder_lm_logits = None - encoder_lm_logits = None - - if num_use_decoder > 0: - decoder_outputs = self.model( - input_ids[:num_use_decoder], attention_mask=attention_mask[:num_use_decoder], - decoder_input_ids=decoder_input_ids[:num_use_decoder], - encoder_outputs=encoder_outputs_for_decoder[:num_use_decoder] - ).last_hidden_state - decoder_lm_logits = self.lm_head(decoder_outputs) + self.final_logits_bias - - if num_use_decoder < input_ids.size(0): - encoder_lm_logits = self.lm_head(encoder_output[num_use_decoder:]) + self.final_logits_bias - - if decoder_lm_logits is None: - reorder_lm_logits = encoder_lm_logits - elif encoder_lm_logits is None: - reorder_lm_logits = decoder_lm_logits - else: - reorder_lm_logits = torch.cat([decoder_lm_logits, encoder_lm_logits], dim=0) - - return reorder_lm_logits - - def forward( - self, - input_ids=None, - attention_mask=None, - labels=None, - adv_labels=None, - use_decoder=None, - ): - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - batch_ids = torch.arange(input_ids.size(0)).to(use_decoder) - use_decoder_batch_ids = batch_ids[use_decoder == 1] - no_use_decoder_batch_ids = batch_ids[use_decoder != 1] - reorder_batch_ids = torch.cat([use_decoder_batch_ids, no_use_decoder_batch_ids], dim=0) - input_ids = input_ids[reorder_batch_ids] - attention_mask = attention_mask[reorder_batch_ids] - decoder_input_ids = decoder_input_ids[reorder_batch_ids] - labels = labels[reorder_batch_ids] - adv_labels = adv_labels[reorder_batch_ids] - num_use_decoder = use_decoder_batch_ids.size(0) - - loss_fct = CrossEntropyLoss() - - if self.fronzen: - with torch.no_grad(): - lm_logits = self.encode_and_decode( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - num_use_decoder=num_use_decoder, - ) - lm_loss = None - else: - lm_logits = self.encode_and_decode( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - num_use_decoder=num_use_decoder, - ) - lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) - - generate_input_ids = lm_logits.argmax(-1) - not_masked_indices = (labels == -100) - generate_input_ids[not_masked_indices] = input_ids[not_masked_indices] - - if self.cross: - # cross the adv process - use_decoder = torch.bernoulli(torch.tensor([self.decoder_rate] * input_ids.size(0))).long() - else: - # adv process need reverse the use_decoder - use_decoder = torch.ones_like(use_decoder) - use_decoder[:num_use_decoder] = 0 - batch_ids = torch.arange(input_ids.size(0)).to(use_decoder) - use_decoder_batch_ids = batch_ids[use_decoder == 1] - no_use_decoder_batch_ids = batch_ids[use_decoder != 1] - reorder_batch_ids = torch.cat([use_decoder_batch_ids, no_use_decoder_batch_ids], dim=0) - generate_input_ids = generate_input_ids[reorder_batch_ids] - attention_mask = attention_mask[reorder_batch_ids] - decoder_input_ids = decoder_input_ids[reorder_batch_ids] - adv_labels = adv_labels[reorder_batch_ids] - num_use_decoder = use_decoder_batch_ids.size(0) - - adv_lm_logits = self.encode_and_decode( - input_ids=generate_input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - num_use_decoder=num_use_decoder, - ) - - adv_lm_loss = loss_fct(adv_lm_logits.view(-1, self.config.vocab_size), adv_labels.view(-1)) - - if lm_loss is None: - lm_loss = torch.zeros_like(adv_lm_loss) - - loss = adv_lm_loss + lm_loss - - return Seq2SeqLMOutput( - loss=loss, - logits=adv_lm_logits, - ) diff --git a/EduNLP/Pretrain/jiuzhang_vec.py b/EduNLP/Pretrain/jiuzhang_vec.py index 284fdd5a..04b18089 100644 --- a/EduNLP/Pretrain/jiuzhang_vec.py +++ b/EduNLP/Pretrain/jiuzhang_vec.py @@ -1,7 +1,4 @@ -import os -from typing import List, Union -from transformers import BertForMaskedLM -from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding +from transformers import DataCollatorWithPadding from transformers import Trainer, TrainingArguments from copy import deepcopy diff --git a/EduNLP/Vector/jiuzhang_vec.py b/EduNLP/Vector/jiuzhang_vec.py index 0ed42070..e87680e2 100644 --- a/EduNLP/Vector/jiuzhang_vec.py +++ b/EduNLP/Vector/jiuzhang_vec.py @@ -1,4 +1,4 @@ -from EduNLP.ModelZoo.jiuzhang import JiuzhangModel as Jiuzhang +from EduNLP.ModelZoo.jiuzhang import Jiuzhang from .meta import Vector import torch diff --git a/tests/test_pretrain/test_pretrained_jiuzhang.py b/tests/test_pretrain/test_pretrained_jiuzhang.py index cf547acf..e606ae90 100644 --- a/tests/test_pretrain/test_pretrained_jiuzhang.py +++ b/tests/test_pretrain/test_pretrained_jiuzhang.py @@ -3,20 +3,16 @@ os.environ["WANDB_DISABLED"] = "true" import torch from EduNLP.ModelZoo.jiuzhang import JiuzhangForPropertyPrediction, JiuzhangForKnowledgePrediction -from EduNLP.ModelZoo.jiuzhang.modeling import CPTModel as HFJiuzhangModel +from EduNLP.ModelZoo.jiuzhang import Jiuzhang as HFJiuzhang from EduNLP.Pretrain import JiuzhangTokenizer from EduNLP.Pretrain import finetune_jiuzhang_for_property_prediction, finetune_jiuzhang_for_knowledge_prediction from EduNLP.Vector import T2V, JiuzhangModel from EduNLP.I2V import get_pretrained_i2v, Jiuzhang TEST_GPU = False -from transformers import AutoConfig class TestPretrainJiuzhang: - def save_model(self, pretrained_model_dir): - model = HFJiuzhangModel.from_pretrained("fnlp/cpt-base") - model.save_pretrained(pretrained_model_dir) def test_tokenizer(self, standard_luna_data, pretrained_model_dir): test_items = [ @@ -52,8 +48,13 @@ def test_tokenizer(self, standard_luna_data, pretrained_model_dir): res = tokenizer(test_items, key=lambda x: x["ques_content"], return_tensors=False) assert isinstance(res["input_ids"], list) + def test_save_model(self, pretrained_model_dir): + model = HFJiuzhang.from_pretrained("fnlp/cpt-base") + tokenizer = JiuzhangTokenizer.from_pretrained(pretrained_model_dir) + model.resize_token_embeddings(len(tokenizer.bert_tokenizer)) + model.save_pretrained(pretrained_model_dir) + def test_t2v(self, pretrained_model_dir): - pretrained_model_dir = pretrained_model_dir items = [ {'stem': '如图$\\FigureID{088f15ea-8b7c-11eb-897e-b46bfc50aa29}$, \ 若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$'} @@ -61,10 +62,6 @@ def test_t2v(self, pretrained_model_dir): tokenizer = JiuzhangTokenizer.from_pretrained(pretrained_model_dir) encodes = tokenizer(items, key=lambda x: x['stem']) - model = HFJiuzhangModel.from_pretrained("fnlp/cpt-base") - model.resize_token_embeddings(len(tokenizer.bert_tokenizer)) - model.save_pretrained(pretrained_model_dir) - t2v = JiuzhangModel(pretrained_model_dir) output = t2v(encodes) assert output.shape[2] == t2v.vector_size @@ -78,7 +75,6 @@ def test_t2v(self, pretrained_model_dir): t2v.infer_vector(encodes, pooling_strategy='average') def test_i2v(self, pretrained_model_dir): - pretrained_model_dir = pretrained_model_dir items = [ {'stem': '如图$\\FigureID{088f15ea-8b7c-11eb-897e-b46bfc50aa29}$, \ 若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$'} @@ -100,7 +96,6 @@ def test_i2v(self, pretrained_model_dir): assert len(t_vec[0][0]) == i2v.vector_size def test_train_pp(self, standard_luna_data, pretrained_model_dir): - self.save_model(pretrained_model_dir) data_params = { "stem_key": "ques_content", "label_key": "difficulty" @@ -114,8 +109,6 @@ def test_train_pp(self, standard_luna_data, pretrained_model_dir): train_items = standard_luna_data # train without eval_items - model = HFJiuzhangModel.from_pretrained("fnlp/cpt-base") - model.save_pretrained(pretrained_model_dir) finetune_jiuzhang_for_property_prediction( train_items, pretrained_model_dir, @@ -140,8 +133,6 @@ def test_train_pp(self, standard_luna_data, pretrained_model_dir): model(**encodes) def test_train_kp(self, standard_luna_data, pretrained_model_dir): - # pretrained_model_dir = 'D:\\EduNLP' - self.save_model(pretrained_model_dir) data_params = { "stem_key": "ques_content", "label_key": "know_list"