Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ELECTRA and GPT2 support #10

Open
Stochastic-Adventure opened this issue Aug 5, 2021 · 2 comments
Open

ELECTRA and GPT2 support #10

Stochastic-Adventure opened this issue Aug 5, 2021 · 2 comments

Comments

@Stochastic-Adventure
Copy link

Hi,

I'm wondering how to add ELECTRA and GPT2 support to this module.

Neither ELECTRA nor GPT2 has pooled output, unlike BERT/RoBERTa-based model.

I noticed in the models.py the model is implemented as following:

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states
        )

        pooled_output = outputs[1]
        seq_output = outputs[0]
        logits = self.output2logits(pooled_output, seq_output, input_ids)

        return self.calc_loss(logits, outputs, labels)

There are no pooled_output for ELECTRA/GPT2 sequence classification models, only seq_output is in the outputs variable.

How to get around this limitation and get a working version of ELECTRA/GPT2?
Thank you!

@bugface
Copy link
Contributor

bugface commented Aug 9, 2021

I will look at these two models and get back to you when I add them or found a way to add them.

@bugface
Copy link
Contributor

bugface commented Aug 9, 2021

For ELECTRA, you can manually extract the pooled output representation of [CLS] (from HuggingFace https://huggingface.co/transformers/_modules/transformers/models/electra/modeling_electra.html#ElectraForSequenceClassification) as

discriminator_hidden_states = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
            output_hidden_states,
            return_dict,
        )

seq_output = discriminator_hidden_states[0]
pooled_output = seq_output[:, 0, :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants