From 1f92b63c98ee8f35695f606c1cd71eee55bd56dd Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 5 Sep 2023 15:57:02 +0200 Subject: [PATCH] Add explicit representations for the FIM tokens being used, facilitating the use of models that use different tokens (no longer specific to bigcode models) --- autodev/run_completion_model_finetuning.py | 8 +++- autodev/run_service.py | 24 ++++++------ .../autodev/autocomplete/completion_model.py | 38 ++++++++++++------- .../completion_model_comparison.py | 2 +- .../src/autodev/autocomplete/fim_config.py | 32 ++++++++++++++++ .../autocomplete/finetuning/dataset.py | 23 ++++------- .../autodev/autocomplete/finetuning/fim.py | 36 ++++-------------- .../autocomplete/finetuning/finetuning.py | 18 +++++++++ autodev/src/autodev/autocomplete/model.py | 17 ++++++--- 9 files changed, 120 insertions(+), 78 deletions(-) create mode 100644 autodev/src/autodev/autocomplete/fim_config.py diff --git a/autodev/run_completion_model_finetuning.py b/autodev/run_completion_model_finetuning.py index eb380b8..6e2b636 100644 --- a/autodev/run_completion_model_finetuning.py +++ b/autodev/run_completion_model_finetuning.py @@ -1,8 +1,13 @@ +""" +Performs fine-tuning of an auto-completion model, teaching the model a new language +""" + import multiprocessing import os from glob import glob from typing import Optional +from autodev.autocomplete.fim_config import BigCodeFIMTokens from autodev.util import logging from autodev.autocomplete.finetuning import FineTuningConfiguration, CompletionFineTuning @@ -15,7 +20,7 @@ def run_finetuning_santacoder_thestack(lang_id, save_freq=500, log_freq=1, fim_rate=0.5, - fim_spm_rate=0.5, + fim_spm_rate=0.0, resume_from_checkpoint: Optional[bool] = None, use_lora=False, lora_r=32, @@ -65,6 +70,7 @@ def run_finetuning_santacoder_thestack(lang_id, save_freq=save_freq, log_freq=log_freq, num_workers=multiprocessing.cpu_count(), + fim_tokens=BigCodeFIMTokens(), fim_rate=fim_rate, fim_spm_rate=fim_spm_rate, output_dir=output_dir, diff --git a/autodev/run_service.py b/autodev/run_service.py index 1c700d2..acccffe 100644 --- a/autodev/run_service.py +++ b/autodev/run_service.py @@ -1,23 +1,20 @@ from typing import Literal -import torch - -from autodev.util import logging from autodev.autocomplete.completion_model import CompletionModel -from autodev.autocomplete.model import SantaCoderModelFactory, ModelFactory +from autodev.autocomplete.model import SantaCoderModelFactory, BigCodeModelFactory from autodev.llm import LLMType from autodev.service import Service - +from autodev.util import logging log = logging.getLogger(__name__) -def run_service(completion_model_name: Literal["santacoder-ruby", "santacoder", "starcoder"]): +def run_service(completion_model_name: Literal["santacoder-ruby", "santacoder", "starcoder"], chat_model_llmtype=LLMType.OPENAI_CHAT_GPT4, + max_completion_tokens=32): # create completion model - completion_device = "cuda:0" if torch.cuda.is_available() else "cpu" if completion_model_name == "starcoder": completion_model_path = "bigcode/starcoder" - completion_model_factory = ModelFactory(completion_model_path) + completion_model_factory = BigCodeModelFactory(completion_model_path) else: completion_model_factory = SantaCoderModelFactory() if completion_model_name == "santacoder-ruby": @@ -26,16 +23,17 @@ def run_service(completion_model_name: Literal["santacoder-ruby", "santacoder", completion_model_path = "bigcode/santacoder" else: raise ValueError(completion_model_name) - max_completion_tokens = 32 log.info(f"Loading completion model '{completion_model_path}'") - completion_model = CompletionModel(completion_model_factory.create_model(completion_model_path), - completion_model_factory.create_tokenizer(), device=completion_device, max_new_tokens=max_completion_tokens) + completion_model = CompletionModel.from_model_factory(completion_model_factory, model_path=completion_model_path, + max_tokens=max_completion_tokens) - Service(LLMType.OPENAI_CHAT_GPT4, completion_model).run() + # run service + Service(chat_model_llmtype, completion_model).run() if __name__ == '__main__': logging.configure() - run_service("santacoder-ruby") + run_service("santacoder") + #run_service("santacoder-ruby") #run_service("starcoder") diff --git a/autodev/src/autodev/autocomplete/completion_model.py b/autodev/src/autodev/autocomplete/completion_model.py index 302cdbc..4051a21 100644 --- a/autodev/src/autodev/autocomplete/completion_model.py +++ b/autodev/src/autodev/autocomplete/completion_model.py @@ -1,45 +1,55 @@ import re from typing import Union +import torch from optimum.onnxruntime import ORTModelForCausalLM from peft import PeftModel from transformers import PreTrainedModel, pipeline from .completion_task import CompletionTask, CompletionResult +from .fim_config import FIMTokens +from .model import ModelFactory -class CompletionModel: - DEBUG = False +def fim_prompt(task: CompletionTask, fim_tokens: FIMTokens) -> str: + return f"{fim_tokens.prefix_token}{task.prefix}{fim_tokens.suffix_token}{task.suffix}{fim_tokens.middle_token}" - # TODO: The tags/tokens below are specific to santacoder and other bigcode models. - # To generalise this, we should probably add the tokens to the ModelFactory and pass them on wherever necessary. - TAG_FIM_PREFIX = "" - TAG_FIM_SUFFIX = "" - TAG_FIM_MIDDLE = "" - re_fim_middle = re.compile(re.escape(TAG_FIM_MIDDLE)) +class CompletionModel: + DEBUG = False def __init__(self, model: Union[PreTrainedModel, PeftModel, ORTModelForCausalLM], tokenizer, + fim_tokens: FIMTokens, max_new_tokens=256, device="cuda:0"): self.model = model + self.fim_tokens = fim_tokens self.pipe = pipeline("text-generation", model=model, max_new_tokens=max_new_tokens, device=device, trust_remote_code=True, tokenizer=tokenizer) + self.re_fim_middle = re.compile(re.escape(fim_tokens.middle_token)) @classmethod - def fim_prompt(cls, task: CompletionTask) -> str: - return f"{cls.TAG_FIM_PREFIX}{task.prefix}{cls.TAG_FIM_SUFFIX}{task.suffix}{cls.TAG_FIM_MIDDLE}" + def from_model_factory(cls, model_factory: ModelFactory, model_path=None, max_tokens=256, device=None) -> "CompletionModel": + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + return CompletionModel(model_factory.create_model(model_path), + model_factory.create_tokenizer(), + model_factory.fim_tokens, + device=device, + max_new_tokens=max_tokens) - @classmethod - def _extract_completion(cls, s: str, task: CompletionTask) -> str: - if cls.DEBUG: + def fim_prompt(self, task: CompletionTask) -> str: + return fim_prompt(task, self.fim_tokens) + + def _extract_completion(self, s: str, task: CompletionTask) -> str: + if self.DEBUG: import pickle with open("completion.pkl", "wb") as f: pickle.dump({"task": task, "s": s}, f) - m = cls.re_fim_middle.search(s) + m = self.re_fim_middle.search(s) if not m: return "" completion = s[m.end():] diff --git a/autodev/src/autodev/autocomplete/completion_model_comparison.py b/autodev/src/autodev/autocomplete/completion_model_comparison.py index cfcb132..9dcddb6 100644 --- a/autodev/src/autodev/autocomplete/completion_model_comparison.py +++ b/autodev/src/autodev/autocomplete/completion_model_comparison.py @@ -63,7 +63,7 @@ def result_dir(task_name): log.info(f"Loading model {model_id}") model = self.model_factory.create_model(model_id) - completion_model = CompletionModel(model, tokenizer, device=self.device) + completion_model = CompletionModel(model, tokenizer, self.model_factory.fim_tokens, device=self.device) for task_name, task in tasks.items(): ext = os.path.splitext(task_name)[1] diff --git a/autodev/src/autodev/autocomplete/fim_config.py b/autodev/src/autodev/autocomplete/fim_config.py new file mode 100644 index 0000000..339b8e3 --- /dev/null +++ b/autodev/src/autodev/autocomplete/fim_config.py @@ -0,0 +1,32 @@ +from transformers import PreTrainedTokenizer + + +class FIMTokenIds: + def __init__(self, prefix_token_id: int, suffix_token_id: int, middle_token_id: int, pad_token_id: int): + self.prefix_token_id = prefix_token_id + self.suffix_token_id = suffix_token_id + self.middle_token_id = middle_token_id + self.pad_token_id = pad_token_id + + +class FIMTokens: + def __init__(self, prefix_token: str, suffix_token: str, middle_token: str, pad_token: str): + self.prefix_token = prefix_token + self.suffix_token = suffix_token + self.middle_token = middle_token + self.pad_token = pad_token + + def get_token_ids(self, tokenizer: PreTrainedTokenizer) -> FIMTokenIds: + prefix_token_id = self._get_token_id(self.prefix_token, tokenizer) + suffix_token_id = self._get_token_id(self.suffix_token, tokenizer) + middle_token_id = self._get_token_id(self.middle_token, tokenizer) + pad_token_id = self._get_token_id(self.pad_token, tokenizer) + return FIMTokenIds(prefix_token_id, suffix_token_id, middle_token_id, pad_token_id) + + def _get_token_id(self, token: str, tokenizer: PreTrainedTokenizer): + return tokenizer.vocab[token] + + +class BigCodeFIMTokens(FIMTokens): + def __init__(self): + super().__init__("", "", "", "") diff --git a/autodev/src/autodev/autocomplete/finetuning/dataset.py b/autodev/src/autodev/autocomplete/finetuning/dataset.py index a95c977..b887350 100644 --- a/autodev/src/autodev/autocomplete/finetuning/dataset.py +++ b/autodev/src/autodev/autocomplete/finetuning/dataset.py @@ -4,7 +4,7 @@ import logging import random -from typing import Tuple +from typing import Tuple, Optional import numpy as np import torch @@ -13,6 +13,7 @@ from tqdm import tqdm from . import fim +from ..fim_config import FIMTokenIds log = logging.getLogger(__name__) @@ -36,6 +37,7 @@ class ConstantLengthDataset(IterableDataset): concat_token_id: the identifier of the token with which to connect content from different source files tokenizer (Tokenizer): The processor used for proccessing the data. dataset (dataset.Dataset): Dataset with text files. + fim_tokens: specifies the tokens to use for FIM (fill-in-the-middle) infinite (bool): If True the iterator is reset after dataset reaches end else stops. seq_length (int): Length of token sequences to return. num_of_sequences (int): Number of token sequences to keep in buffer. @@ -49,6 +51,7 @@ def __init__( concat_token_id: int, tokenizer, dataset, + fim_token_ids: Optional[FIMTokenIds] = None, infinite=False, seq_length=1024, num_of_sequences=1024, @@ -69,16 +72,9 @@ def __init__( self.fim_rate = fim_rate self.fim_spm_rate = fim_spm_rate self.seed = seed - - ( - self.suffix_tok_id, - self.prefix_tok_id, - self.middle_tok_id, - self.pad_tok_id, - ) = fim.get_fim_token_ids(self.tokenizer) - if not self.suffix_tok_id and self.fim_rate > 0: - print("FIM is not supported by tokenizer, disabling FIM") - self.fim_rate = 0 + self.fim_token_ids = fim_token_ids + if self.fim_token_ids is None and self.fim_rate > 0: + raise ValueError("Must provide fim_token_ids when fim_rate > 0") def __iter__(self): iterator = iter(self.dataset) @@ -117,10 +113,7 @@ def __iter__(self): tokenized_input, np_rng = fim.permute( tokenized_input, np_rng, - self.suffix_tok_id, - self.prefix_tok_id, - self.middle_tok_id, - self.pad_tok_id, + self.fim_token_ids, fim_rate=self.fim_rate, fim_spm_rate=self.fim_spm_rate, truncate_or_pad=False, diff --git a/autodev/src/autodev/autocomplete/finetuning/fim.py b/autodev/src/autodev/autocomplete/finetuning/fim.py index 8c7c5f8..c0d28ea 100644 --- a/autodev/src/autodev/autocomplete/finetuning/fim.py +++ b/autodev/src/autodev/autocomplete/finetuning/fim.py @@ -1,35 +1,15 @@ """ The code in this module is based on https://github.com/loubnabnl/santacoder-finetuning """ -import functools - import numpy as np - -# this is expensive so we cache it -@functools.lru_cache(maxsize=None) -def get_fim_token_ids(tokenizer): - try: - # TODO: This is specific to santacoder's tokenizer; to generalise this, we'll have to specify the tokens explicitly - _, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[ - "additional_special_tokens" - ] - suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = ( - tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD] - ) - except KeyError: - suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None - return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id +from autodev.autocomplete.fim_config import FIMTokenIds -## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py def permute( sample, np_rng, - suffix_tok_id, - prefix_tok_id, - middle_tok_id, - pad_tok_id, + fim_token_ids: FIMTokenIds, fim_rate=0.5, fim_spm_rate=0.5, truncate_or_pad=False, @@ -55,15 +35,15 @@ def permute( return sample, np_rng suffix = suffix[: suffix.shape[0] - diff] elif diff < 0: - suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) + suffix = np.concatenate([suffix, np.full((-1 * diff), fim_token_ids.pad_token_id)]) if np_rng.binomial(1, fim_spm_rate): # SPM (variant 2 from FIM paper) new_sample = np.concatenate( [ - [prefix_tok_id, suffix_tok_id], + [fim_token_ids.prefix_token_id, fim_token_ids.suffix_token_id], suffix, - [middle_tok_id], + [fim_token_ids.middle_token_id], prefix, middle, ] @@ -72,11 +52,11 @@ def permute( # PSM new_sample = np.concatenate( [ - [prefix_tok_id], + [fim_token_ids.prefix_token_id], prefix, - [suffix_tok_id], + [fim_token_ids.suffix_token_id], suffix, - [middle_tok_id], + [fim_token_ids.middle_token_id], middle, ] ) diff --git a/autodev/src/autodev/autocomplete/finetuning/finetuning.py b/autodev/src/autodev/autocomplete/finetuning/finetuning.py index 79a49ee..81c19f4 100644 --- a/autodev/src/autodev/autocomplete/finetuning/finetuning.py +++ b/autodev/src/autodev/autocomplete/finetuning/finetuning.py @@ -21,6 +21,7 @@ from transformers.trainer import TRAINING_ARGS_NAME from .dataset import load_train_val_datasets, chars_token_ratio, ConstantLengthDataset +from ..fim_config import FIMTokens log = logging.getLogger(__name__) @@ -55,12 +56,26 @@ class FineTuningConfiguration: log_freq: int = 1 eval_freq: int = 1000 save_freq: int = 1000 + fim_tokens: Optional[FIMTokens] = None + """the tokens to use for fill-in-the-middle (FIM); may be None if FIM is not activated (fim_rate=0)""" fim_rate: float = 0 + """the relative frequency with which to apply the fill-in-the-middle (FIM) transform""" fim_spm_rate: float = 0 + """ + the relative frequency with which, when applying FIM tranform, to apply the SPM (suffix-prefix-middle) variant + instead of the regular variant (prefix-suffix-middle) + """ use_lora: bool = False + """whether to use low-rank adaptation (LoRA)""" lora_r: int = 8 + """LoRA rank parameter""" lora_alpha: int = 8 + """LoRA alpha parameter""" lora_target_modules: Optional[List[str]] = None + """ + The names of the torch modules to which low-rank adaptation is to be applied; if None, try to determine automatically + (works for select models only) + """ lora_dropout = 0.1 @@ -130,10 +145,12 @@ def create_datasets(self, tokenizer): chars_per_token = chars_token_ratio(train_data, tokenizer, cfg.data_column) log.info(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else cfg.eos_token_id + fim_token_ids = cfg.fim_tokens.get_token_ids(tokenizer) if cfg.fim_tokens is not None else None train_dataset = ConstantLengthDataset( concat_token_id, tokenizer, train_data, + fim_token_ids, infinite=True, seq_length=cfg.seq_length, chars_per_token=chars_per_token, @@ -146,6 +163,7 @@ def create_datasets(self, tokenizer): concat_token_id, tokenizer, valid_data, + fim_token_ids, infinite=False, seq_length=cfg.seq_length, chars_per_token=chars_per_token, diff --git a/autodev/src/autodev/autocomplete/model.py b/autodev/src/autodev/autocomplete/model.py index 9484319..75dacd1 100644 --- a/autodev/src/autodev/autocomplete/model.py +++ b/autodev/src/autodev/autocomplete/model.py @@ -1,5 +1,4 @@ import logging -import os from abc import ABC, abstractmethod from pathlib import Path from typing import Union, Optional @@ -8,6 +7,8 @@ from peft import PeftModel from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig +from autodev.autocomplete.fim_config import FIMTokens, BigCodeFIMTokens + log = logging.getLogger(__name__) @@ -44,12 +45,14 @@ def transform(self, model: PreTrainedModel): from optimum.bettertransformer import BetterTransformer return BetterTransformer.transform(model) + TModel = Union[PreTrainedModel, PeftModel, ORTModelForCausalLM] class ModelFactory: - def __init__(self, base_model_id: str, transformation: Optional[ModelTransformation] = None): + def __init__(self, base_model_id: str, fim_tokens: FIMTokens, transformation: Optional[ModelTransformation] = None): self.base_model_id = base_model_id + self.fim_tokens = fim_tokens self.transformation = transformation def create_tokenizer(self) -> PreTrainedTokenizer: @@ -136,9 +139,11 @@ def _find_onnx_model(self, dir: Path) -> Optional[Path]: return None -class SantaCoderModelFactory(ModelFactory): - def __init__(self): - super().__init__(base_model_id="bigcode/santacoder") - +class BigCodeModelFactory(ModelFactory): + def __init__(self, base_model_id: str): + super().__init__(base_model_id, fim_tokens=BigCodeFIMTokens()) +class SantaCoderModelFactory(BigCodeModelFactory): + def __init__(self): + super().__init__(base_model_id="bigcode/santacoder")