Skip to content

Commit

Permalink
Fixed errors by upper-case of model name, and changed the description (
Browse files Browse the repository at this point in the history
…#82)


* fix a glm tokenizer bug
Signed-off-by: zhaohu xing <[email protected]>
* Update tokenizer.py
Signed-off-by: Anhforth <[email protected]>
  • Loading branch information
BAAI-OpenPlatform authored Aug 29, 2022
1 parent dee25b7 commit e57557d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 14 deletions.
1 change: 0 additions & 1 deletion doc_zh/TUTORIAL_4_TRAINER.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ python train.py --test1=1

2. [glm-title-generation-env-trainer](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/glm_title_generation/train_env_trainer.py)


# 使用 pytorchDDP launcher 或 deepspeed launcher 运行
如果你使用多个GPU来训练模型,你可以直接运行train.py来调用FlagAI训练器中的启动器。
```commandline
Expand Down
1 change: 1 addition & 0 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __getattr__(self, name):
}



class AutoLoader:

def __init__(self,
Expand Down
20 changes: 7 additions & 13 deletions flagai/data/tokenizer/uni_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,13 @@


def is_control(ch):
"""控制类字符判断
"""
https://en.wikipedia.org/wiki/Control_character
https://www.fileformat.info/info/unicode/category/Cc/index.htm
https://www.fileformat.info/info/unicode/category/Cf/index.htm
"""
return unicodedata.category(ch) in ('Cc', 'Cf')



class Tokenizer(BaseTokenizer):
def __init__(self,
add_block_symbols=True,
Expand All @@ -56,7 +53,7 @@ def __init__(self,
if self.tokenizer_class == "wp":
self.text_tokenizer = WordpieceTokenizer(self.vocab_file)
elif self.tokenizer_class == "bpe":
if self.tokenizer_model_name.startswith('clip'):
if self.tokenizer_model_name.lower().startswith('clip'):
self.text_tokenizer = MMBPETokenizer(self.vocab_file, self.merges_file)
else:
self.text_tokenizer = BPETokenizer(self.vocab_file, self.merges_file)
Expand All @@ -65,8 +62,6 @@ def __init__(self,
else:
raise NotImplementedError("cannot assign a tokenize class")

self.is_glm = self.tokenizer_model_name.startswith('GLM')
# self.is_clip = self.tokenizer_model_name.startswith('clip')
self.num_tokens = self.text_tokenizer.vocab_size

if self.tokenizer_class == "wp":
Expand Down Expand Up @@ -125,7 +120,7 @@ def __init__(self,
self.num_tokens += 2
self.num_command_tokens += 2
elif self.tokenizer_class == "bpe":
if self.tokenizer_model_name.startswith('roberta'):
if self.tokenizer_model_name.lower().startswith('roberta'):
self.num_command_tokens = 6
self.num_text_tokens = self.num_tokens - 3
self._command_tokens = [
Expand All @@ -151,7 +146,7 @@ def __init__(self,
])
self.num_tokens += 2
self.num_command_tokens += 2
elif self.tokenizer_model_name.startswith('clip'):
elif self.tokenizer_model_name.lower().startswith('clip'):
self.num_command_tokens = 2
self._command_tokens = [
CommandToken('sot', '<start_of_text>',
Expand All @@ -170,7 +165,7 @@ def __init__(self,
self.text_tokenizer.convert_token_to_id('<|endoftext|>'))
]
if add_block_symbols:
if self.tokenizer_model_name.startswith('GLM'):
if self.tokenizer_model_name.lower().startswith('glm'):
unk_token_id = self.num_tokens + 5
cls_token_id = self.num_tokens + 2
num_tokens_to_add = 5
Expand Down Expand Up @@ -215,7 +210,7 @@ def __init__(self,
self.num_text_tokens = self.text_tokenizer.vocab_size
self.num_tokens = self.num_text_tokens

if self.tokenizer_model_name.startswith('GLM'):
if self.tokenizer_model_name.lower().startswith('glm'):
pad_token_id = self.num_tokens
eos_token_id = self.num_tokens
unk_token_id = self.num_tokens + 4
Expand Down Expand Up @@ -450,7 +445,6 @@ def CommandTokenIds(self, exception=None):
result.append(s.Id)
return (result)


def encode_plus_non_glm(
self,
text,
Expand Down Expand Up @@ -517,7 +511,7 @@ def encode_plus( #for Seq2seq
truncation=True,
max_length=None,
):
if not self.tokenizer_model_name.startswith("GLM"):
if not self.tokenizer_model_name.lower().startswith("glm"):
return self.encode_plus_non_glm(source_text, second_text, truncation, max_length)
sop_id = self.get_command_id('sop') #start of piece
eop_id = self.get_command_id('eop') #end of piece
Expand Down

0 comments on commit e57557d

Please sign in to comment.