diff --git a/docs/docusaurus_tsx/docs/FAQ/special_tokens.md b/docs/docusaurus_tsx/docs/FAQ/special_tokens.md index 44fb1269..c6a20f2d 100644 --- a/docs/docusaurus_tsx/docs/FAQ/special_tokens.md +++ b/docs/docusaurus_tsx/docs/FAQ/special_tokens.md @@ -1,21 +1,39 @@ # What special tokens are used? -In the v2, special tokens were different for SEQ2SEQ and LM: -LM was BOS, PAD, EOS with IDs (0, 1, 2) and the first vocab token started at id=3 -SEQ2SEQ was UNK, PAD, BOS, EOS with IDs (0, 1, 2, 3) and first vocab token started at id=4 - -In v3 we changed this behavior to align things: - group.add( - "--default_specials", - "-default_specilas", - nargs="+", - type=str, - default=[ - DefaultTokens.UNK, - DefaultTokens.PAD, - DefaultTokens.BOS, - DefaultTokens.EOS, - ]) +There are 4 main special tokens: +- BOS for "beginning of sentence"; +- PAD for "padding"; +- EOS for "end of sentence"; +- UNK for "unknown". + +## Special tokens actually used + +Depending on the context, these tokens can take various values: + +1. Default behaviour, training from scratch + +Some default values are defined as [constants](https://github.com/eole-nlp/eole/blob/ff39275c50d12951963008da11d029940b590713/eole/constants.py#L8) for the project: +```python +class DefaultTokens(object): + PAD = "" + BOS = "" + EOS = "" + UNK = "" +``` + +2. Retrieving a pretrained model from HF + +The special tokens will be retrieved and configured from the `special_tokens_map.json` configuration file from the HF model files. + +3. Custom behaviour + +In any case, these tokens can be overriden via the ad-hoc configuration settings: +- `bos_token` +- `pad_token` +- `eos_token` +- `unk_token` + +## Special tokens behaviour in Eole When we train a SEQ2SEQ model we use: SRC: srctok1 srctok2 srctok3 .... srctokn diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index c3dfa64d..7836e221 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -190,26 +190,6 @@ "XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig, } -decoder_start_table = { - "LlamaForCausalLM": "", - "MistralForCausalLM": "", - "MixtralForCausalLM": "", - "PhiForCausalLM": "", - "Phi3ForCausalLM": "", - "GPT2LMHeadModel": "", - "XLMRobertaXLForMaskedLM": "", -} - -specials_table = { - "LlamaForCausalLM": ["", "", ""], - "MistralForCausalLM": ["", "", ""], - "MixtralForCausalLM": ["", "", ""], - "PhiForCausalLM": ["", "", ""], - "Phi3ForCausalLM": ["", "", ""], - "GPT2LMHeadModel": ["", "", ""], - "XLMRobertaXLForMaskedLM": ["", "", "", ""], -} - class Tokenizer: def __init__(self, model_path: str): @@ -313,6 +293,12 @@ def run(cls, args): ) else: tokenizer_config_json = None + if os.path.exists(os.path.join(args.model_dir, "special_tokens_map.json")): + tokenizer_config_json = os.path.join( + args.model_dir, "special_tokens_map.json" + ) + else: + tokenizer_config_json = None if os.path.exists(os.path.join(args.model_dir, "generation_config.json")): generation_config_json = os.path.join( args.model_dir, "generation_config.json" @@ -415,6 +401,22 @@ def run(cls, args): raise huggingface_hub.utils.EntryNotFoundError( "No valid model files found" ) + try: + try: + special_tokens_json = huggingface_hub.hf_hub_download( + repo_id=args.model_dir, + filename="special_tokens_map.json", + token=args.token, + ) + except huggingface_hub.utils.EntryNotFoundError: + raise huggingface_hub.utils.EntryNotFoundError( + "Something went wrong the repo does not contain" + "any special_tokens_map.json file" + ) + except Exception as e: + if isinstance(e, huggingface_hub.utils.EntryNotFoundError): + special_tokens_json = None + print(e) with open(config_path, encoding="utf-8") as fconfig: config = json.load(fconfig) @@ -584,7 +586,6 @@ def run(cls, args): "n_positions": 0, } left_pad = True - eos_token = None optional_eos = [] mapped_tokens = [] gpt2_pretok = False @@ -944,21 +945,9 @@ def get_weight(checkpoint, tensor_name): if "added_tokens_decoder" in data.keys(): eos_tokens = [ data["added_tokens_decoder"][str(index)]["content"] - for index in eos_token_id[1:] + for index in eos_token_id ] optional_eos = eos_tokens[1:] - eos_token = eos_tokens[0] - elif isinstance(eos_token_id, int): - if "eos_token" in data.keys(): - if isinstance(data["eos_token"], dict): - # Llama2 style - eos_token = data["eos_token"]["content"] - elif isinstance(data["eos_token"], str): - eos_token = data["eos_token"] - elif "added_tokens_decoder" in data.keys(): - eos_token = data["added_tokens_decoder"][str(eos_token_id)][ - "content" - ] # Automatically convert added_tokens into mapped_tokens if "added_tokens_decoder" in data.keys(): mapped_tokens = [ @@ -973,6 +962,29 @@ def get_weight(checkpoint, tensor_name): else: add_bos_token = True + vocabs = {"specials": {}} + + if special_tokens_json is not None: + with open(special_tokens_json, encoding="utf-8") as f: + special_tokens_map = json.load(f) + for token_name in ["bos_token", "unk_token", "eos_token", "pad_token"]: + token = special_tokens_map.get(token_name, None) + if isinstance(token, list): + vocabs["specials"][token_name] = token[0] + elif isinstance(token, str): + vocabs["specials"][token_name] = token + elif isinstance(token, dict): + vocabs["specials"][token_name] = token["content"] + elif tokenizer_json is not None: + with open(tokenizer_json, encoding="utf-8") as f: + data = json.load(f) + vocab = {v: k for k, v in data["model"]["vocab"].items()} + for token_name in ["bos_token", "unk_token", "eos_token", "pad_token"]: + if f"{token_name}_id" in config.keys(): + vocabs["specials"][token_name] = vocab[ + config[f"{token_name}_id"] + ] + if generation_config_json is not None: with open(generation_config_json, encoding="utf-8") as f: data = json.load(f) @@ -982,8 +994,6 @@ def get_weight(checkpoint, tensor_name): for key in keys: if key in data.keys(): generation_config_dict[key] = data[key] - - vocabs = {} if ( tokenizer_model is not None ): # sentencepiece mode (might be good to check it's a SP model) @@ -1003,19 +1013,8 @@ def get_weight(checkpoint, tensor_name): vocab.extend(newtokens) for tok in data["added_tokens"]: vocab[tok["id"]] = tok["content"] - if "<|startoftext|>" in vocab: - index = vocab.index("<|startoftext|>") - vocab[index] = DefaultTokens.BOS - if eos_token is not None: - if eos_token in vocab and "" not in vocab: - index = vocab.index(eos_token) - vocab[index] = DefaultTokens.EOS - if "<0x00>" in vocab: - index = vocab.index("<0x00>") - vocab[index] = DefaultTokens.PAD src_vocab = pyonmttok.build_vocab_from_tokens( vocab, - special_tokens=specials_table[arch], ) else: # # BPE mode - we leverage the HF tokenizer.json info src_subword_type = "bpe" @@ -1023,6 +1022,10 @@ def get_weight(checkpoint, tensor_name): data = json.load(f) # gpt2_pretok pretokenizers = data.get("pre_tokenizer", {}).get("pretokenizers", [{}]) + pre_tokenizer = data.get("pre_tokenizer", None) + pretokenizers = pre_tokenizer.get("pretokenizers", None) + if pretokenizers is None: + pretokenizers = [pre_tokenizer] for pretokenizer in pretokenizers: if pretokenizer.get("type", None) == "ByteLevel": gpt2_pretok = True @@ -1031,23 +1034,14 @@ def get_weight(checkpoint, tensor_name): # "Ā" is '\x00' in unicode (cf tokenize.py gpt2 mapping) for tok in data["model"]["vocab"] ] + if DefaultTokens.PAD in vocab: + vocabs["specials"]["pad_token"] = DefaultTokens.PAD voc_size = len(vocab) if vocab_size > voc_size: for i in range(vocab_size - voc_size): vocab.append(DefaultTokens.VOCAB_PAD + str(i)) for tok in data["added_tokens"]: vocab[tok["id"]] = tok["content"] - if "<|startoftext|>" in vocab: - index = vocab.index("<|startoftext|>") - vocab[index] = DefaultTokens.BOS - if "<|begin_of_text|>" in vocab: - index = vocab.index("<|begin_of_text|>") - vocab[index] = DefaultTokens.BOS - if eos_token is not None: - if eos_token in vocab and "" not in vocab: - index = vocab.index(eos_token) - vocab[index] = DefaultTokens.EOS - src_vocab = pyonmttok.build_vocab_from_tokens(vocab) tokenizer_basename = "bpe.model" @@ -1062,7 +1056,7 @@ def get_weight(checkpoint, tensor_name): vocabs["src"] = src_vocab vocabs["tgt"] = src_vocab if add_bos_token: - vocabs["decoder_start_token"] = decoder_start_table[arch] + vocabs["decoder_start_token"] = vocabs["specials"]["bos_token"] else: vocabs["decoder_start_token"] = "" vocab_dict = vocabs_to_dict(vocabs) @@ -1089,6 +1083,7 @@ def get_weight(checkpoint, tensor_name): tgt_vocab_size=vocab_size, vocab_size_multiple=8, decoder_start_token=vocabs["decoder_start_token"], + **vocabs["specials"], transforms=["onmt_tokenize", "filtertoolong"], transforms_configs={ "filtertoolong": {"src_seq_length": 512, "tgt_seq_length": 512}, diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index 81c00dac..9484148a 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -82,7 +82,8 @@ def run(cls, args): ) vocabs, model, model_opt = config.model.model_class.load_test_model(config) - padding_idx = vocabs["tgt"][DefaultTokens.PAD] + pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) + padding_idx = vocabs["tgt"][pad_token] criterion = torch.nn.CrossEntropyLoss( ignore_index=padding_idx, reduction="none" ) diff --git a/eole/config/data.py b/eole/config/data.py index 5237fca4..193379d2 100644 --- a/eole/config/data.py +++ b/eole/config/data.py @@ -26,15 +26,17 @@ class BaseVocabConfig(Config): description="Default decoder start token. For most models it is = BOS. " "Some fairseq models require .", ) - default_specials: list = Field( - default=[ - constants.DefaultTokens.UNK, - constants.DefaultTokens.PAD, - constants.DefaultTokens.BOS, - constants.DefaultTokens.EOS, - ], - description="Default specials used for vocab initialization. " - "UNK, PAD, BOS, EOS will take IDs 0, 1, 2, 3.", + bos_token: str | None = Field( + default=constants.DefaultTokens.BOS, + ) + eos_token: str | None = Field( + default=constants.DefaultTokens.EOS, + ) + unk_token: str | None = Field( + default=constants.DefaultTokens.UNK, + ) + pad_token: str | None = Field( + default=constants.DefaultTokens.PAD, ) # pre trained embeddings stuff, might be put elsewhere both_embeddings: str | None = Field( diff --git a/eole/inference_engine.py b/eole/inference_engine.py index 3ca83584..4b19f4d2 100755 --- a/eole/inference_engine.py +++ b/eole/inference_engine.py @@ -278,6 +278,13 @@ def __init__(self, config, model_type=None): vocabs["src"] = src_vocab vocabs["tgt"] = src_vocab vocabs["decoder_start_token"] = "" + # TODO: this should be loaded from model config + vocabs["specials"] = { + "bos_token": DefaultTokens.BOS, + "pad_token": DefaultTokens.PAD, + "eos_token": DefaultTokens.EOS, + "unk_token": DefaultTokens.UNK, + } self.vocabs = vocabs # Build transform pipe transforms = make_transforms(config, self.transforms_cls, self.vocabs) @@ -290,7 +297,10 @@ def predict_batch(self, batch, config): _input_tokens = [ self.vocabs["src"].lookup_index(id) for id in start_ids - if id != self.vocabs["src"].lookup_token(DefaultTokens.PAD) + if id + != self.vocabs["src"].lookup_token( + self.vocabs["specials"].get("pad_token", DefaultTokens.PAD) + ) ] input_tokens.append(_input_tokens) if self.model_type == ModelType.DECODER: diff --git a/eole/inputters/inputter.py b/eole/inputters/inputter.py index fa58e03d..88441279 100644 --- a/eole/inputters/inputter.py +++ b/eole/inputters/inputter.py @@ -16,6 +16,13 @@ def build_vocab(config, specials): 'decoder_start_token': DefaultTokens.BOS } """ + vocabs = {} + vocabs["specials"] = { + "bos_token": config.bos_token, + "pad_token": config.pad_token, + "eos_token": config.eos_token, + "unk_token": config.unk_token, + } def _pad_vocab_to_multiple(vocab, multiple): vocab_size = len(vocab) @@ -26,8 +33,8 @@ def _pad_vocab_to_multiple(vocab, multiple): vocab.add_token(DefaultTokens.VOCAB_PAD + str(i)) return vocab - default_specials = config.default_specials - vocabs = {} + default_specials = list(vocabs["specials"].values()) + src_vocab = _read_vocab_file(config.src_vocab, config.src_words_min_frequency) src_specials = [ @@ -45,7 +52,7 @@ def _pad_vocab_to_multiple(vocab, multiple): src_vocab = pyonmttok.build_vocab_from_tokens( src_vocab, maximum_size=config.src_vocab_size, special_tokens=src_specials ) - src_vocab.default_id = src_vocab[DefaultTokens.UNK] + src_vocab.default_id = src_vocab[config.unk_token] if src_vocab.default_id >= len(src_vocab): src_vocab.default_id = ( 0 # patch that assigns OOV to id=0 when UNK does not exist @@ -80,7 +87,6 @@ def _pad_vocab_to_multiple(vocab, multiple): vocabs["tgt"] = tgt_vocab vocabs["decoder_start_token"] = config.decoder_start_token - return vocabs @@ -126,6 +132,8 @@ def vocabs_to_dict(vocabs): vocabs_dict["decoder_start_token"] = vocabs["decoder_start_token"] else: vocabs_dict["decoder_start_token"] = DefaultTokens.BOS + if "specials" in vocabs.keys(): + vocabs_dict["specials"] = vocabs["specials"] return vocabs_dict @@ -148,5 +156,7 @@ def dict_to_vocabs(vocabs_dict): vocabs["tgt"] = pyonmttok.build_vocab_from_tokens(vocabs_dict["tgt"]) if vocabs["tgt"].default_id >= len(vocabs["src"]): vocabs["tgt"].default_id = 0 # patch that assigns OOV to id=0 + if "specials" in vocabs_dict.keys(): + vocabs["specials"] = vocabs_dict["specials"] return vocabs diff --git a/eole/inputters/text_utils.py b/eole/inputters/text_utils.py index 83c02be7..bd1ef022 100644 --- a/eole/inputters/text_utils.py +++ b/eole/inputters/text_utils.py @@ -68,7 +68,9 @@ def numericalize(vocabs, example, model_type=ModelType.ENCODER_DECODER): numeric["tgt"]["tgt_ids"] = [] tgt_text = example["tgt"]["tgt"].split(" ") numeric["tgt"]["tgt_ids"] = vocabs["tgt"]( - [decoder_start_token] + tgt_text + [DefaultTokens.EOS] + [decoder_start_token] + + tgt_text + + [vocabs["specials"].get("eos_token", "")] ) elif model_type == ModelType.DECODER: @@ -79,7 +81,9 @@ def numericalize(vocabs, example, model_type=ModelType.ENCODER_DECODER): if example["tgt"] is not None: numeric["tgt"]["tgt_ids"] = [] tgt_text = example["tgt"]["tgt"].split(" ") - numeric["tgt"]["tgt_ids"] = vocabs["tgt"](tgt_text + [DefaultTokens.EOS]) + numeric["tgt"]["tgt_ids"] = vocabs["tgt"]( + tgt_text + [vocabs["specials"].get("eos_token", "")] + ) if decoder_start_token == "": numeric["tgt"]["tgt_ids"] = numeric["tgt"]["tgt_ids"][1:] @@ -88,17 +92,21 @@ def numericalize(vocabs, example, model_type=ModelType.ENCODER_DECODER): if example["tgt"] is not None: # TO BE DISCUSSED tgt_text = example["tgt"]["tgt"].split(" ") txt = ( - [DefaultTokens.BOS] + [vocabs["specials"].get("bos_token", "")] + tgt_text - + [DefaultTokens.EOS] - + [DefaultTokens.EOS] + + [vocabs["specials"].get("eos_token", "")] + + [vocabs["specials"].get("eos_token", "")] + src_text - + [DefaultTokens.EOS] + + [vocabs["specials"].get("eos_token", "")] ) numeric["src"]["src_ids"] = vocabs["src"](txt) numeric["tgt"]["tgt_ids"] = vocabs["src"](txt) else: - txt = [DefaultTokens.BOS] + src_text + [DefaultTokens.EOS] + txt = ( + [vocabs["specials"].get("bos_token", "")] + + src_text + + [vocabs["specials"].get("eos_token", "")] + ) numeric["src"]["src_ids"] = vocabs["src"](txt) else: diff --git a/eole/models/model.py b/eole/models/model.py index 2354ef71..a7e73074 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -57,12 +57,13 @@ def build_decoder(model_config, running_config=None): def build_src_emb(model_config, vocabs, running_config=None): # Build embeddings. + pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) src_emb = Embeddings( word_vec_size=model_config.embeddings.src_word_vec_size, position_encoding_type=model_config.embeddings.position_encoding_type, position_shift=model_config.embeddings.position_shift, dropout=getattr(running_config, "dropout", [0.0])[0], - word_padding_idx=vocabs["src"][DefaultTokens.PAD], + word_padding_idx=vocabs["tgt"][pad_token], word_vocab_size=len(vocabs["src"]), sparse=getattr(running_config, "optim", None) == "sparseadam", freeze_word_vecs=model_config.embeddings.freeze_word_vecs_enc, @@ -75,12 +76,13 @@ def build_tgt_emb( model_config, vocabs, running_config=None, share_embeddings=False, src_emb=None ): # Build embeddings. + pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) tgt_emb = Embeddings( word_vec_size=model_config.embeddings.tgt_word_vec_size, position_encoding_type=model_config.embeddings.position_encoding_type, position_shift=model_config.embeddings.position_shift, dropout=getattr(running_config, "dropout", [0.0])[0], - word_padding_idx=vocabs["tgt"][DefaultTokens.PAD], + word_padding_idx=vocabs["tgt"][pad_token], word_vocab_size=len(vocabs["tgt"]), sparse=getattr(running_config, "optim", None) == "sparseadam", freeze_word_vecs=model_config.embeddings.freeze_word_vecs_dec, diff --git a/eole/models/model_saver.py b/eole/models/model_saver.py index b94dffec..4183246f 100644 --- a/eole/models/model_saver.py +++ b/eole/models/model_saver.py @@ -8,6 +8,7 @@ from eole.modules.lora import lora_state_dict from eole.config import recursive_model_fields_set from eole.config.run import TrainConfig +from eole.constants import DefaultTokens try: from safetensors.torch import save_file @@ -65,6 +66,14 @@ def load_checkpoint(model_path): if os.path.exists(vocab_path): with open(vocab_path) as f: checkpoint["vocab"] = json.load(f) + # use default specials if not specified + if "specials" not in checkpoint["vocab"].keys(): + checkpoint["vocab"]["specials"] = { + "bos_token": DefaultTokens.BOS, + "pad_token": DefaultTokens.PAD, + "eos_token": DefaultTokens.EOS, + "unk_token": DefaultTokens.UNK, + } else: raise FileNotFoundError(f"{model_path} does not contain vocab.json") optim_path = os.path.join(model_path, "optimizer.pt") diff --git a/eole/predict/inference.py b/eole/predict/inference.py index b6dbbab4..aaf86b6e 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -93,12 +93,19 @@ def __init__( self.model = model self.vocabs = vocabs self._tgt_vocab = vocabs["tgt"] - self._tgt_eos_idx = [vocabs["tgt"].lookup_token(DefaultTokens.EOS)] + [ - vocabs["tgt"].lookup_token(tok) for tok in optional_eos - ] - self._tgt_pad_idx = vocabs["tgt"].lookup_token(DefaultTokens.PAD) - self._tgt_bos_idx = vocabs["tgt"].lookup_token(DefaultTokens.BOS) - self._tgt_unk_idx = vocabs["tgt"].lookup_token(DefaultTokens.UNK) + self._tgt_eos_idx = [ + vocabs["tgt"].lookup_token(vocabs.get("specials", {}).get("eos_token", "")) + ] + [vocabs["tgt"].lookup_token(eos_token) for eos_token in optional_eos] + # defaulting to DefaultTokens.PAD might not always work + self._tgt_pad_idx = vocabs["tgt"].lookup_token( + vocabs.get("specials", {}).get("pad_token", DefaultTokens.PAD) + ) + self._tgt_bos_idx = vocabs["tgt"].lookup_token( + vocabs.get("specials", {}).get("bos_token", "") + ) + self._tgt_unk_idx = vocabs["tgt"].lookup_token( + vocabs.get("specials", {}).get("unk_token", "") + ) self._tgt_sep_idx = vocabs["tgt"].lookup_token(DefaultTokens.SEP) self._tgt_start_with = vocabs["tgt"].lookup_token(vocabs["decoder_start_token"]) self._tgt_vocab_len = len(self._tgt_vocab) diff --git a/eole/tests/test_models.py b/eole/tests/test_models.py index 1c5023b4..68d66516 100644 --- a/eole/tests/test_models.py +++ b/eole/tests/test_models.py @@ -58,7 +58,16 @@ def get_vocabs(self): ], ) - vocabs = {"src": src_vocab, "tgt": tgt_vocab} + vocabs = { + "src": src_vocab, + "tgt": tgt_vocab, + "specials": { + "bos_token": DefaultTokens.BOS, + "pad_token": DefaultTokens.PAD, + "eos_token": DefaultTokens.EOS, + "unk_token": DefaultTokens.UNK, + }, + } return vocabs def get_batch(self, source_l=3, bsize=1): diff --git a/eole/trainer.py b/eole/trainer.py index a84f461d..8775de7b 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -35,8 +35,8 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None): used to save the model """ - train_loss = LossCompute.from_config(config, model, vocabs["tgt"]) - valid_loss = LossCompute.from_config(config, model, vocabs["tgt"], train=False) + train_loss = LossCompute.from_config(config, model, vocabs) + valid_loss = LossCompute.from_config(config, model, vocabs, train=False) estim_loss_lambda = config.training.estim_loss_lambda estim_loss_lambda_steps = config.training.estim_loss_lambda_steps diff --git a/eole/transforms/tokenize.py b/eole/transforms/tokenize.py index ccf492d5..6c1711f3 100644 --- a/eole/transforms/tokenize.py +++ b/eole/transforms/tokenize.py @@ -155,6 +155,15 @@ def _repr_args(self): } return ", ".join([f"{kw}={arg}" for kw, arg in kwargs.items()]) + def warm_up(self, vocabs=None): + super().warm_up(None) + if vocabs is not None: + self.pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) + self.eos_token = vocabs["specials"].get("eos_token", DefaultTokens.EOS) + else: + self.pad_token = DefaultTokens.PAD + self.eos_token = DefaultTokens.EOS + def tokenize_string(self, string, side="src", is_train=False): raise NotImplementedError @@ -164,7 +173,7 @@ def _tokenize(self, tokens, side="src", is_train=False): # in case the tokenizer doesn't preserve them. sentence = " ".join(tokens) # Locate the end-of-sentence placeholders. - sent_list = sentence.split(DefaultTokens.EOS) + sent_list = sentence.split(self.eos_token) # Tokenize each sentence separately. segmented = [] for _sentence in sent_list: @@ -176,10 +185,10 @@ def _tokenize(self, tokens, side="src", is_train=False): _sentence_tokens = [] for _chunk in _sentence_chunks: _sentence_tokens += self.tokenize_string(_chunk, side, is_train) + [ - DefaultTokens.PAD + self.pad_token # not sure this covers all cases ] # Re-insert the eos token. - segmented += _sentence_tokens[:-1] + [DefaultTokens.EOS] + segmented += _sentence_tokens[:-1] + [self.eos_token] return segmented[:-1] def apply(self, example, is_train=False, stats=None, **kwargs): diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 0d6b8e10..8b24cf0b 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -28,7 +28,7 @@ class LossCompute(nn.Module): lambda_coverage: Hyper-param to apply coverage attention if any lambda_align: Hyper-param for alignment loss tgt_shift_index (int): 1 for NMT, 0 for LM - vocab: target vocab + vocabs: full vocabs with specials module that maps the output of the decoder to a distribution over the target vocabulary. lm_generator (:obj:`ctranslate2.Generator`): LM Generator @@ -43,7 +43,7 @@ def __init__( lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1, - vocab=None, + vocabs=None, lm_generator=None, lm_prior_lambda=None, lm_prior_tau=None, @@ -55,15 +55,19 @@ def __init__( self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align self.tgt_shift_index = tgt_shift_index - self.vocab = vocab + self.vocabs = vocabs self.lm_generator = lm_generator self.lm_prior_lambda = lm_prior_lambda self.lm_prior_tau = lm_prior_tau self.lm_prior_model = lm_prior_model self.estimloss = nn.MSELoss(reduction="sum") + self.pad_token = self.vocabs["specials"].get("pad_token", DefaultTokens.PAD) + self.unk_token = self.vocabs["specials"].get("unk_token", DefaultTokens.UNK) + self.eos_token = self.vocabs["specials"].get("eos_token", DefaultTokens.EOS) + @classmethod - def from_config(cls, config, model, vocab, train=True): + def from_config(cls, config, model, vocabs, train=True): """ Returns a subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -74,7 +78,8 @@ def from_config(cls, config, model, vocab, train=True): device = torch.device( "cuda" if eole.utils.misc.use_gpu(config.training) else "cpu" ) - padding_idx = vocab[DefaultTokens.PAD] + pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD) + padding_idx = vocabs["tgt"][pad_token] if config.model.decoder is not None: lambda_align = getattr( @@ -135,7 +140,7 @@ def from_config(cls, config, model, vocab, train=True): lambda_coverage=lambda_coverage, lambda_align=lambda_align, tgt_shift_index=tgt_shift_idx, - vocab=vocab, + vocabs=vocabs, lm_generator=lm_generator, lm_prior_lambda=lm_prior_lambda, lm_prior_tau=lm_prior_tau, @@ -182,7 +187,7 @@ def _compute_lm_loss_ct2(self, output, target): scores = F.log_softmax(scores.to(torch.float32), dim=-1) src = target.detach().clone() - src[src == self.vocab[DefaultTokens.EOS]] = self.padding_idx + src[src == self.vocabs["tgt"][self.eos_token]] = self.padding_idx src = src[:, :-1, :] src_len = src[:, :, 0].ne(self.padding_idx).sum(1) # ct2 expects src with lengths without padding @@ -195,8 +200,8 @@ def _compute_lm_loss_ct2(self, output, target): # again we use raw probs to rescale with tau and apply log_softmax lm_scores = self._bottle(lm_scores) / self.lm_prior_tau lm_scores = F.log_softmax(lm_scores.to(torch.float32), dim=-1) - lm_scores[:, self.vocab[DefaultTokens.UNK]] = -50 - lm_scores[:, self.vocab[DefaultTokens.EOS]] -= 20 + lm_scores[:, self.vocabs["tgt"]["unk_token"]] = -50 + lm_scores[:, self.vocabs["tgt"]["eos_token"]] -= 20 # lm_scores are in log space so log_target=True lm_loss = F.kl_div(scores, lm_scores, reduction="none", log_target=True).sum(-1) non_padding = self._bottle(output).ne(self.padding_idx)[:, 0] @@ -215,7 +220,7 @@ def _compute_lm_loss(self, output, target): scores = F.log_softmax(scores.to(torch.float32), dim=-1) src = target.detach().clone() - src[src == self.vocab[DefaultTokens.EOS]] = self.padding_idx + src[src == self.vocabs["tgt"]["eos_token"]] = self.padding_idx src = src[:, :-1, :] src_len = src[:, :, 0].ne(self.padding_idx).sum(1) # ct2 expects src with lengths without padding @@ -226,8 +231,8 @@ def _compute_lm_loss(self, output, target): ) # again we use raw probs to rescale with tau and apply log_softmax lm_scores = F.log_softmax(lm_scores.to(torch.float32), dim=-1) - lm_scores[:, self.vocab[DefaultTokens.UNK]] = -50 - lm_scores[:, self.vocab[DefaultTokens.EOS]] -= 20 + lm_scores[:, self.vocabs["tgt"]["unk_token"]] = -50 + lm_scores[:, self.vocabs["tgt"]["eos_token"]] -= 20 # lm_scores are in log space so log_target=True lm_loss = F.kl_div(scores, lm_scores, reduction="none", log_target=True).sum(-1) non_padding = self._bottle(output).ne(self.padding_idx)[:, 0] diff --git a/recipes/cometkiwi/cometkiwi-xl-eole.yaml b/recipes/cometkiwi/cometkiwi-xl-eole.yaml index 1c98e5b1..d25bc872 100755 --- a/recipes/cometkiwi/cometkiwi-xl-eole.yaml +++ b/recipes/cometkiwi/cometkiwi-xl-eole.yaml @@ -9,7 +9,6 @@ tgt_vocab_size: 250880 vocab_size_multiple: 1 report_every: 50 skip_empty_level: silent -default_specials: ['', '', '', ''] data_task: "encoder" # transforms config diff --git a/recipes/cometkiwi/cometkiwi-xxl-eole.yaml b/recipes/cometkiwi/cometkiwi-xxl-eole.yaml index 1381520d..599acb35 100755 --- a/recipes/cometkiwi/cometkiwi-xxl-eole.yaml +++ b/recipes/cometkiwi/cometkiwi-xxl-eole.yaml @@ -9,7 +9,6 @@ tgt_vocab_size: 250880 vocab_size_multiple: 1 report_every: 50 skip_empty_level: silent -default_specials: ['', '', '', ''] # transforms config transforms: [sentencepiece, filtertoolong] diff --git a/recipes/gpt2/inference.yaml b/recipes/gpt2/inference.yaml index b3542422..4e6b897a 100644 --- a/recipes/gpt2/inference.yaml +++ b/recipes/gpt2/inference.yaml @@ -1,10 +1,3 @@ -transforms: [onmt_tokenize] -transforms_configs: - onmt_tokenize: - src_subword_type: bpe - src_subword_model: ${EOLE_MODEL_DIR}/openai_gpt2/bpe.model - gpt2_pretok: true - world_size: 1 gpu_ranks: [0] diff --git a/recipes/gpt2/run.sh b/recipes/gpt2/run.sh new file mode 100755 index 00000000..4b78ec06 --- /dev/null +++ b/recipes/gpt2/run.sh @@ -0,0 +1,7 @@ +# naive script with commands from the readme +# (useful to make sure the recipe still runs) + +eole convert HF --model_dir openai-community/gpt2 --output $EOLE_MODEL_DIR/openai_gpt2 --token $HF_TOKEN +echo -e "The European Union was created in" > lm_input.txt +eole predict -c inference.yaml +eole tools eval_hellaswag -c inference.yaml \ No newline at end of file diff --git a/recipes/llama2/run.sh b/recipes/llama2/run.sh new file mode 100755 index 00000000..6d5c4378 --- /dev/null +++ b/recipes/llama2/run.sh @@ -0,0 +1,20 @@ +# naive script with commands from the readme +# (useful to make sure the recipe still runs) + +eole convert HF --model_dir meta-llama/Llama-2-7b-chat-hf --output $EOLE_MODEL_DIR/llama2-7b-chat-hf --token $HF_TOKEN +echo -e "[INST] <> +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. +<> +What are some nice places to visit in France? [/INST]" | sed ':a;N;$!ba;s/\n/⦅newline⦆/g' > test_prompt.txt +eole predict -c llama-inference.yaml -src test_prompt.txt -output test_output.txt +eole predict -c llama-inference-tp-2gpu.yaml -src test_prompt.txt -output test_output.txt +[ ! -d ./data ] && mkdir ./data +# Alpaca +wget -P ./data https://opennmt-models.s3.amazonaws.com/llama/alpaca_clean.txt +# Vicuna +wget -P ./data https://opennmt-models.s3.amazonaws.com/llama/sharegpt.txt +# Open Assisstant +wget -P ./data https://opennmt-models.s3.amazonaws.com/llama/osst1.flattened.txt +eole train -c llama-finetune.yaml +eole model lora --action merge --base_model ${EOLE_MODEL_DIR}/llama2-7b-chat-hf --lora_weights ./finetune/llama2-7b-chat-hf-finetune --output ./finetune/merged \ No newline at end of file diff --git a/recipes/llama3.1/llama-inference.yaml b/recipes/llama3.1/llama-inference.yaml index abf81f43..46ce70d2 100755 --- a/recipes/llama3.1/llama-inference.yaml +++ b/recipes/llama3.1/llama-inference.yaml @@ -5,7 +5,6 @@ model_path: "${EOLE_MODEL_DIR}/llama3.1-8b" seed: 42 max_length: 256 # max_length: 1 -gpu: 0 batch_type: sents batch_size: 4 world_size: 1 diff --git a/recipes/llama3.1/run.sh b/recipes/llama3.1/run.sh new file mode 100755 index 00000000..3c3ebe68 --- /dev/null +++ b/recipes/llama3.1/run.sh @@ -0,0 +1,341 @@ +# naive script with commands from the readme +# (useful to make sure the recipe still runs) + +eole convert HF --model_dir meta-llama/Meta-Llama-3.1-8B --output $EOLE_MODEL_DIR/llama3.1-8b --token $HF_TOKEN + +echo -e "You are given this machine learning research paper, please read it carefully and answer the follow up question. + +=== BEGIN === + +2306.15595v2 [cs.CL] 28 Jun 2023 + +arXiv + +EXTENDING CONTEXT WINDOW OF LARGE LAN- +GUAGE MODELS VIA POSITION INTERPOLATION + +Shouyuan Chen Sherman Wong Liangjian Chen Yuandong Tian +Meta Platforms Inc. +{chenshouyuan, shermanwong, cli, yuandong}@meta . com + +1 INTRODUCTION + +Large language models (LLMs) typically come with a pre-defined context window size. For exam- +ple, inputs to LLaMA models (Touvron et al., 2023) must be fewer than 2048 tokens. This pre-set +context window limit is frequently exceeded in applications such as conducting long conversations, +summarizing long documents, or executing long-term planning. For these applications, LLMs with +longer context windows are preferred. However, training an LLM from scratch with long context +windows requires significant investments. This naturally leads to a question: Can we extend the +context window of an existing pre-trained LLM? + +One straightforward approach is to fine-tune an existing pre-trained Transformer with a longer con- +text window. However, empirically, we found that models trained this way adapt to long context +windows very slowly. After training for more than 10000 batches, the effective context window +saw a minimal increase, moving from 2048 to 2560 (Table 4). This suggests that such method is +inefficient for extending to substantially longer context windows. + +While certain techniques such as ALiBi (Press et al., 2022) and LeX (Sun et al., 2022) enable length +extrapolation of Transformers, i.e. train on short context windows and inference on longer ones, +many existing pre-trained LLMs, including LLaMA (Touvron et al., 2023), use positional encodings +that have weak extrapolation properties (e.g., RoPE (Su et al., 2021)). Therefore, the applicability +of these techniques for extending the context window sizes of such LLMs remains limited. + +In this work, we introduce Position Interpolation to enable context window extensions for certain +existing pre-trained LLMs, including LLaMA. The key idea is, instead of extrapolation, we directly +down-scale the position indices so that the maximum position index matches the previous context +window limit in the pre-training stage. See Figure 1 for an illustration. In other words, to accom- +modate more input tokens, we interpolate the position encodings at neighboring integer positions, +utilizing the fact that position encodings can be applied on non-integer positions, as opposed to +extrapolating outside the trained positions, which may lead to catastrophic values. We verify our +approach theoretically, by showing that the interpolated attention score has a much smaller upper + +bound (~ 600x smaller in LLaMA 7B setting) than the extrapolated one, and is thus much more +stable. Therefore, interpolated position encodings are easier for the model to adapt. + +Empirically, we found that Position Interpolation is highly effective and efficient, requiring only a +very short period of fine-tuning for the model to fully adapt to greatly extended context windows. +We present experimental results for extending the context window to up to 32768 from the initial +2048 across 7B to 65B LLaMA models using Position Interpolation. Our results show that + +1. Position Interpolation can easily enable very long context windows (e.g. 32768), requiring +only fine-tuning for 1000 steps on the Pile (Gao et al., 2020) to achieve a good quality. +The cost of fine-tuning is negligible compared to the pre-training costs. This confirms +our hypothesis that it is relatively easy for the models to adapt to interpolated position +encodings. + +2. Position Interpolation generates strong models that can effectively make use of much ex- +tended context window. We show that models extended by Position Interpolation enjoy +significant perplexity gains from greatly extended context windows for text modeling, and +we show that the perplexity reduces graceful with the enlargement of context windows. +We also applied Position Interpolation in a long text summarization task, and demonstrate +competitive performances. + +3. Position Interpolation preserves model quality relatively well for tasks within its original +context window sizes. We present a variety of evaluation results for the extended LLaMA +models on the original LLaMA benchmark. Compared with original LLaMA models, the +extended LLLaM A models saw a minor degradation on several standard benchmarks within +a 2048 token limit. + +Our results highlight the innate ability of Transformer models to “extrapolate to sequence lengths +longer than the ones encountered during training” as hypothesized in the seminal work of Vaswani +et al. (2017). We reaffirm this hypothesis and suggest that the previously known weakness of ex- +trapolating to longer sequences for language modeling (Press et al., 2022) may be due to direct + +extrapolation of positional encodings and it can be largely mitigated by interpolating position en- +codings instead. + +Concurrent work. Right before our release, we are informed with a concurrent blogpost (Super- +HOT kaiokendev (2023)) that also interpolates positional encoding in RoPE to extend the context +window from 2K to 8K. Recently, open source community picks it up in Reddit post ! and Github +Issues 2, which shows that fine-tuning with LoRA (Hu et al., 2021) also seems to work well. Our +paper shows a full fine-tuning with up to 65B model work well with Position Interpolation, and we +also give theoretical explanations why interpolation achieves much more stable results than extrap- +olation, by showing that the upper bound of interplated attention score is much lower than that of +extrapolated ones. + +2 METHOD + +2.1 BACKGROUND: ROTARY POSITION EMBEDDING (ROPE) + +Transformer models require explicit positional information to be injected, typically in the form of +positional encodings, to represent the order of inputs. We consider Rotary Position Embedding +(ROPE) (Su et al., 2021), which is the position encoding used in the LLLaMA model (Touvron et al., +2023). Given a position index m € [0, ¢) and an embedding vector x := [zg, 71,..., 241], Where +d is the dimension of the attention head, RoPE defines a vector-valued complex function f{x, m) as +follows + +Using RoPE, the self-attention score +is only dependent on relative position m — 7 through trigonometric functions. Here q and k are the +query and key vector for a specific attention head. At each layer, RoPE is applied on both query and +key embeddings for computing attention scores. + +2.2 DIRECT EXTRAPOLATION + +While the attention score in RoPE only depends on the relative positions, which is what we want, +its extrapolation performance is not great . In particular, when directly extending to larger context +windows unseen in the training, the perplexity may shoot up to very high numbers (i.e., > 10%), +comparable to untrained models. + +Ideally, we want to see the model trained on a context window of size L = 2048 to still work +reasonably well on longer context window, but may not have the capability to leverage information +that appears beyond L. For example, to answer a question located at 3000, the model trained on +maximal window size of I = 2048 cannot leverage evidences provided at location 0, but still +can leverage the evidences provided at location 2900. In contrast, in reality we see catastrophic +behaviors, i.e., question at location 3000 cannot be answered correctly, even if the evidences are +located at location 2900. + +What is the reason behind? How could this happen if the attention score a,,,—,, decays as the relative +distance |m — n/| increases, according to Section 3.4.3 of (Su et al., 2021), and content from very +far distances should not matter that much? It turns out that the upper bound derived in Section 3.4.3 +of (Su et al., 2021) may be too loose: while it indeed decays with respect to |m — nl, the bound +can still be quite large (i.e., the bound can be critically depends on the magnitude of v;) and thus +vacuous. In fact, if we treat all trigonometric functions as basis functions (i.e, ¢;(s) := #93), and +think about Eqn. 2 as basis expansion as the following: + +where s is the positional span between a query and a key and h; := (ga; + igaj+1){k2j — tk2j+1) +are complex coefficients depending on q and k (here the definition of h; is exactly the same as the +definition of k; in Sec 3.4.3 in RoPE (Su et al., 2021)). Now the the issue becomes clear: as shown +in Fig. 2, a, can be small in magnitude in the range of [0, 2048], but gives huge values out of the +region. The underlying reason is that the trigonometric family {¢;} (with sufficiently large d) is +a universal approximator and can fit any arbitrary functions. Therefore, for a, there always exist +coefficients {h;} (i.e. key and query) that corresponds to small function values in [0, 2048] but + +much larger in regions beyond. + +2.3 PROPOSED APPROACH: POSITION INTERPOLATION (PI) + +In Fig. 2, thanks to the smoothness of bases functions ¢; interpolation is much more stable and will +not lead to wild values. Therefore, instead of extrapolate the attention score in Eqn. 3 to s > L, +how about we define an attention score a{s) = a(Ls/L’) where L’ is the longer context window? +Formally, we replace RoPE f by {’ defined as follows + +We call this transformation on the position encoding Position Interpolation. In this step, we reduce +position indices from [0, L') to [0, L) to match the original range of indices before computing RoPE. +Consequently, as inputs to RoPE, the maximum relative distance between any two tokens has been +reduced from I’ to L. Since we align the ranges of position indices and relative distances before +and after extension, we mitigate the effect on attention score computation due to context window +extensions, which can allow the model easier to adapt. To further demonstrate this is the case, in the +following theorem, we show that the interpolated attention score is well-behaved: + +While there is no close form for B(s) := 4/21 |Ag41(s)|, numerically it is at least larger than d, and for many positional difference s, B(s) is much larger than d +(check Appendix B for the plot). Therefore, the interpolation bound is at least 2 - 294.73 ~ 600 x +smaller than the extrapolation bound, and thus the interpolated attention score is much more stable +than extrapolated one. + +Notably, our method of rescaling of position indices does not introduce extra weight, or modify +the model architecture in any way. This makes it attractive in practical applications, since most +infrastructure and optimization for the original model can be reused after the extension. + +Fine-tuning. We can further fine-tune the interpolated model using the next token prediction task +with interpolated position encodings on the extended context window size using a pre-training cor- +pus such as the Pile (Gao et al., 2020). In the next section, we show that our fine-tuning process +only needs tens to hundreds thousands of examples. We also find that the result of the fine-tuning +is not sensitive to the choice of examples. The reason may be that the model is only adapting to the +new context window during the fine-tuning phase, starting from a good initialization, as opposed to +acquiring new knowledge. + +Other ways to reduce interpolation/extrapolation bound. From the expression of the interpola- +tion (Eqn. 5) and extrapolation bound (Eqn. 8), a common term is max; ||, which is the maximal +magnitude of query/key products. If we enforce a regularization on || during LLM training, it is +possible that the catastrophic extrapolation error can be mitigated or even resolved. In fact, if we +apply ridge regression with proper regularization to fit a curve in Fig. 2, the magnitude of extrapo- +lated a(s) when s > L can be comparable to that within [0, L]. To our knowledge, we are not aware +of existing LLM pre-training techniques that leverage this regularization and will leave it for future +work. + +3 EXPERIMENTS + +We show Position Interpolation can effectively extend context window up to 32 times of the original +size, and such extension can be done with only several hundreds of training steps. We show the +resulting models are strong LLMs with fully effective long context windows. We demonstrate its +performance in a number of tasks including language modeling, passkey retrieval, and long doc- +ument summarization. We also present benchmark results of the extended models on the original +LLaMA evaluation benchmarks. +3.1 SETUP + +Model Variants. We extended the pre-trained 7B, 13B, 33B and 65B LLaMA models (Touvron +et al., 2023) to various context window of sizes up to 32768, using either direct fine-tuning or +Position Interpoloation method. Except for rescaling the position indices for models extended with +Position Interpolation, we did not modify LLaMA model architectures (Touvron et al., 2023) in any +ways. + +Training Procedure. We fine-tune all model variants using the next token prediction objective. We +use AdamW (Loshchilov & Hutter, 2019) with 5; = 0.9 and 2 = 0.95. We use a linear learning +rate warmup of 20 steps starting from 10% of the maximum learning rate. For 7B and 13B models, +we set the learning rate to 2 x 1075 and for 33B and 65B models we set the learning rate to 1072. We +set the weight decay to zero. For extending 7B, 13B and 33B models to the 8192 context window +size, we use 32 A100 GPUs and 64 global batch size. For all other cases we use 128 A100 GPUs and +128 global batch size. We note that the main need of using more GPUs is memory limitation during +fine-tuning, and it is possible to use fewer GPUs in certain cases. We train all models using PyTorch +(Paszke et al., 2019) with Fully Sharded Data Parallel (Zhao et al., 2023) and Flash Attention (Dao +et al., 2022). + +If not specified otherwise, for the Position Interpolation method, we fine-tune the models for 1000 +steps. For the direct fine-tuning method, we use 10000 steps. We primarily fine-tune using the Pile +training dataset (Gao et al., 2020). In Section 3.4 we also compared fine-tuning performance on the +RedPajama dataset (Computer, 2023). + +3.2 LONG SEQUENCE LANGUAGE MODELING + +We evaluate the long sequence language modeling performance of our extended models and base- +lines on two datasets: book corpus (PG-19) (Rae et al., 2020) and cleaned Arxiv Math proof-pile +dataset (Azerbayev et al., 2022). + +We use the test splits of PG19 (Rae et al., 2020) and proof-pile (Azerbayev et al., 2022). For PG19, +we use the whole test split consisting of 100 documents. For the proof-pile dataset, we use a random +subsample of 128 documents with at least 32768 SentencePiece (Kudo & Richardson, 2018) tokens +and truncate to the first 32768 tokens for each test document. We evaluate perplexity at various +context window size by using a sliding window approach following Press et al. (2022) with stride +S = 256. + +In Table 1 and Table 2, we report the perplexity results for our models and baselines on the datasets. +From the results, we found that models extended with our method enjoy a significantly improved +perplexity from longer context window sizes. By increasing the context window size from 2048 to +16384, we observed -0.28 and -0.5 reductions of perplexity for extending LLaMA 7B models on +both datasets, -0.27 and -0.48 reductions for extending LL.aMA 13B models, and -0.14 and -0.42 +reductions for extending LLaMA 33B models. For LLaMA 65B models, we observed -0.12 and +-0.3 reductions of perplexity by extending to the 8192 context window size. + +In general, we observed a consistent trend of our models achieving better perplexity with longer +context windows. This indicates our models can effectively make use of the longer context windows +to better predict next tokens in language modeling tasks. Moreover, we found this trend extends to +32768 window size without diminishing on the PG19 dataset for LLaMA 7B and 13B models. This +indicates that our method may enable extension to even longer context windows. + +In contrast, we observed that models extended via the direct fine-tuning method has shown regres- +sion (up to +0.48) or minor improvement (up to -0.12) on the perplexity at longer context windows. +This indicates that models extended this way have limited capability of making use of context win- +dows longer than their pre-trained settings. + +We saw a minor degradation of the perplexity on the original context window of 2048 for our ex- +tended models in some cases. For example, on the Proof-pile dataset, we saw a degradation ranging +from 0.01 to 0.05 across all models with extended with Position Interpolation. A small degradation +of performance within original evaluation context window is expected since Position Interpolation +forces position encodings in original context window to reside in a much narrower region, which +may negatively affect the language model’s performance. We present more benchmark results on +the original context window size in Section 3.4. + +In Table 3 we report the relationship between perplexity and the number of fine-tuning steps for +LLaMA 7B model extending to 8192 and 16384 context window sizes using Position Interpolation +evaluated on the PG19 dataset. We can see without fine-tuning (at step 0) the model can exhibit +certain language modeling capability, as indicated by < 20 perplexity for extending to 8192 context +window (in contrast, the direct extrapolation method leads to > 10% perplexity). With fine-tuning, +we observed that the perplexity improves quickly. At 200 steps the models surpassed the original +model’s perplexity on 2048 context window size, indicating the models gaining ability of effectively +using sequences longer than the pre-training settings for language modeling. At 1000 steps, we can +see the models have improved steadily and achieve a significantly better perplexity. + +3.3 MEASURING EFFECTIVE CONTEXT WINDOW SIZE THROUGH PASSKEY RETRIEVAL + +We study the effective context window size, i.e. the maximum distance of a token can effectively +attend to during inference, of our models after extension. To measure this, we follow a synthetic +evaluation task of passkey retrieval proposed by Mohtashami & Jaggi (2023). In this task, the models +are asked to recover a random passkey hidden in a long document. See Figure 3 for the format of +the document. + +Given a language model, we estimate the upper and lower bounds of effective context windows as +follows. Suppose the random passkey is k tokens away from the end of the input. When a model +persistently fails to retrieve the correct passkey value across several independent attempts, it suggests +that the effective context window size of the model is less than k. Conversely, if a model consistently +succeeds in retrieving the correct passkey value, we deduce that the effective context window size +of the model is at least k. + +We evaluate the 7B and 33B LLaMA model variants that are extended via Position Interpolation or +direct fine-tuning. For each model, we use 32 different &£ uniformly spaced in the targeted context +window L’ and run the above tests for 10 times for each k, where each time a random passkey of 5 +random digits is used. In Table 4, we report kyax as a function of the number of fine-tuning steps, + +We can see that models extended via Position Interpolation all successfully attain their desired ex- +tension objectives in terms of effective context window sizes, indicating by the effective context +window size reaching maximum kp, = L/, after merely fine-tuning for 200 steps, consistently +across both 7B and 33B model sizes and up to 32768 context windows. In contrast, LLLaMA models +that are extended via direct fine-tuning only saw a minimal increase of the effective context win- +dow size kay from 2048 to 2560, even after fine-tuning for more than 10000 steps, with no clear +indication of an acceleration in the increase of window size. + +3.4 BENCHMARKS ON ORIGINAL CONTEXT WINDOW SIZE + +We evaluate the models extended by Position Interpolation on several standard benchmark tasks +within the original context window size of 2048. The evaluation results are listed in Table 5. From +the results, we saw that models extended to 8192 produce comparable results on the original bench- +mark which is designed for a much smaller context window, with a degradation of up to 2% on +the benchmark tasks, for both 7B and 33B model sizes. Models extended to longer context win- +dows regressed more on the benchmarks, but still in reasonable ranges for most tasks. We also note +that the choice of fine-tuning datasets does not seem to lead significant difference in the benchmark +performances, which may be due to the limited number of fine-tuning steps used in our method. +The regression on benchmark tasks is consistent with our observation on perplexity regression in +Section 3.2. + +3.5 LONG DOCUMENT SUMMARIZATION + +In this task, we evaluate our models’ performance on the long document summarization task. In +particular, we consider the GovReport (Huang et al., 2021) dataset, which contains 17457 documents +for training and 972 documents for evaluation. Each document comes with a human generated +summary. We truncate all input documents to their first 15000 tokens. + +We fine-tune the LL.aMA models extended with Position Interpolation with a context window of +16384. Note the rescaling of position indices are still required during this fine-tuning step. We first +Model Size Context Window Fine-tune on BoolQ PIQA Race-M Race-H WinoGrande + +format the raw document using the prompt template in Figure 4, and then concatenate the prompt +with the ground-truth summary (truncate to 1000 tokens) associated with each document. We fine- +tune the model using the next token prediction task with the above setup for 10 epochs. The losses +from the input prompt proportion of training examples are excluded during our fine-tuning. + +We use a generation temperature of 0.5 and top, = 0.95 as our inference parameter to generate a +summarization of each document in the test set. The final output is truncated at 1000 tokens. We +used the ROUGE-1/ROUGE-2/ROUGE-L scores (Lin, 2004) as the evaluation metrics to evaluate +the models’ outputs vs the ground-truth summaries. + +In Table 6 we report our evaluation results. We have also included results from two baselines in +existing SCROLLS Leaderboard (Shaham et al., 2022; Ainslie et al., 2023). In general, we have +obtained competitive R1 score among other models with minimal tuning of hyper-parameters. This +result suggests our models with 16384 context window can effectively handle the long document +summarization task. + +=== END OF FILE === + +Question: What is the paper about? +Answer: " | sed ':a;N;$!ba;s/\n/⦅newline⦆/g' > test_prompt.txt + +eole predict -c llama-inference.yaml -src test_prompt.txt -output test_output.txt \ No newline at end of file diff --git a/recipes/llama3/run.sh b/recipes/llama3/run.sh new file mode 100755 index 00000000..fa1075ad --- /dev/null +++ b/recipes/llama3/run.sh @@ -0,0 +1,6 @@ +# naive script with commands from the readme +# (useful to make sure the recipe still runs) + +eole convert HF --model_dir meta-llama/Meta-Llama-3-8B-Instruct --output $EOLE_MODEL_DIR/llama3-8b-instruct --token $HF_TOKEN +echo -e "What are some nice places to visit in France?" | sed ':a;N;$!ba;s/\n/⦅newline⦆/g' > test_prompt.txt +eole predict -c llama-inference.yaml -src test_prompt.txt -output test_output.txt \ No newline at end of file diff --git a/recipes/mistral/README.md b/recipes/mistral/README.md index 39283788..6ada5438 100644 --- a/recipes/mistral/README.md +++ b/recipes/mistral/README.md @@ -17,7 +17,7 @@ export HF_TOKEN= ### Download and convert model ``` -eole convert HF --model_dir TheBloke/Mistral-7B-Instruct-v0.2-AWQ --output ${EOLE_MODEL_DIR}/mistral-7b-instruct-v0.2-awq --token $HF_TOKEN +eole convert HF --model_dir mistralai/Mistral-7B-v0.3 --output ${EOLE_MODEL_DIR}/mistral-7b-v0.3 --token $HF_TOKEN ``` diff --git a/recipes/mistral/mistral-7b-awq-gemm-inference.yaml b/recipes/mistral/mistral-7b-awq-gemm-inference.yaml index 1cbff4bb..e9c5e0b9 100755 --- a/recipes/mistral/mistral-7b-awq-gemm-inference.yaml +++ b/recipes/mistral/mistral-7b-awq-gemm-inference.yaml @@ -1,11 +1,11 @@ transforms: [sentencepiece] transforms_configs: sentencepiece: - src_subword_model: "$EOLE_MODEL_DIR/mistral-7b-v0.3/tokenizer.model" - tgt_subword_model: "$EOLE_MODEL_DIR/mistral-7b-v0.3/tokenizer.model" + src_subword_model: "$EOLE_MODEL_DIR/mistral-7b-instruct-v0.2-awq/tokenizer.model" + tgt_subword_model: "$EOLE_MODEL_DIR/mistral-7b-instruct-v0.2-awq/tokenizer.model" # Model info -model_path: "$EOLE_MODEL_DIR/mistral-7b-v0.3" +model_path: "$EOLE_MODEL_DIR/mistral-7b-instruct-v0.2-awq" # Inference seed: 42 @@ -19,8 +19,7 @@ gpu_ranks: [0] # parallel_mode: "tensor_parallel" #quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] #quant_type: "bnb_NF4" -# precision: fp16 -precision: fp16 +compute_dtype: fp16 #top_k: 1 #top_p: 0.6 #temperature: 0.9 diff --git a/recipes/mistral/run.sh b/recipes/mistral/run.sh new file mode 100755 index 00000000..60ce090b --- /dev/null +++ b/recipes/mistral/run.sh @@ -0,0 +1,3 @@ +eole convert HF --model_dir mistralai/Mistral-7B-v0.3 --output ${EOLE_MODEL_DIR}/mistral-7b-v0.3 --token $HF_TOKEN +echo -e "What are some nice places to visit in France?" | sed ':a;N;$!ba;s/\n/⦅newline⦆/g' > test_prompt.txt +eole predict -c mistral-7b-awq-gemm-inference.yaml -src test_prompt.txt -output test_output.txt \ No newline at end of file diff --git a/recipes/wiki_103/README.md b/recipes/wiki_103/README.md index 4134a5ea..a8c6b756 100644 --- a/recipes/wiki_103/README.md +++ b/recipes/wiki_103/README.md @@ -45,7 +45,7 @@ transforms_configs: ### Build vocabulary command The vocabulary is built using: ```bash -eole build_vocab -config wiki_103.yaml -n_sample -1 # -n_threads 4 +eole build_vocab -config wiki_103.yaml -n_sample -1 # -num_threads 4 ``` ## Step 3: Train the model