-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(whl): add tabmwp env and prompt pg policy (#667)
* wrong * update config * update command policy * debug * debug * add glm * add glm * add glm model * add eval return * reformat * modify action space * modify action space * polish answer process * update policy * update rwkv * update policy * polish * polish * debug prompt pg * add parse * update load env * add merge files * add merge files * feature(whl): add internlm * feature(whl): add internlm * update fix parse * add new dataset * fix datafiles * polish code * polish env * polish * polish * add model wrapper * polish wrapper * polish * remove redundant files * reformat * polish * debug * polish readme * reformat * polish tabmwp * test
- Loading branch information
Showing
21 changed files
with
1,107 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
|
||
from ding.utils import MODEL_REGISTRY | ||
from torch import nn | ||
try: | ||
from transformers import AutoTokenizer, AutoModelForTokenClassification | ||
except ImportError: | ||
import sys | ||
from ditk import logging | ||
logging.warning("not found transformer, please install it using: pip install transformers") | ||
sys.exit(1) | ||
|
||
|
||
@MODEL_REGISTRY.register('language_transformer') | ||
class LanguageTransformer(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "bert-base-uncased", | ||
add_linear: bool = False, | ||
embedding_size: int = 128, | ||
freeze_encoder: bool = True | ||
) -> None: | ||
super().__init__() | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModelForTokenClassification.from_pretrained(model_name) | ||
|
||
# Freeze transformer encoder and only train the linear layer | ||
if freeze_encoder: | ||
for param in self.model.parameters(): | ||
param.requires_grad = False | ||
|
||
if add_linear: | ||
# Add an additional small, adjustable linear layer on top of BERT tuned through RL | ||
self.embedding_size = embedding_size | ||
self.linear = nn.Linear( | ||
self.model.config.hidden_size, embedding_size | ||
) # 768 for bert-base-uncased, distilbert-base-uncased | ||
else: | ||
self.linear = None | ||
|
||
def _calc_embedding(self, x: list) -> torch.Tensor: | ||
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer, | ||
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach | ||
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is | ||
# exactly ``max_length``, which can enable batch-wise computing. | ||
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device) | ||
output = self.model(**input, output_hidden_states=True) | ||
# Get last layer hidden states | ||
last_hidden_states = output.hidden_states[-1] | ||
# Get [CLS] hidden states | ||
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size | ||
|
||
if self.linear: | ||
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size | ||
|
||
return sentence_embedding | ||
|
||
def forward(self, train_samples: list, candidate_samples: list) -> dict: | ||
prompt_embedding = self._calc_embedding(train_samples) | ||
cands_embedding = self._calc_embedding(candidate_samples) | ||
scores = torch.mm(prompt_embedding, cands_embedding.t()) | ||
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
|
||
from ding.model.template.language_transformer import LanguageTransformer | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestNLPPretrainedModel: | ||
|
||
def check_model(self): | ||
test_pids = [1] | ||
cand_pids = [0, 2, 4] | ||
problems = [ | ||
"This is problem 0", "This is the first question", "Second problem is here", "Another problem", | ||
"This is the last problem" | ||
] | ||
ctxt_list = [problems[pid] for pid in test_pids] | ||
cands_list = [problems[pid] for pid in cand_pids] | ||
|
||
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) | ||
scores = model(ctxt_list, cands_list) | ||
assert scores.shape == (1, 3) | ||
|
||
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256) | ||
scores = model(ctxt_list, cands_list) | ||
assert scores.shape == (1, 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,3 +54,4 @@ | |
|
||
# new-type policy | ||
from .ppof import PPOFPolicy | ||
from .prompt_pg import PromptPGPolicy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.