Skip to content

Commit

Permalink
Upgrade to HF transformers 4.3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 24, 2021
1 parent 1662d78 commit 78ba8f6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
6 changes: 3 additions & 3 deletions nbs/01-gpt2-with-value-head.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"# export\n",
"\n",
"from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel\n",
"from transformers.modeling_utils import top_k_top_p_filtering\n",
"from transformers import top_k_top_p_filtering\n",
"from torch import nn\n",
"from torch.nn import Identity\n",
"import torch.nn.functional as F\n",
Expand Down Expand Up @@ -129,7 +129,7 @@
" def forward(\n",
" self,\n",
" input_ids=None,\n",
" past=None,\n",
" past_key_values=None,\n",
" attention_mask=None,\n",
" token_type_ids=None,\n",
" position_ids=None,\n",
Expand All @@ -142,7 +142,7 @@
" \n",
" transformer_outputs = self.transformer(\n",
" input_ids,\n",
" past=past,\n",
" past_key_values=past_key_values,\n",
" attention_mask=attention_mask,\n",
" token_type_ids=token_type_ids,\n",
" position_ids=position_ids,\n",
Expand Down
11 changes: 5 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@

-e .
jupyterlab==2.0.1
nbdev==0.2.16
numpy==1.18.2
pandas==1.0.3
simpletransformers==0.21.4
torch==1.4.0
simpletransformers==0.60.9
torch>=1.4.0
tqdm==4.43.0
transformers==2.6.0
wandb==0.8.35
matplotlib==3.2.1
transformers==4.3.2
wandb==0.10.20
matplotlib==3.2.1
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ language = English
custom_sidebar = False
license = apache2
status = 2
requirements = torch>=1.4.0 transformers==2.6.0 numpy>=1.18.2
requirements = torch>=1.4.0 transformers==4.3.2 numpy>=1.18.2
nbs_path = ./nbs/
doc_path = docs
doc_host = https://lvwerra.github.io
Expand Down
6 changes: 3 additions & 3 deletions trl/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Cell

from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model, GPT2PreTrainedModel
from transformers.modeling_utils import top_k_top_p_filtering
from transformers import top_k_top_p_filtering
from torch import nn
from torch.nn import Identity
import torch.nn.functional as F
Expand Down Expand Up @@ -78,7 +78,7 @@ def detach_value_head(self):
def forward(
self,
input_ids=None,
past=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
Expand All @@ -91,7 +91,7 @@ def forward(

transformer_outputs = self.transformer(
input_ids,
past=past,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
Expand Down

0 comments on commit 78ba8f6

Please sign in to comment.