Skip to content

Commit

Permalink
added lora adapter save interval and modified tokenizer
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed Aug 24, 2023
1 parent b70a407 commit fdb485d
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 23 deletions.
13 changes: 7 additions & 6 deletions examples/Aquila/Aquila-chat/Aquila-chat-lora.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
batch_size: 1
epochs: 5
batch_size: 4
epochs: 10
gradient_accumulation_steps: 1
lr: 4.0e-5
warm_up: 0.01
lora_r: 8
warm_up_iters: 200
lora_r: 16
lora_alpha: 32
save_interval: 200
log_interval: 1
save_interval: 500
log_interval: 10
bmt_cpu_offload: False
bmt_pre_load: True
bmt_lr_decay_style: 'cosine'
Expand All @@ -17,4 +18,4 @@ eps: 1.0e-8
lora: True

enable_sft_dataset_dir: './data/'
enable_sft_dataset_file: 'convo_samples.jsonl'
enable_sft_dataset_file: 'sft_v0.9.10_train_chinese.jsonl'
2 changes: 1 addition & 1 deletion examples/Aquila/Aquila-chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
悟道·天鹰Aquila系列模型将持续开源更优版本,大家可以先删除原来目录下的`checkpoints_in/aquilachat-7b`,再下载新权重,其他使用方式不变。

- 2023/07/24 :发布权重文件 v0.9,开源了 AquilaCode-multi、AquilaCode-py。 AquilaChat-7B和Aquila-7B权重无更新, AquilaCode-7B-NV和AquilaCode-7B-TS权重暂时不会有更新计划。
- Aquila-7B md5: 18eac56434db0198494b22b321633785
- Aquila-7B md5: 5b56d31c8154c9184a38ff7bc6b4d887
- AquilaChat-7B md5: 465683009c8b536ef4cca85febb0227c
- AquilaCode-multi md5:07cfce9440a0fa1ac2768b39d2cf4286
- AquilaCode-py md5:3faa85fc03d8fda70a73064f48d02d85
Expand Down
2 changes: 1 addition & 1 deletion examples/Aquila/Aquila-chat/README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The additional details of the Aquila model will be presented in the official tec
We will continue to release improved versions of Aquila model as open source. You can start by deleting the `checkpoints_in/aquilachat-7b` in the original directory and then download the new weights. Other usage methods remain unchanged. For more details, please refer to the folloing change log:

- 2023/07/24 :Released v0.8 checkpoint files,AquilaCode-multi and AquilaCode-python have been released while AquilaCode-7B-NV and AquilaCode-7B-TS are temporarily not maintained. There are no updates for the weights of Aquila-7B and AquilaChat-7B.
- Aquila-7B md5: 18eac56434db0198494b22b321633785
- Aquila-7B md5: 5b56d31c8154c9184a38ff7bc6b4d887
- AquilaChat-7B md5: 465683009c8b536ef4cca85febb0227c
- AquilaCode-multi md5:07cfce9440a0fa1ac2768b39d2cf4286
- AquilaCode-py md5:3faa85fc03d8fda70a73064f48d02d85
Expand Down
2 changes: 2 additions & 0 deletions examples/Aquila/Aquila-chat/aquila_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gc
gc.collect()
torch.cuda.empty_cache()
import sys;sys.path.append("/data2/yzd/FlagAI")
from flagai.auto_model.auto_loader import AutoLoader
from flagai.data.tokenizer import Tokenizer
from flagai.env_args import EnvArgs
Expand Down Expand Up @@ -38,6 +39,7 @@
training_script=__file__,
)
env_args = env_args.parse_args()

#env_args.wandb = False

# overwrite
Expand Down
8 changes: 5 additions & 3 deletions examples/Aquila/Aquila-chat/generate_chat_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
# Licensed under the Apache License, Version 2.0 (the "License")
import os
import torch
import sys;sys.path.append("/data2/yzd/FlagAI/")
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.data.tokenizer import Tokenizer
import torch.nn as nn
from flagai.model.predictor.aquila import aquila_generate


state_dict = "./checkpoints_in/"
model_name = 'aquila-7b'
state_dict = "./checkpoints_in/lora/"
model_name = 'aquilachat-7b'


loader = AutoLoader("lm",
Expand All @@ -20,7 +21,7 @@
use_cache=True,
fp16=True,
device='cuda',
adapter_dir='/data2/yzd/FlagAI/examples/Aquila/Aquila-chat/checkpoints_out/aquila_experiment/2023080216/') # Directory to adapter_model.bin and adapter_config.json
adapter_dir='/data2/yzd/FlagAI/examples/Aquila/Aquila-chat/checkpoints_out/aquilachat_experiment/2023081313/') # Directory to adapter_model.bin and adapter_config.json
model = loader.get_model()

tokenizer = loader.get_tokenizer()
Expand All @@ -37,6 +38,7 @@
"Write a short story about a dragon and a knight.",
"翻译成英文: '我饿了想吃饭'",
"write a fairy tale for me",
"用英文回答: 世界上最高的地方在哪里?"
]

for text in texts:
Expand Down
2 changes: 1 addition & 1 deletion examples/Aquila/Aquila-chat/hostfile
Original file line number Diff line number Diff line change
@@ -1 +1 @@
192.168.20.3 slots=1
192.168.20.2 slots=8
167 changes: 167 additions & 0 deletions flagai/data/tokenizer/uni_tokenizer/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import regex as re
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload

class Trie:
"""
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
Loose reference https://en.wikipedia.org/wiki/Trie
"""

def __init__(self):
self.data = {}

def add(self, word: str):
if not word:
# Prevent empty string
return
ref = self.data
for char in word:
ref[char] = char in ref and ref[char] or {}
ref = ref[char]
ref[""] = 1

def split(self, text: str) -> List[str]:
states = OrderedDict()

# This will contain every indices where we need
# to cut.
# We force to cut at offset 0 and len(text) (added later)
offsets = [0]

# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
skip = 0
# Main loop, Giving this algorithm O(n) complexity
for current, current_char in enumerate(text):
if skip and current < skip:
# Prevents the lookahead for matching twice
# like extra_id_100 and id_100
continue

# This will track every state
# that stop matching, we need to stop tracking them.
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
# fail on "b", we need to remove 0 from the valid states.
to_remove = set()
# Whenever we found a match, we need to drop everything
# this is a greedy algorithm, it will match on the first found token
reset = False

# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.

# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
# Here we are also actively looking for other earlier partial
# matches
# "[CLS]", "L", we need to match CLS even if L is special
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
# This partial match is later, we can stop looking
break
elif lookstart < start:
# This partial match is earlier, the trie pointer
# was already updated, so index is + 1
lookahead_index = current + 1
end = current + 1
else:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index

while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index

if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead

# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
break
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char]

# Storing back the new pointer into the states.
# Partial matches got longer by one.
states[start] = trie_pointer
else:
# The new character has not match in the trie, we need
# to stop keeping track of this partial match.
# We can't do it directly within the loop because of how
# python iteration works
to_remove.add(start)

# Either clearing the full start (we found a real match)
# Or clearing only the partial matches that didn't work.
if reset:
states = {}
else:
for start in to_remove:
del states[start]

# If this character is a starting character within the trie
# start keeping track of this partial match.
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]

# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break

return self.cut_text(text, offsets)

def cut_text(self, text, offsets):
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
offsets.append(len(text))
tokens = []
start = 0
for end in offsets:
if start > end:
logger.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
" anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
continue
tokens.append(text[start:end])
start = end

return tokens
50 changes: 41 additions & 9 deletions flagai/data/tokenizer/uni_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import List, Union, Optional
import unicodedata
import json
from flagai.data.tokenizer.uni_tokenizer.tokenization_utils import Trie


def is_control(ch):
Expand Down Expand Up @@ -372,7 +373,6 @@ def __init__(self,
if tk not in self.command_name_map:
res = self.search_special(tk)
self.add_command_token(tk, res,self.tokenizer_class)

self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {
tok.token: tok
Expand All @@ -392,7 +392,8 @@ def __init__(self,
self.token_end_id = self.TokenToId("<|endoftext|>")
except KeyError:
self.token_end_id = self.TokenToId("[SEP]")


self.tokens_trie = Trie()
#print("All special tokens: ", str([(k, v.token, v.Id) for k,v in self.command_name_map.items()]))

def get_vocab(self):
Expand Down Expand Up @@ -455,7 +456,7 @@ def _is_special(ch):
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')

def _encode(self, text):
tokens = self.text_tokenizer.tokenize(text)
tokens = self.tokenize(text)
ids = self.text_tokenizer.convert_tokens_to_ids(tokens)
return ids

Expand Down Expand Up @@ -484,7 +485,7 @@ def EncodeAsTokens(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.text_tokenizer.tokenize(processed_text)
tokens = self.tokenize(processed_text)
return tokens

def IdToToken(self, id):
Expand Down Expand Up @@ -521,10 +522,10 @@ def DecodeIds(self, ids):
tokens, self.command_token_map)

def encode(self, text):
if hasattr(self.text_tokenizer, "encode"):
return self.text_tokenizer.encode(text)
# if hasattr(self.text_tokenizer, "encode"):
# return self.text_tokenizer.encode(text)
return self.convert_tokens_to_ids(
self.text_tokenizer.tokenize(text))
self.tokenize(text))

def decode(self, ids):
if hasattr(self.text_tokenizer, "decode"):
Expand Down Expand Up @@ -791,8 +792,14 @@ def tokenize_as_tensor(self, texts):
sot_token=sot_token,
eot_token=eot_token)

def tokenize(self, text, maxlen=None, add_spatial_tokens=False):
tokens = self.text_tokenizer.tokenize(text)
def _create_trie(self, unique_no_split_tokens):
trie = Trie()
for token in unique_no_split_tokens:
trie.add(token)
self.tokens_trie = trie

def _tokenize(self, text, maxlen=None, add_spatial_tokens=False):
tokens = self.tokenize(text, max_len=max_len, add_spatial_tokens=add_spatial_tokens)

if add_spatial_tokens:
tokens.insert(0, self.get_command_id('cls'))
Expand All @@ -803,6 +810,31 @@ def tokenize(self, text, maxlen=None, add_spatial_tokens=False):
self.truncate_sequence(maxlen, tokens, pop_index=-index)
return tokens

def tokenize(self, text, maxlen=None, add_spatial_tokens=False):
tokens = self.tokens_trie.split(text)
# ["This is something", "<special_token_1>", " else"]
for i, token in enumerate(tokens):
if token in self._command_token_tokens:
left = tokens[i - 1] if i > 0 else None
right = tokens[i + 1] if i < len(tokens) - 1 else None
# We strip left and right by default
if right:
tokens[i + 1] = right.lstrip()
if left:
tokens[i - 1] = left.rstrip()
# ["This is something", "<special_token_1>", "else"]
tokenized_text = []
for token in tokens:
# Need to skip eventual empty (fully stripped) tokens
if not token:
continue
if token in self._command_token_tokens:
tokenized_text.append(token)
else:
tokenized_text.extend(self._tokenize(token))
# ["This", " is", " something", "<special_token_1>", "else"]
return tokenized_text

def search_special(self, name):
if name == "cls":
if self.check_special('<s>'): return '<s>'
Expand Down
8 changes: 6 additions & 2 deletions flagai/env_trainer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,9 @@ def do_train(self,
if self.save_dir and (self.iteration + 1) % self.save_interval == 0 and \
self.iteration != best_iteration:
if self.adapter_save:
self.model.save_pretrained(save_directory=self.save_dir)
save_dir = self.save_dir+'/'+str(self.iteration)
os.makedirs(save_dir,exist_ok=True)
self.model.save_pretrained(save_directory=save_dir)
else:
best_iteration = self.iteration
save_checkpoint(self.iteration+1,
Expand All @@ -697,7 +699,9 @@ def do_train(self,
# self.iteration-1 as the exact iteration
if self.save_dir and (self.iteration-1) != best_iteration:
if self.adapter_save:
self.model.save_pretrained(save_directory=self.save_dir)
save_dir = self.save_dir+'/'+str(self.iteration)
os.makedirs(save_dir,exist_ok=True)
self.model.save_pretrained(save_directory=save_dir)
else:
save_checkpoint(self.iteration+1,
best_iteration+1,
Expand Down

0 comments on commit fdb485d

Please sign in to comment.