From d3e56cf7fdaec378370b61653d3c745b8d0682d3 Mon Sep 17 00:00:00 2001 From: mkshing Date: Wed, 6 Mar 2024 16:17:41 +0900 Subject: [PATCH] complete vlm --- .github/workflows/black.yaml | 18 +++ .github/workflows/test-build.yaml | 22 ++++ .gitignore | 1 + README.md | 58 +++++++++- configs/llm/evollm-jp.yaml | 7 ++ configs/vlm/evovlm-jp.yaml | 23 ++++ configs/vlm/jsvlm.yaml | 16 +++ configs/vlm/llava-1-6-mistral-7b.yaml | 23 ++++ evaluate.py | 70 ++++++++++++ evomerge/__init__.py | 3 + evomerge/eval/__init__.py | 2 + evomerge/eval/base.py | 155 +++++++++++++++++++++++++ evomerge/eval/ja_vg_vqa.py | 68 +++++++++++ evomerge/eval/ja_vlm_wild.py | 50 ++++++++ evomerge/eval/metrics.py | 100 ++++++++++++++++ evomerge/models/__init__.py | 5 + evomerge/models/base.py | 158 ++++++++++++++++++++++++++ evomerge/models/causallm.py | 74 ++++++++++++ evomerge/models/jsvlm.py | 81 +++++++++++++ evomerge/models/llava.py | 82 +++++++++++++ evomerge/models/prompt_templates.py | 42 +++++++ evomerge/utils.py | 59 ++++++++++ requirements.txt | 14 +++ setup.py | 19 ++++ 24 files changed, 1148 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/black.yaml create mode 100644 .github/workflows/test-build.yaml create mode 100644 configs/llm/evollm-jp.yaml create mode 100644 configs/vlm/evovlm-jp.yaml create mode 100644 configs/vlm/jsvlm.yaml create mode 100644 configs/vlm/llava-1-6-mistral-7b.yaml create mode 100644 evaluate.py create mode 100644 evomerge/__init__.py create mode 100644 evomerge/eval/__init__.py create mode 100644 evomerge/eval/base.py create mode 100644 evomerge/eval/ja_vg_vqa.py create mode 100644 evomerge/eval/ja_vlm_wild.py create mode 100644 evomerge/eval/metrics.py create mode 100644 evomerge/models/__init__.py create mode 100644 evomerge/models/base.py create mode 100644 evomerge/models/causallm.py create mode 100644 evomerge/models/jsvlm.py create mode 100644 evomerge/models/llava.py create mode 100644 evomerge/models/prompt_templates.py create mode 100644 evomerge/utils.py create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml new file mode 100644 index 0000000..a9834a9 --- /dev/null +++ b/.github/workflows/black.yaml @@ -0,0 +1,18 @@ +name: Run black +on: + push: + branches: [ main ] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install venv + run: | + sudo apt-get -y install python3.10-venv + - uses: psf/black@stable + with: + options: "--check --verbose -l88" + src: "." \ No newline at end of file diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml new file mode 100644 index 0000000..3d81f9e --- /dev/null +++ b/.github/workflows/test-build.yaml @@ -0,0 +1,22 @@ +name: Build package + +on: + push: + branches: [ main ] + pull_request: + +jobs: + build: + name: Build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: 3.10 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install . \ No newline at end of file diff --git a/.gitignore b/.gitignore index 68bc17f..732c190 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +.vscode \ No newline at end of file diff --git a/README.md b/README.md index e5f937b..0ee1035 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,56 @@ -# evolving-merged-models -Evolutionary Optimization of Model Merging Recipes +# Evolutionary Optimization of Model Merging Recipes + +This is an official repository of [Evolutionary Optimization of Model Merging Recipes](https://arxiv.org/). + +## Installation + +### 1. Clone the repo + +```bash +git clone https://github.com/SakanaAI/evolving-merged-models.git +cd evolving-merged-models +``` + +### 2. Install necesarry libraries + +```bash +pip install -r requirements.txt +``` + +* We tested under the following environment: + * Python Version: 3.10 + * CUDA Version: 12.3 + +## Models + +### EvoLLM-JP + +| Model | MGSM (Accuracy) | [lm-eval-harness](https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable) (Average) | +| :-- | --: | --: | +| [shisa-gamma-7b-v1](https://huggingface.co/augmxnt/shisa-gamma-7b-v1) | 0.10 | 66.06 | +| [WizardMath-7B-V1.1](https://huggingface.co/WizardLM/WizardMath-7B-V1.1) | 0.18 | 60.07 | +| [Abel-7B-002](https://huggingface.co/GAIR/Abel-7B-002) | 0.31 | 56.51 | +| [Ours: EvoLLM-JP](https://huggingface.co/SakanaAI/evollm-jp) | **0.52** | **70.51** | + +### EvoVLM-JP + +| Model | Ja-VG-VQA-500 (Ja-R-L) | JaVLM-Bench-In-the-Wild (Ja-R-L) | +| :-- | --: | --: | +| [LLaVA-1.6-Mistral-7B](https://llava-vl.github.io/blog/2024-01-30-llava-next/) | 14.32 | 41.10 | +| [JSVLM](https://huggingface.co/stabilityai/japanese-stable-vlm) | - | 40.50 | +| [Ours: EvoVLM-JP](https://huggingface.co/SakanaAI/evovlm-jp) | **19.70** | **51.94** | + +## Evaluation + +To launch evaluation, run the following script with a certain config. All configs used for the paper are in `configs`. + +```bash +python evaluate.py --config_path {path-to-config} +``` + +## Citation + +```bibtex + + +``` diff --git a/configs/llm/evollm-jp.yaml b/configs/llm/evollm-jp.yaml new file mode 100644 index 0000000..05a5030 --- /dev/null +++ b/configs/llm/evollm-jp.yaml @@ -0,0 +1,7 @@ +model: + target: evomerge.CausalLM + params: + model_path: SakanaAI/EvoLLM-JP + template: ja-shisa-vqa + model_kwargs: + torch_dtype: torch.float16 diff --git a/configs/vlm/evovlm-jp.yaml b/configs/vlm/evovlm-jp.yaml new file mode 100644 index 0000000..4ab2dc0 --- /dev/null +++ b/configs/vlm/evovlm-jp.yaml @@ -0,0 +1,23 @@ +model: + target: evomerge.LLaVA + params: + model_path: SakanaAI/EvoVLM-JP + template: ja-shisa-vqa + model_kwargs: + torch_dtype: torch.float16 + generation_config: + max_new_tokens: 512 + do_sample: false + num_beams: 5 + repetition_penalty: 1.5 +eval: + - target: evomerge.eval.JaVGVQA + params: + loader_kwargs: + batch_size: 4 + num_workers: 2 + - target: evomerge.eval.JaVLMBenchIntheWild + params: + loader_kwargs: + batch_size: 4 + num_workers: 2 \ No newline at end of file diff --git a/configs/vlm/jsvlm.yaml b/configs/vlm/jsvlm.yaml new file mode 100644 index 0000000..f2a68fc --- /dev/null +++ b/configs/vlm/jsvlm.yaml @@ -0,0 +1,16 @@ +model: + target: evomerge.JSVLM + params: + model_kwargs: + torch_dtype: torch.float16 + generation_config: + do_sample: false + max_new_tokens: 512 + num_beams: 5 + repetition_penalty: 1.5 +eval: + - target: evomerge.eval.JaVLMBenchIntheWild + params: + loader_kwargs: + batch_size: 4 + num_workers: 2 \ No newline at end of file diff --git a/configs/vlm/llava-1-6-mistral-7b.yaml b/configs/vlm/llava-1-6-mistral-7b.yaml new file mode 100644 index 0000000..59b4590 --- /dev/null +++ b/configs/vlm/llava-1-6-mistral-7b.yaml @@ -0,0 +1,23 @@ +model: + target: evomerge.LLaVA + params: + model_path: llava-hf/llava-v1.6-mistral-7b-hf + template: ja-shisa-vqa + model_kwargs: + torch_dtype: torch.float16 + generation_config: + max_new_tokens: 512 + do_sample: false + num_beams: 5 + repetition_penalty: 1.5 +eval: + - target: evomerge.eval.JaVGVQA + params: + loader_kwargs: + batch_size: 4 + num_workers: 2 + - target: evomerge.eval.JaVLMBenchIntheWild + params: + loader_kwargs: + batch_size: 4 + num_workers: 2 \ No newline at end of file diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..41f59f7 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,70 @@ +import os +import argparse +import gc +import json +import logging +import os +from dataclasses import asdict + +import torch + +from evomerge import instantiate_from_config, load_config, set_seed + +logger = logging.getLogger(__name__) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True, help="config path") + parser.add_argument("--output_path", type=str, default=None) + args = parser.parse_args() + # validation + if args.output_path is None: + args.output_path = ( + os.path.splitext(os.path.basename(args.config_path))[0] + ".json" + ) + args.output_path = f"results/{args.output_path}" + os.makedirs("results", exist_ok=True) + assert args.output_path.endswith(".json"), "`output_path` must be json file" + return args + + +def main(args): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + config = load_config(args.config_path) + logger.info(f"Config:\n{json.dumps(config, indent=2, ensure_ascii=False)}") + set_seed(42) + + # 1. load model (it's already moved to device) + model = instantiate_from_config(config["model"]) + logger.info(f"Model: {model.__class__.__name__}") + + eval_configs = config["eval"] + if isinstance(eval_configs, str): + eval_configs = [eval_configs] + + results = {} + for eval_config in eval_configs: + # 2. load evaluator + evaluator = instantiate_from_config(eval_config) + logger.info(f"Evaluator: {evaluator.__class__.__name__}") + # 3. Run! + outputs = evaluator(model) + logger.info(f"Result:\n{outputs.metrics}") + results[evaluator.name] = asdict(outputs) + del evaluator + torch.cuda.empty_cache() + gc.collect() + + with open(args.output_path, "w") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/evomerge/__init__.py b/evomerge/__init__.py new file mode 100644 index 0000000..543265f --- /dev/null +++ b/evomerge/__init__.py @@ -0,0 +1,3 @@ +from .eval import * +from .models import * +from .utils import * diff --git a/evomerge/eval/__init__.py b/evomerge/eval/__init__.py new file mode 100644 index 0000000..df7d88a --- /dev/null +++ b/evomerge/eval/__init__.py @@ -0,0 +1,2 @@ +from .ja_vg_vqa import JaVGVQA +from .ja_vlm_wild import JaVLMBenchIntheWild diff --git a/evomerge/eval/base.py b/evomerge/eval/base.py new file mode 100644 index 0000000..3133c65 --- /dev/null +++ b/evomerge/eval/base.py @@ -0,0 +1,155 @@ +import logging +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from datasets import Dataset, IterableDataset, load_dataset +from torch.utils.data import DataLoader +from tqdm import tqdm + +from ..models import LLMInference +from ..utils import instantiate_from_config + +logger = logging.getLogger(__name__) + + +@dataclass +class EvalOutput: + metrics: Dict[str, float] + results: List[Dict[str, Any]] + + +# copied from https://github.com/Stability-AI/datapipelines/blob/main/sdata/dataset.py#L15 +def dict_collation_fn( + samples: List, combine_tensors: bool = True, combine_scalars: bool = True +) -> Dict: + """Take a list of samples (as dictionary) and create a batch, preserving the keys. + If `tensors` is True, `ndarray` objects are combined into + tensor batches. + :param dict samples: list of samples + :param bool tensors: whether to turn lists of ndarrays into a single ndarray + :returns: single sample consisting of a batch + :rtype: dict + """ + keys = set.intersection(*[set(sample.keys()) for sample in samples]) + batched = {key: [] for key in keys} + + for s in samples: + [batched[key].append(s[key]) for key in batched] + + result = {} + for key in batched: + if isinstance(batched[key][0], (int, float)): + if combine_scalars: + result[key] = np.array(list(batched[key])) + elif isinstance(batched[key][0], torch.Tensor): + if combine_tensors: + result[key] = torch.stack(list(batched[key])) + elif isinstance(batched[key][0], np.ndarray): + if combine_tensors: + result[key] = np.array(list(batched[key])) + else: + result[key] = list(batched[key]) + + del samples + del batched + return result + + +def flatten_list( + results: List[Dict[str, List[Union[str, bool]]]] +) -> Dict[str, List[Union[str, bool]]]: + flatten_results = {} + for res in results: + for k in res: + if k not in flatten_results: + flatten_results[k] = res[k] + else: + flatten_results[k].extend(res[k]) + return flatten_results + + +class Evaluator: + name: str = "" + dataset_path: Optional[str] = None + dataset_split: str = "test" + + def __init__( + self, + device: str = "cuda", + verbose: bool = False, + shuffle: bool = False, + buffer_size: int = 1000, + slice_indices: Optional[Tuple[int]] = None, + dataset_kwargs: Optional[dict] = None, + loader_kwargs: Optional[dict] = None, + ): + self.dataset: Union[Dataset, IterableDataset] = None + self.verbose = verbose + self.shuffle = shuffle + self.buffer_size = buffer_size + self.slice_indices = slice_indices + self.dataset_kwargs = dataset_kwargs if dataset_kwargs else {} + self.loader_kwargs = loader_kwargs if loader_kwargs else {} + self.collate_fn = None + self.dataset = self._load_dataset(slice_indices=slice_indices) + + def slice_dataset( + self, dataset: Dataset, slice_indices: Optional[Tuple[int]] = None + ) -> Dataset: + assert not self.dataset_kwargs.get( + "streaming", False + ), "You can't slice IterableDataset" + start, end = slice_indices + indices = list(range(start, end)) + dataset = dataset.select(indices) + return dataset + + def _load_dataset(self, slice_indices: Optional[Tuple[int]] = None): + path = self.dataset_kwargs.pop("path", self.dataset_path) + split = self.dataset_kwargs.pop("split", self.dataset_split) + dataset = load_dataset(path=path, split=split, **self.dataset_kwargs) + if slice_indices is not None: + dataset = self.slice_dataset(dataset, slice_indices) + return dataset + + @abstractmethod + def evaluate(self, model: LLMInference, example: Dict[str, Any]) -> Dict[str, Any]: + pass + + @abstractmethod + def compute_score( + self, results: List[Dict[str, Any]] + ) -> Dict[str, Union[int, float]]: + pass + + def prepare_data(self) -> Union[Dataset, IterableDataset]: + dataset = self.dataset + if self.shuffle: + if isinstance(self.dataset, Dataset): + dataset = dataset.shuffle(seed=42) + elif isinstance(self.dataset, IterableDataset): + dataset = dataset.shuffle(seed=42, buffer_size=self.buffer_size) + else: + raise RuntimeError(f"{type(self.dataset)} is not supported") + return dataset + + def prepare_loader(self, dataset): + if self.collate_fn is None: + self.collate_fn = self.loader_kwargs.pop("collate_fn", dict_collation_fn) + if not callable(self.collate_fn): + self.collate_fn = instantiate_from_config(self.collate_fn) + return DataLoader(dataset, collate_fn=self.collate_fn, **self.loader_kwargs) + + def __call__(self, model: LLMInference) -> EvalOutput: + results = [] + dataset = self.prepare_data() + dataloader = self.prepare_loader(dataset) + for example in tqdm(dataloader, desc=f"Evaluating {self.name}"): + res = self.evaluate(model, example) + results.append(res) + results = flatten_list(results) + metrics = self.compute_score(results) + return EvalOutput(metrics=metrics, results=results) diff --git a/evomerge/eval/ja_vg_vqa.py b/evomerge/eval/ja_vg_vqa.py new file mode 100644 index 0000000..cd636f5 --- /dev/null +++ b/evomerge/eval/ja_vg_vqa.py @@ -0,0 +1,68 @@ +""" +Japanese VQA of Visual Genome +https://github.com/yahoojapan/ja-vg-vqa +""" + +from typing import Any, Dict, List, Union + +from ..models import VLMInference +from .base import Evaluator +from .metrics import LanguageDetector, rouge_ja + + +def extract_qa(example): + qa_list = example["qas"] + # TODO: for now, always take the first example + # should we evaluate everything? Or, take one example randomely? + qa = qa_list[0] + example["question"] = qa["question"] + example["answer"] = qa["answer"] + return example + + +class JaVGVQA(Evaluator): + name = "JaVGVQA" + dataset_path = "SakanaAI/ja-vg-vqa-2500" + dataset_split = "test" + + def __init__(self, strict_japanese: bool = True, **kwargs): + super().__init__(**kwargs) + # extract qa + self.dataset = self.dataset.map(extract_qa) + # filter column + self.dataset = self.dataset.select_columns(["question", "answer", "image"]) + self.lang_detect = None + if strict_japanese: + self.lang_detect = LanguageDetector() + + def evaluate( + self, model: VLMInference, example: Dict[str, List[Any]] + ) -> Dict[str, List[Union[str, bool]]]: + question = example["question"] + answer = example["answer"] + image = example["image"] + + # generate responses + resps = model(text=question, image=image) + return { + "question": question, + "answer": answer, + "prediction": resps, + } + + def compute_score(self, results: Dict[str, List[Any]]) -> Dict[str, float]: + res_dict = rouge_ja(refs=results["answer"], preds=results["prediction"]) + # detect Japanese by fasttext and replace empty string if it's not Ja + if self.lang_detect: + preds = [] + for answer, pred in zip(results["answer"], results["prediction"]): + # if answer is English, pass + if self.lang_detect(answer).get("__label__ja", 0.0) >= 0.5: + res = self.lang_detect(pred) + if res.get("__label__ja", 0.0) < 0.5: + pred = "" + preds.append(pred) + res_dict_ja = rouge_ja(refs=results["answer"], preds=preds) + res_dict_ja = {f"{k}_ja": v for k, v in res_dict_ja.items()} + res_dict.update(res_dict_ja) + return res_dict diff --git a/evomerge/eval/ja_vlm_wild.py b/evomerge/eval/ja_vlm_wild.py new file mode 100644 index 0000000..137919e --- /dev/null +++ b/evomerge/eval/ja_vlm_wild.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Union + +from ..models import VLMInference +from .base import Evaluator +from .metrics import LanguageDetector, rouge_ja + + +class JaVLMBenchIntheWild(Evaluator): + name = "JaVLMBenchIntheWild" + dataset_path = "SakanaAI/japanese-vlm-bench-in-the-wild" + dataset_split = "test" + + def __init__(self, strict_japanese: bool = True, **kwargs): + super().__init__(**kwargs) + self.dataset = self.dataset.select_columns(["question", "answer", "image"]) + self.lang_detect = None + if strict_japanese: + self.lang_detect = LanguageDetector() + + def evaluate( + self, model: VLMInference, example: Dict[str, List[Any]] + ) -> Dict[str, List[Union[str, bool]]]: + question = example["question"] + answer = example["answer"] + image = example["image"] + + # generate responses + resps = model(text=question, image=image) + return { + "question": question, + "answer": answer, + "prediction": resps, + } + + def compute_score(self, results: Dict[str, List[Any]]) -> Dict[str, float]: + res_dict = rouge_ja(refs=results["answer"], preds=results["prediction"]) + # detect Japanese by fasttext and replace empty string if it's not Ja + if self.lang_detect: + preds = [] + for answer, pred in zip(results["answer"], results["prediction"]): + # if answer is English, pass + if self.lang_detect(answer).get("__label__ja", 0.0) >= 0.5: + res = self.lang_detect(pred) + if res.get("__label__ja", 0.0) < 0.5: + pred = "" + preds.append(pred) + res_dict_ja = rouge_ja(refs=results["answer"], preds=preds) + res_dict_ja = {f"{k}_ja": v for k, v in res_dict_ja.items()} + res_dict.update(res_dict_ja) + return res_dict diff --git a/evomerge/eval/metrics.py b/evomerge/eval/metrics.py new file mode 100644 index 0000000..1847fbb --- /dev/null +++ b/evomerge/eval/metrics.py @@ -0,0 +1,100 @@ +import re + +from huggingface_hub import hf_hub_download +from rouge_score import rouge_scorer, scoring +from fasttext.FastText import _FastText + + +def rouge(refs, preds): + """ + Returns `t5` style ROUGE scores. See the related implementation: + https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68 + + :param refs: + A `list` of reference `strs`. + :param preds: + A `list` of predicted `strs`. + """ + rouge_types = ["rouge1", "rouge2", "rougeL"] + scorer = rouge_scorer.RougeScorer(rouge_types) + + def _prepare_summary(summary): + summary = summary.replace(" . ", ".\n") + return summary + + # Accumulate confidence intervals. + aggregator = scoring.BootstrapAggregator() + for ref, pred in zip(refs, preds): + ref = _prepare_summary(ref) + pred = _prepare_summary(pred) + aggregator.add_scores(scorer.score(ref, pred)) + result = aggregator.aggregate() + return {type: result[type].mid.fmeasure * 100 for type in rouge_types} + + +class MecabTokenizer: + def __init__(self) -> None: + from fugashi import Tagger + + self.tagger = Tagger("-Owakati") + + def normalize_answer(self, text): + """Lower case text, remove punctuation and extra whitespace, etc.""" + import emoji + import neologdn + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_emoji(text): + text = "".join(["" if emoji.is_emoji(c) else c for c in text]) + emoji_pattern = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F1E0-\U0001F1FF" # flags (iOS) + "\U00002702-\U000027B0" + "]+", + flags=re.UNICODE, + ) + return emoji_pattern.sub(r"", text) + + text = remove_emoji(text) + # see neologdn docs for details, but handles things like full/half width variation + text = neologdn.normalize(text) + text = white_space_fix(text) + return text + + def tokenize(self, text): + return self.tagger.parse(self.normalize_answer(text)).split() + + +def rouge_ja(refs, preds): + """This uses a MeCab tokenizer for Japanese text.""" + tokenizer = MecabTokenizer() + rouge_types = ["rouge1", "rouge2", "rougeL"] + # mecab-based rouge + scorer = rouge_scorer.RougeScorer( + rouge_types, + tokenizer=tokenizer, + ) + + # Accumulate confidence intervals. + aggregator = scoring.BootstrapAggregator() + for ref, pred in zip(refs, preds): + aggregator.add_scores(scorer.score(ref, pred)) + result = aggregator.aggregate() + return {type: result[type].mid.fmeasure * 100 for type in rouge_types} + + +class LanguageDetector: + def __init__(self): + repo_id = "mkshing/fasttext-language-detection" + model_path = hf_hub_download( + repo_id=repo_id, repo_type="model", filename="model.bin" + ) + self.model = _FastText(model_path) + + def __call__(self, text: str) -> dict: + return dict(zip(*self.model.predict(text.replace("\n", ""), k=-1))) diff --git a/evomerge/models/__init__.py b/evomerge/models/__init__.py new file mode 100644 index 0000000..a374f09 --- /dev/null +++ b/evomerge/models/__init__.py @@ -0,0 +1,5 @@ +from .base import LLMInference, VLMInference +from .jsvlm import JSVLM +from .llava import LLaVA +from .causallm import CausalLM +from .prompt_templates import PROMPT_TEMPLATES diff --git a/evomerge/models/base.py b/evomerge/models/base.py new file mode 100644 index 0000000..aeeecad --- /dev/null +++ b/evomerge/models/base.py @@ -0,0 +1,158 @@ +import gc +import logging +from abc import abstractmethod +from functools import partial +from typing import List, Optional, Union + +import deepspeed +import torch +from PIL import Image +from torch import nn +from transformers import AutoModelForCausalLM +from transformers.modeling_utils import ModuleUtilsMixin + +from .prompt_templates import PROMPT_TEMPLATES + +logger = logging.getLogger(__name__) +STR2DTYPE = { + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, + "torch.float32": torch.float32, + "auto": "auto", +} + + +INFERENCE_ENGINES = ["deepspeed", "vllm"] + + +class LLMInference(nn.Module, ModuleUtilsMixin): + default_template: str = None + default_generation_config: dict = None + + def __init__( + self, + model_path: str = None, + llm_model_path: str = None, + template: Optional[str] = None, + verbose: bool = False, + device: Optional[Union[str, torch.device]] = None, + model_kwargs: Optional[dict] = None, + generation_config: Optional[dict] = None, + torch_dtype: Optional[str] = None, + inference_engine: Optional[str] = None, + ): + super().__init__() + if inference_engine is not None: + assert ( + inference_engine in INFERENCE_ENGINES + ), f"{inference_engine} is not supported. You must choose from {INFERENCE_ENGINES}" + + self.model = None + self.tokenizer = None + self.processor = None + self.verbose = verbose + if template is None: + template = self.default_template + if template in PROMPT_TEMPLATES: + template = PROMPT_TEMPLATES[template] + logger.info(f"prompt template:\n{template}") + + self.llm_model_path = llm_model_path + self.template = template + self.model_kwargs = model_kwargs if model_kwargs else {} + self.generation_config = ( + generation_config if generation_config else self.default_generation_config + ) + torch_dtype = self.model_kwargs.pop("torch_dtype", torch_dtype) + if torch_dtype is not None and isinstance(torch_dtype, str): + self.model_kwargs["torch_dtype"] = STR2DTYPE[torch_dtype] + self.inference_engine = inference_engine + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self._device = device + + def post_init(self): + if self.llm_model_path is not None: + logger.info(f"loading llm from {self.llm_model_path}") + llm = AutoModelForCausalLM.from_pretrained( + self.llm_model_path, **self.model_kwargs + ).eval() + self.set_llm(llm) + logger.info(f"Successfully set the llm to VLM!") + self.prepare_model() + + def _deepspeed_init(self): + self.model = deepspeed.init_inference( + self.model, + dtype=self.model.dtype, + replace_with_kernel_inject=True, + ) + + def prepare_model(self): + self.to(self._device) + if self.inference_engine == "deepspeed": + self._deepspeed_init() + elif self.inference_engine == "vllm": + pass + # free unused memory + gc.collect() + torch.cuda.empty_cache() + + @abstractmethod + def build_prompt(self, text: Union[str, List[str]]) -> List[str]: + pass + + def set_template(self, template: str): + if template in PROMPT_TEMPLATES: + template = PROMPT_TEMPLATES[template] + self.template = template + + def set_llm(self, llm: torch.nn.Module): + self.model.set_llm(llm) + + def get_output_ids(self, input_ids: torch.Tensor, output_ids: torch.Tensor): + input_token_len = input_ids.shape[1] + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) + if n_diff_input_output > 0: + logger.warn( + f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" + ) + return output_ids[:, input_token_len:] + + def forward( + self, + text: Union[str, List[str]], + **generation_config, + ) -> List[str]: + pass + + +# for deepspeed +def get_input_embeddings(self): + llm = self.language_model + if not hasattr(llm, "get_input_embeddings"): + return llm.module.get_input_embeddings() + return llm.get_input_embeddings() + + +class VLMInference(LLMInference): + def _deepspeed_init(self): + self.model.model.language_model = deepspeed.init_inference( + self.model.model.language_model, + dtype=self.model.dtype, + replace_with_kernel_inject=True, + ) + # monkey patching + self.model.model.get_input_embeddings = partial( + get_input_embeddings, self=self.model.model + ) + + def forward( + self, + text: Union[str, List[str]], + image: Optional[Union[Image.Image, List[Image.Image]]], + **generation_config, + ) -> List[str]: + pass diff --git a/evomerge/models/causallm.py b/evomerge/models/causallm.py new file mode 100644 index 0000000..cd30998 --- /dev/null +++ b/evomerge/models/causallm.py @@ -0,0 +1,74 @@ +import logging +from typing import List, Union + +import torch +from PIL import Image +from transformers import AutoTokenizer, AutoModelForCausalLM + +from . import LLMInference +from .prompt_templates import JA_ALPACA_COT_TEMPLATE + +logger = logging.getLogger(__name__) + + +class CausalLM(LLMInference): + default_template = JA_ALPACA_COT_TEMPLATE + # taken from https://github.com/haotian-liu/LLaVA/blob/main/predict.py#L87 + default_generation_config = { + "do_sample": False, + "max_new_tokens": 1024, + "temperature": 0, + "top_p": 1.0, + "repetition_penalty": 1.0, + } + + def __init__(self, model_path: str, **kwargs): + super().__init__(**kwargs) + self.model = ( + AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs) + .eval() + .requires_grad_(False) + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.post_init() + + def build_prompt(self, text: Union[str, List[str]]) -> List[str]: + if isinstance(text, str): + text = [text] + return [self.template.format(input=t) for t in text] + + def forward( + self, + text: Union[str, List[str]], + **generation_config, + ) -> List[str]: + """ + Assume text is question string + """ + if len(generation_config) == 0: + generation_config = self.generation_config + if self.verbose: + logger.info( + f"Setting generation config to default\n{generation_config}" + ) + text = self.build_prompt(text) + if self.verbose: + logger.info( + "Sample of actual inputs:\n" + "-" * 100 + f"\n{text[0]}\n" + "-" * 100 + ) + inputs = self.tokenizer(text=text, padding=True, return_tensors="pt") + # generate + with torch.inference_mode(): + output_ids = self.model.generate( + **inputs.to(self.device), + **generation_config, + ) + # output_ids contains input_ids as well. So, return only output_ids + output_ids = self.get_output_ids( + input_ids=inputs.input_ids, output_ids=output_ids + ) + generated_text = self.processor.batch_decode( + output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + generated_text = [text.strip() for text in generated_text] + return generated_text diff --git a/evomerge/models/jsvlm.py b/evomerge/models/jsvlm.py new file mode 100644 index 0000000..5684421 --- /dev/null +++ b/evomerge/models/jsvlm.py @@ -0,0 +1,81 @@ +# https://huggingface.co/stabilityai/japanese-stable-vlm +import logging +from typing import List, Union + +import torch +from PIL import Image +from transformers import AutoImageProcessor, AutoModelForVision2Seq, AutoTokenizer + +from . import VLMInference +from .prompt_templates import JSVLM_TEMPLATE + +logger = logging.getLogger(__name__) + + +class JSVLM(VLMInference): + default_template = JSVLM_TEMPLATE + default_generation_config = { + "do_sample": False, + "num_beams": 5, + "max_new_tokens": 128, + "min_length": 1, + "repetition_penalty": 1.5, + } + + def __init__(self, model_path: str = "stabilityai/japanese-stable-vlm", **kwargs): + super().__init__(**kwargs) + self.model = AutoModelForVision2Seq.from_pretrained( + model_path, trust_remote_code=True, **self.model_kwargs + ).eval() + self.processor = AutoImageProcessor.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.post_init() + + def build_prompt(self, text: str): + if isinstance(text, str): + text = [text] + return [self.template.format(input=t) for t in text] + + def forward( + self, + text: Union[str, List[str]], + image: Union[Image.Image, List[Image.Image]], + **generation_config, + ) -> List[str]: + if len(generation_config) == 0: + generation_config = self.generation_config + if self.verbose: + logger.info( + f"Setting generation config to default\n{generation_config}" + ) + if not isinstance(image, list): + image = [image] + + text = self.build_prompt(text) + if self.verbose: + logger.info( + "Sample of actual inputs:\n" + "-" * 100 + f"\n{text[0]}\n" + "-" * 100 + ) + assert len(text) == len(image) + inputs = self.processor(images=image, return_tensors="pt") + text_encoding = self.tokenizer( + text, + add_special_tokens=False, + return_tensors="pt", + padding=True, + ) + inputs.update(text_encoding) + # generate + with torch.inference_mode(): + output_ids = self.model.generate( + **inputs.to(self.device, dtype=self.dtype), + **generation_config, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + generated_text = self.tokenizer.batch_decode( + output_ids, skip_special_tokens=True + ) + generated_text = [text.strip() for text in generated_text] + return generated_text diff --git a/evomerge/models/llava.py b/evomerge/models/llava.py new file mode 100644 index 0000000..8623305 --- /dev/null +++ b/evomerge/models/llava.py @@ -0,0 +1,82 @@ +import logging +from typing import List, Union + +import torch +from PIL import Image +from transformers import AutoProcessor, LlavaForConditionalGeneration + +from . import VLMInference +from .prompt_templates import LLAVA_MISTRAL_TEMPLATE + +logger = logging.getLogger(__name__) + + +class LLaVA(VLMInference): + default_template = LLAVA_MISTRAL_TEMPLATE + # taken from https://github.com/haotian-liu/LLaVA/blob/main/predict.py#L87 + default_generation_config = { + "do_sample": True, + "max_new_tokens": 1024, + "temperature": 0.2, + "top_p": 1.0, + "use_cache": True, + } + + def __init__(self, model_path: str = "llava-hf/llava-v1.6-mistral-7b-hf", **kwargs): + super().__init__(**kwargs) + self.model = ( + LlavaForConditionalGeneration.from_pretrained( + model_path, **self.model_kwargs + ) + .eval() + .requires_grad_(False) + ) + self.processor = AutoProcessor.from_pretrained(model_path) + self.post_init() + + def build_prompt(self, text: Union[str, List[str]]) -> List[str]: + if isinstance(text, str): + text = [text] + return [self.template.format(input=f"\n{t}") for t in text] + + def forward( + self, + text: Union[str, List[str]], + image: Union[Image.Image, List[Image.Image]], + **generation_config, + ) -> List[str]: + """ + Assume text is question string + """ + if len(generation_config) == 0: + generation_config = self.generation_config + if self.verbose: + logger.info( + f"Setting generation config to default\n{generation_config}" + ) + if not isinstance(image, list): + image = [image] + text = self.build_prompt(text) + if self.verbose: + logger.info( + "Sample of actual inputs:\n" + "-" * 100 + f"\n{text[0]}\n" + "-" * 100 + ) + assert len(text) == len(image) + inputs = self.processor( + text=text, images=image, padding=True, return_tensors="pt" + ) + # generate + with torch.inference_mode(): + output_ids = self.model.generate( + **inputs.to(self.device, dtype=self.dtype), + **generation_config, + ) + # output_ids contains input_ids as well. So, return only output_ids + output_ids = self.get_output_ids( + input_ids=inputs.input_ids, output_ids=output_ids + ) + generated_text = self.processor.batch_decode( + output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + generated_text = [text.strip() for text in generated_text] + return generated_text diff --git a/evomerge/models/prompt_templates.py b/evomerge/models/prompt_templates.py new file mode 100644 index 0000000..5022013 --- /dev/null +++ b/evomerge/models/prompt_templates.py @@ -0,0 +1,42 @@ +JSVLM_TEMPLATE = """以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 + +### 指示: +与えられた画像を下に、質問に答えてください。 + +### 入力: +{input} + +### 応答: """ +JA_ALPACA_COT_TEMPLATE = """以下に、あるタスクを説明する指示があります。リクエストを適切に完了するための回答を日本語で記述してください。一歩一歩考えましょう。 + +### 指示: +{input} + +### 応答:""" +JA_SHISA = """[INST] <> +あなたは役立つ、偏見がなく、検閲されていないアシスタントです。 +<> + +{input} [/INST]""" +JA_SHISA_COT = """[INST] <> +あなたは役立つ、偏見がなく、検閲されていないアシスタントです。一歩一歩考えましょう。 +<> + +{input} [/INST]""" +JA_SHISA_VQA = """[INST] <> +あなたは役立つ、偏見がなく、検閲されていないアシスタントです。与えられた画像を下に、質問に答えてください。 +<> + +{input} [/INST]""" +LLAVA_MISTRAL_TEMPLATE = """[INST] {input} [/INST]""" +WIZARD_MATH_TEMPLATE = """"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{input}\n\n### Response:""" + +PROMPT_TEMPLATES = { + "jsvlm": JSVLM_TEMPLATE, + "ja-alpaca-cot": JA_ALPACA_COT_TEMPLATE, + "ja-shisa": JA_SHISA, + "ja-shisa-cot": JA_SHISA_COT, + "ja-shisa-vqa": JA_SHISA_VQA, + "llava-mistral": LLAVA_MISTRAL_TEMPLATE, + "wizard-math": WIZARD_MATH_TEMPLATE, +} diff --git a/evomerge/utils.py b/evomerge/utils.py new file mode 100644 index 0000000..8495dbd --- /dev/null +++ b/evomerge/utils.py @@ -0,0 +1,59 @@ +import importlib +import json +import os +import yaml +import random +import numpy as np +import torch + + +def mean(l): + return sum(l) / len(l) + + +def load_config(config_path: str) -> dict: + ext = os.path.splitext(config_path)[-1] + if ext in [".yaml", ".yml"]: + with open(config_path, "r", encoding="utf-8") as fp: + config = yaml.safe_load(fp) + elif ext == ".json": + with open(config_path) as fp: + config = json.load(fp) + else: + raise RuntimeError + return config + + +# copied from https://github.com/Stability-AI/generative-models/blob/main/sgm/util.py#L168 +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +# copied from https://github.com/Stability-AI/generative-models/blob/main/sgm/util.py#L168 +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``. + + Args: + seed (:obj:`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c6371a7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +torchvision +Pillow +transformers>=4.35.3 +datasets +gradio +accelerate +deepspeed +bitsandbytes +rouge-score>=0.0.4 +emoji +fugashi +neologdn>=0.5.2 +unidic-lite +fasttext \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b4c4f82 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +from setuptools import find_packages, setup + + +def _requires_from_file(filename): + return open(filename).read().splitlines() + + +setup( + name="evomerge", + version="0.1.0", + author="Sakana AI", + url="https://github.com/SakanaAI/evolving-merged-models", + description="Evolutionary Optimization of Model Merging Recipes", + install_requires=_requires_from_file('requirements.txt'), + packages=find_packages(exclude=("examples")), + license = 'MIT', + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", +) \ No newline at end of file