Skip to content

Commit

Permalink
Add explicit representations for the FIM tokens being used, facilitating
Browse files Browse the repository at this point in the history
the use of models that use different tokens (no longer specific to
bigcode models)
  • Loading branch information
opcode81 committed Sep 5, 2023
1 parent 57fbc1f commit 1f92b63
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 78 deletions.
8 changes: 7 additions & 1 deletion autodev/run_completion_model_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""
Performs fine-tuning of an auto-completion model, teaching the model a new language
"""

import multiprocessing
import os
from glob import glob
from typing import Optional

from autodev.autocomplete.fim_config import BigCodeFIMTokens
from autodev.util import logging
from autodev.autocomplete.finetuning import FineTuningConfiguration, CompletionFineTuning

Expand All @@ -15,7 +20,7 @@ def run_finetuning_santacoder_thestack(lang_id,
save_freq=500,
log_freq=1,
fim_rate=0.5,
fim_spm_rate=0.5,
fim_spm_rate=0.0,
resume_from_checkpoint: Optional[bool] = None,
use_lora=False,
lora_r=32,
Expand Down Expand Up @@ -65,6 +70,7 @@ def run_finetuning_santacoder_thestack(lang_id,
save_freq=save_freq,
log_freq=log_freq,
num_workers=multiprocessing.cpu_count(),
fim_tokens=BigCodeFIMTokens(),
fim_rate=fim_rate,
fim_spm_rate=fim_spm_rate,
output_dir=output_dir,
Expand Down
24 changes: 11 additions & 13 deletions autodev/run_service.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
from typing import Literal

import torch

from autodev.util import logging
from autodev.autocomplete.completion_model import CompletionModel
from autodev.autocomplete.model import SantaCoderModelFactory, ModelFactory
from autodev.autocomplete.model import SantaCoderModelFactory, BigCodeModelFactory
from autodev.llm import LLMType
from autodev.service import Service

from autodev.util import logging

log = logging.getLogger(__name__)


def run_service(completion_model_name: Literal["santacoder-ruby", "santacoder", "starcoder"]):
def run_service(completion_model_name: Literal["santacoder-ruby", "santacoder", "starcoder"], chat_model_llmtype=LLMType.OPENAI_CHAT_GPT4,
max_completion_tokens=32):
# create completion model
completion_device = "cuda:0" if torch.cuda.is_available() else "cpu"
if completion_model_name == "starcoder":
completion_model_path = "bigcode/starcoder"
completion_model_factory = ModelFactory(completion_model_path)
completion_model_factory = BigCodeModelFactory(completion_model_path)
else:
completion_model_factory = SantaCoderModelFactory()
if completion_model_name == "santacoder-ruby":
Expand All @@ -26,16 +23,17 @@ def run_service(completion_model_name: Literal["santacoder-ruby", "santacoder",
completion_model_path = "bigcode/santacoder"
else:
raise ValueError(completion_model_name)
max_completion_tokens = 32
log.info(f"Loading completion model '{completion_model_path}'")
completion_model = CompletionModel(completion_model_factory.create_model(completion_model_path),
completion_model_factory.create_tokenizer(), device=completion_device, max_new_tokens=max_completion_tokens)
completion_model = CompletionModel.from_model_factory(completion_model_factory, model_path=completion_model_path,
max_tokens=max_completion_tokens)

Service(LLMType.OPENAI_CHAT_GPT4, completion_model).run()
# run service
Service(chat_model_llmtype, completion_model).run()


if __name__ == '__main__':
logging.configure()

run_service("santacoder-ruby")
run_service("santacoder")
#run_service("santacoder-ruby")
#run_service("starcoder")
38 changes: 24 additions & 14 deletions autodev/src/autodev/autocomplete/completion_model.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,55 @@
import re
from typing import Union

import torch
from optimum.onnxruntime import ORTModelForCausalLM
from peft import PeftModel
from transformers import PreTrainedModel, pipeline

from .completion_task import CompletionTask, CompletionResult
from .fim_config import FIMTokens
from .model import ModelFactory


class CompletionModel:
DEBUG = False
def fim_prompt(task: CompletionTask, fim_tokens: FIMTokens) -> str:
return f"{fim_tokens.prefix_token}{task.prefix}{fim_tokens.suffix_token}{task.suffix}{fim_tokens.middle_token}"

# TODO: The tags/tokens below are specific to santacoder and other bigcode models.
# To generalise this, we should probably add the tokens to the ModelFactory and pass them on wherever necessary.
TAG_FIM_PREFIX = "<fim-prefix>"
TAG_FIM_SUFFIX = "<fim-suffix>"
TAG_FIM_MIDDLE = "<fim-middle>"

re_fim_middle = re.compile(re.escape(TAG_FIM_MIDDLE))
class CompletionModel:
DEBUG = False

def __init__(self,
model: Union[PreTrainedModel, PeftModel, ORTModelForCausalLM],
tokenizer,
fim_tokens: FIMTokens,
max_new_tokens=256,
device="cuda:0"):
self.model = model
self.fim_tokens = fim_tokens
self.pipe = pipeline("text-generation", model=model, max_new_tokens=max_new_tokens, device=device,
trust_remote_code=True, tokenizer=tokenizer)
self.re_fim_middle = re.compile(re.escape(fim_tokens.middle_token))

@classmethod
def fim_prompt(cls, task: CompletionTask) -> str:
return f"{cls.TAG_FIM_PREFIX}{task.prefix}{cls.TAG_FIM_SUFFIX}{task.suffix}{cls.TAG_FIM_MIDDLE}"
def from_model_factory(cls, model_factory: ModelFactory, model_path=None, max_tokens=256, device=None) -> "CompletionModel":
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
return CompletionModel(model_factory.create_model(model_path),
model_factory.create_tokenizer(),
model_factory.fim_tokens,
device=device,
max_new_tokens=max_tokens)

@classmethod
def _extract_completion(cls, s: str, task: CompletionTask) -> str:
if cls.DEBUG:
def fim_prompt(self, task: CompletionTask) -> str:
return fim_prompt(task, self.fim_tokens)

def _extract_completion(self, s: str, task: CompletionTask) -> str:
if self.DEBUG:
import pickle
with open("completion.pkl", "wb") as f:
pickle.dump({"task": task, "s": s}, f)

m = cls.re_fim_middle.search(s)
m = self.re_fim_middle.search(s)
if not m:
return ""
completion = s[m.end():]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def result_dir(task_name):

log.info(f"Loading model {model_id}")
model = self.model_factory.create_model(model_id)
completion_model = CompletionModel(model, tokenizer, device=self.device)
completion_model = CompletionModel(model, tokenizer, self.model_factory.fim_tokens, device=self.device)

for task_name, task in tasks.items():
ext = os.path.splitext(task_name)[1]
Expand Down
32 changes: 32 additions & 0 deletions autodev/src/autodev/autocomplete/fim_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from transformers import PreTrainedTokenizer


class FIMTokenIds:
def __init__(self, prefix_token_id: int, suffix_token_id: int, middle_token_id: int, pad_token_id: int):
self.prefix_token_id = prefix_token_id
self.suffix_token_id = suffix_token_id
self.middle_token_id = middle_token_id
self.pad_token_id = pad_token_id


class FIMTokens:
def __init__(self, prefix_token: str, suffix_token: str, middle_token: str, pad_token: str):
self.prefix_token = prefix_token
self.suffix_token = suffix_token
self.middle_token = middle_token
self.pad_token = pad_token

def get_token_ids(self, tokenizer: PreTrainedTokenizer) -> FIMTokenIds:
prefix_token_id = self._get_token_id(self.prefix_token, tokenizer)
suffix_token_id = self._get_token_id(self.suffix_token, tokenizer)
middle_token_id = self._get_token_id(self.middle_token, tokenizer)
pad_token_id = self._get_token_id(self.pad_token, tokenizer)
return FIMTokenIds(prefix_token_id, suffix_token_id, middle_token_id, pad_token_id)

def _get_token_id(self, token: str, tokenizer: PreTrainedTokenizer):
return tokenizer.vocab[token]


class BigCodeFIMTokens(FIMTokens):
def __init__(self):
super().__init__("<fim-prefix>", "<fim-suffix>", "<fim-middle>", "<fim-pad>")
23 changes: 8 additions & 15 deletions autodev/src/autodev/autocomplete/finetuning/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import random
from typing import Tuple
from typing import Tuple, Optional

import numpy as np
import torch
Expand All @@ -13,6 +13,7 @@
from tqdm import tqdm

from . import fim
from ..fim_config import FIMTokenIds

log = logging.getLogger(__name__)

Expand All @@ -36,6 +37,7 @@ class ConstantLengthDataset(IterableDataset):
concat_token_id: the identifier of the token with which to connect content from different source files
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
fim_tokens: specifies the tokens to use for FIM (fill-in-the-middle)
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
Expand All @@ -49,6 +51,7 @@ def __init__(
concat_token_id: int,
tokenizer,
dataset,
fim_token_ids: Optional[FIMTokenIds] = None,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
Expand All @@ -69,16 +72,9 @@ def __init__(
self.fim_rate = fim_rate
self.fim_spm_rate = fim_spm_rate
self.seed = seed

(
self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_tok_id,
) = fim.get_fim_token_ids(self.tokenizer)
if not self.suffix_tok_id and self.fim_rate > 0:
print("FIM is not supported by tokenizer, disabling FIM")
self.fim_rate = 0
self.fim_token_ids = fim_token_ids
if self.fim_token_ids is None and self.fim_rate > 0:
raise ValueError("Must provide fim_token_ids when fim_rate > 0")

def __iter__(self):
iterator = iter(self.dataset)
Expand Down Expand Up @@ -117,10 +113,7 @@ def __iter__(self):
tokenized_input, np_rng = fim.permute(
tokenized_input,
np_rng,
self.suffix_tok_id,
self.prefix_tok_id,
self.middle_tok_id,
self.pad_tok_id,
self.fim_token_ids,
fim_rate=self.fim_rate,
fim_spm_rate=self.fim_spm_rate,
truncate_or_pad=False,
Expand Down
36 changes: 8 additions & 28 deletions autodev/src/autodev/autocomplete/finetuning/fim.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,15 @@
"""
The code in this module is based on https://github.com/loubnabnl/santacoder-finetuning
"""
import functools

import numpy as np


# this is expensive so we cache it
@functools.lru_cache(maxsize=None)
def get_fim_token_ids(tokenizer):
try:
# TODO: This is specific to santacoder's tokenizer; to generalise this, we'll have to specify the tokens explicitly
_, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[
"additional_special_tokens"
]
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (
tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
)
except KeyError:
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None
return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id
from autodev.autocomplete.fim_config import FIMTokenIds


## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py
def permute(
sample,
np_rng,
suffix_tok_id,
prefix_tok_id,
middle_tok_id,
pad_tok_id,
fim_token_ids: FIMTokenIds,
fim_rate=0.5,
fim_spm_rate=0.5,
truncate_or_pad=False,
Expand All @@ -55,15 +35,15 @@ def permute(
return sample, np_rng
suffix = suffix[: suffix.shape[0] - diff]
elif diff < 0:
suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])
suffix = np.concatenate([suffix, np.full((-1 * diff), fim_token_ids.pad_token_id)])

if np_rng.binomial(1, fim_spm_rate):
# SPM (variant 2 from FIM paper)
new_sample = np.concatenate(
[
[prefix_tok_id, suffix_tok_id],
[fim_token_ids.prefix_token_id, fim_token_ids.suffix_token_id],
suffix,
[middle_tok_id],
[fim_token_ids.middle_token_id],
prefix,
middle,
]
Expand All @@ -72,11 +52,11 @@ def permute(
# PSM
new_sample = np.concatenate(
[
[prefix_tok_id],
[fim_token_ids.prefix_token_id],
prefix,
[suffix_tok_id],
[fim_token_ids.suffix_token_id],
suffix,
[middle_tok_id],
[fim_token_ids.middle_token_id],
middle,
]
)
Expand Down
18 changes: 18 additions & 0 deletions autodev/src/autodev/autocomplete/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from transformers.trainer import TRAINING_ARGS_NAME

from .dataset import load_train_val_datasets, chars_token_ratio, ConstantLengthDataset
from ..fim_config import FIMTokens

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,12 +56,26 @@ class FineTuningConfiguration:
log_freq: int = 1
eval_freq: int = 1000
save_freq: int = 1000
fim_tokens: Optional[FIMTokens] = None
"""the tokens to use for fill-in-the-middle (FIM); may be None if FIM is not activated (fim_rate=0)"""
fim_rate: float = 0
"""the relative frequency with which to apply the fill-in-the-middle (FIM) transform"""
fim_spm_rate: float = 0
"""
the relative frequency with which, when applying FIM tranform, to apply the SPM (suffix-prefix-middle) variant
instead of the regular variant (prefix-suffix-middle)
"""
use_lora: bool = False
"""whether to use low-rank adaptation (LoRA)"""
lora_r: int = 8
"""LoRA rank parameter"""
lora_alpha: int = 8
"""LoRA alpha parameter"""
lora_target_modules: Optional[List[str]] = None
"""
The names of the torch modules to which low-rank adaptation is to be applied; if None, try to determine automatically
(works for select models only)
"""
lora_dropout = 0.1


Expand Down Expand Up @@ -130,10 +145,12 @@ def create_datasets(self, tokenizer):
chars_per_token = chars_token_ratio(train_data, tokenizer, cfg.data_column)
log.info(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else cfg.eos_token_id
fim_token_ids = cfg.fim_tokens.get_token_ids(tokenizer) if cfg.fim_tokens is not None else None
train_dataset = ConstantLengthDataset(
concat_token_id,
tokenizer,
train_data,
fim_token_ids,
infinite=True,
seq_length=cfg.seq_length,
chars_per_token=chars_per_token,
Expand All @@ -146,6 +163,7 @@ def create_datasets(self, tokenizer):
concat_token_id,
tokenizer,
valid_data,
fim_token_ids,
infinite=False,
seq_length=cfg.seq_length,
chars_per_token=chars_per_token,
Expand Down
Loading

0 comments on commit 1f92b63

Please sign in to comment.