Skip to content

Commit

Permalink
Fix Tokenizer encode (#645)
Browse files Browse the repository at this point in the history
* same encode with HF

* sequence_start -> add_bos

* complement
  • Loading branch information
AllentDan authored Nov 19, 2023
1 parent c02e281 commit 07640a3
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 33 deletions.
18 changes: 9 additions & 9 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}' \
return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
else:
Expand All @@ -230,7 +230,7 @@ def messages2prompt(self, messages, sequence_start=True):
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = '<BOS>'
ret = ''
if self.meta_instruction:
ret += f'{self.system}:{self.meta_instruction}{self.eosys}'

Expand Down Expand Up @@ -355,7 +355,7 @@ def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}' \
return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
else:
Expand All @@ -374,7 +374,7 @@ def messages2prompt(self, messages, sequence_start=True):
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = '<BOS>'
ret = ''
if self.meta_instruction:
ret += f'{self.system}{self.meta_instruction}{self.eosys}'

Expand Down Expand Up @@ -424,7 +424,7 @@ def decorate_prompt(self, prompt, sequence_start=True):
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'<BOS>{self.b_inst} ' \
return f'{self.b_inst} ' \
f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \
f'{prompt} {self.e_inst} '

Expand All @@ -443,7 +443,7 @@ def messages2prompt(self, messages, sequence_start=True):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.default_sys_prompt if not system else system
ret = f'<BOS>{self.b_inst} {self.b_sys} {system} {self.e_sys}'
ret = f'{self.b_inst} {self.b_sys} {system} {self.e_sys}'
for i, (user, assistant) in enumerate(zip(users, assistants)):
if i != 0:
ret += f'{self.b_inst} '
Expand Down Expand Up @@ -559,16 +559,16 @@ def _infill_prompt(self, prompt):
prefix, suffix = prompt.split('<FILL>')
if self.suffix_first:
# format as "<PRE> <SUF>{suf} <MID> {pre}"
prompt = f'<BOS><PRE> <SUF>{suffix} <MID> {prefix}'
prompt = f'<PRE> <SUF>{suffix} <MID> {prefix}'
else:
# format as "<PRE> {pre} <SUF>{suf} <MID>"
prompt = f'<BOS><PRE> {prefix} <SUF>{suffix} <MID>'
prompt = f'<PRE> {prefix} <SUF>{suffix} <MID>'
return prompt

def _get_prompt(self, prompt, sequence_start):
prompt = prompt.strip()
if sequence_start:
return f'<BOS>{self.b_inst} ' \
return f'{self.b_inst} ' \
f'{self.b_sys}{self.default_sys_prompt}{self.e_sys}' \
f'{prompt} {self.e_inst}'

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def generate(
prompt = messages
if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len:
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,10 @@ def _stream_infer(self,
session.sequence_length = 0

input_ids, input_lengths = self.preprocess(prompt)
# got input_ids with default add_bos == True
if not sequence_start and input_ids[0][0] == self.bos_id:
input_ids = input_ids[:, 1:]
input_lengths = input_lengths - 1
# will crash if last_token_id == eos_id and send empty input_ids
if sequence_end and request_output_len == 0:
input_ids = np.array([[self.bos_id]], dtype=np.uint32)
Expand Down
32 changes: 11 additions & 21 deletions lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,15 @@ def _maybe_add_prefix_space(self, tokens, decoded):
else:
return decoded

def encode(self, s: str):
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
add_bos = False
add_eos = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '')
add_bos = True
if s == '<EOS>':
s = ''
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
return self.model.Encode(s, add_bos=add_bos, **kwargs)

def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Expand Down Expand Up @@ -175,22 +167,20 @@ def _maybe_add_prefix_space(self, tokens, decoded):
else:
return decoded

def encode(self, s: str):
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
add_special_tokens = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '<s>')
if s == '<EOS>':
s = '</s>'
if len(s) == 0:
add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)
encoded = self.model.encode(s, **kwargs)
if not add_bos:
# in the middle of a session
if len(encoded) and encoded[0] == self.bos_token_id:
encoded = encoded[1:]
return encoded

def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Expand Down Expand Up @@ -261,15 +251,15 @@ def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_token_id

def encode(self, s: str):
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
return self.model.encode(s)
return self.model.encode(s, add_bos, **kwargs)

def decode(self, t: Sequence[int], offset: Optional[int] = None):
"""De-tokenize.
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main(model_path,
seed = random.getrandbits(64)
else:
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
input_ids = tokenizer.encode(prompt, nth_round == 1)
if step + len(
input_ids) + request_output_len >= tm_model.session_len:
print('WARNING: exceed session max length.'
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
assert isinstance(stop_words, List) and \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
stop_words = [
tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words
]
assert isinstance(stop_words, List) and all(
isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
# each id in stop_words represents a stop word
Expand Down

0 comments on commit 07640a3

Please sign in to comment.