Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenFlamingo #2237

Merged
merged 44 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
915f7cc
add openflamingo
michiyasunaga Jan 14, 2024
367aa30
fix ver
michiyasunaga Jan 14, 2024
799f41f
fix ver
michiyasunaga Jan 14, 2024
0b7bcb4
fix ver
michiyasunaga Jan 14, 2024
70a8e34
fix ver
michiyasunaga Jan 14, 2024
1eef0a5
fix ver
michiyasunaga Jan 14, 2024
abe19e3
fix ver
michiyasunaga Jan 14, 2024
b221cbf
fix ver
michiyasunaga Jan 14, 2024
63011e7
fix ver
michiyasunaga Jan 14, 2024
7729aca
add openflamingo
michiyasunaga Jan 14, 2024
1d23a9a
add openflamingo
michiyasunaga Jan 14, 2024
d27893a
add openflamingo
michiyasunaga Jan 14, 2024
c9fd286
add openflamingo
michiyasunaga Jan 14, 2024
5367f05
add openflamingo
michiyasunaga Jan 14, 2024
0afa051
add openflamingo
michiyasunaga Jan 14, 2024
014aada
add openflamingo
michiyasunaga Jan 14, 2024
1bee70c
add openflamingo
michiyasunaga Jan 14, 2024
3613907
fix GHA build - define openflamingo dependencies
teetone Jan 16, 2024
1affed4
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Jan 16, 2024
9c1ad0a
address code review
michiyasunaga Jan 20, 2024
1ca8408
fix transformers version
michiyasunaga Feb 8, 2024
292456e
Merge branch 'main' into michi_openflamingo
michiyasunaga Feb 8, 2024
12d535e
merge main
JosselinSomervilleRoberts Feb 22, 2024
84cc573
Add some parameters to the model deployment
JosselinSomervilleRoberts Feb 22, 2024
3756295
Fixing einops dependency conflict
JosselinSomervilleRoberts Feb 22, 2024
9dc70dc
Remove duplicated crfm-helm['image'] dependency
JosselinSomervilleRoberts Feb 22, 2024
6f531ac
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Feb 26, 2024
8c44df7
more logging for model init
teetone Feb 27, 2024
1fd7961
fix token init in openflamingo
teetone Feb 27, 2024
472bacb
fix token init in openflamingo
teetone Feb 27, 2024
5a64bb3
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Feb 28, 2024
65e2950
resolve
teetone Feb 28, 2024
0ea9ced
fix tokenizer
teetone Feb 28, 2024
6128ad2
update conf
teetone Feb 28, 2024
163b13c
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Feb 28, 2024
e51e550
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Feb 28, 2024
865b33a
disable temporarily
teetone Feb 28, 2024
a4f0586
resolve merge conflicts
teetone Mar 4, 2024
8f0e763
undo
teetone Mar 4, 2024
e2404a9
fix paths
teetone Mar 4, 2024
83eaefa
get in-context learning examples to work
teetone Mar 4, 2024
1404de4
fix decoding
teetone Mar 4, 2024
ac19049
fix sequence construction
teetone Mar 4, 2024
8c6cdcb
include num_completions in cache key
teetone Mar 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ certifi==2024.2.2
cffi==1.16.0
cfgv==3.4.0
charset-normalizer==2.1.1
Cython==0.29.32
einops-exts==0.0.4
emoji==2.1.0
et-xmlfile==1.1.0
chex==0.1.7
click==8.1.7
clip-anytorch==2.5.2
Expand Down Expand Up @@ -145,6 +149,10 @@ nltk==3.8.1
nodeenv==1.8.0
NudeNet==2.0.9
numba==0.56.4
open-clip-torch==2.24.0
openpyxl==3.0.10
outcome==1.2.0
pathy==0.10.2
numpy==1.23.5
oauthlib==3.2.2
omegaconf==2.3.0
Expand Down Expand Up @@ -253,6 +261,15 @@ timm==0.6.13
tokenizers==0.15.2
toml==0.10.2
tomli==2.0.1
trio==0.22.0
trio-websocket==0.9.2
types-Pillow==9.3.0.4
types-pytz==2022.4.0.0
types-redis==4.3.21.1
types-requests==2.28.11.2
types-tabulate==0.9.0.0
types-urllib3==1.26.25
typing==3.7.4.3
toolz==0.12.1
torch~=2.1.2
torch-fidelity==0.3.0
Expand Down Expand Up @@ -281,3 +298,4 @@ xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
zstandard==0.18.0
fairlearn==0.9.0
12 changes: 9 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,14 @@ models =
crfm-helm[yandex]

vlm =
# For OpenFlamingo
einops~=0.7.0
einops-exts~=0.0.4
open-clip-torch~=2.24.0

# VLM models
crfm-helm[openai]
torch~=2.1.2 # For IDEFICS
torch~=2.1.2 # For IDEFICS

# VLM scenarios
crfm-helm[images]
Expand Down Expand Up @@ -178,7 +183,7 @@ heim =
crfm-helm[openai]

# For model, kakaobrain/mindall-e
einops~=0.6.0
einops~=0.7.0
omegaconf~=2.3.0
pytorch-lightning~=2.0.5

Expand Down Expand Up @@ -259,6 +264,7 @@ exclude =
venv/*
src/helm/clients/image_generation/dalle_mini/*
src/helm/clients/image_generation/mindalle/*
src/helm/clients/vision_language/open_flamingo/*

# Ignore completely:
# E203 - White space before ':', (conflicts with black)
Expand All @@ -276,7 +282,7 @@ check_untyped_defs = True
disable_error_code = annotation-unchecked
# TODO: Change disallow_untyped_defs to True
disallow_untyped_defs = False
exclude = dalle_mini|mindalle
exclude = dalle_mini|mindalle|open_flamingo

[tool:pytest]
addopts =
Expand Down
4 changes: 4 additions & 0 deletions src/helm/benchmark/model_metadata_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@
IDEFICS_MODEL_TAG: str = "IDEFICS_MODEL_TAG"
# Llava should use a special prompt format (see `LlavaRunExpander`)
LLAVA_MODEL_TAG: str = "LLAVA_MODEL_TAG"
# OpenFlamingo has a special prompt format (see `OpenFlamingoRunExpander`)
OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG"
# Some VLMs do not support multiple images in the prompt
LIMITED_FUNCTIONALITY_VLM_TAG: str = "LIMITED_FUNCTIONALITY_VLM_TAG"
FULL_FUNCTIONALITY_VLM_TAG: str = "FULL_FUNCTIONALITY_VLM_TAG"


# Frozen is set to false as the model_deployment_registry.py file
# might populate the deployment_names field.


@dataclass(frozen=False)
class ModelMetadata:
name: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ entries: [
# sheetmusic2lilypond
{description: "sheetmusic2lilypond:model=vlm", priority: 1}

# webpages
{description: "image2webpage:subset=css,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=html,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=javascript,model=vlm", priority: 1, groups: ["image2webpage"]}

# chart2csv
# {description: "chart2csv:model=vlm", priority: 1}
]
20 changes: 20 additions & 0 deletions src/helm/benchmark/run_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,26 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]:
]


class OpenFlamingoRunExpander(RunExpander):
"""
Custom prompt for OpenFlamingo following: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b
"""

name = "open_flamingo"

def expand(self, run_spec: RunSpec) -> List[RunSpec]:
return [
replace(
run_spec,
name=run_spec.name,
adapter_spec=replace(
run_spec.adapter_spec,
input_prefix=f"<|endofchunk|>{run_spec.adapter_spec.input_prefix}",
),
),
]


class FormatPromptRunExpander(RunExpander):
"""Adds a prefix and suffix to the prompt."""

Expand Down
6 changes: 6 additions & 0 deletions src/helm/benchmark/run_spec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GOOGLE_PALM_2_MODEL_TAG,
IDEFICS_INSTRUCT_MODEL_TAG,
LLAVA_MODEL_TAG,
OPEN_FLAMINGO_MODEL_TAG,
NLG_PREFIX_TAG,
NO_NEWLINES_TAG,
OPENAI_CHATGPT_MODEL_TAG,
Expand All @@ -33,6 +34,7 @@
IDEFICSInstructRunExpander,
IncreaseTemperatureRunExpander,
LlavaRunExpander,
OpenFlamingoRunExpander,
OpenAIRunExpander,
MistralRunExpander,
StopRunExpander,
Expand Down Expand Up @@ -147,6 +149,10 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec:
if LLAVA_MODEL_TAG in model.tags:
run_spec = singleton(LlavaRunExpander().expand(run_spec))

# OpenFlamingo
if OPEN_FLAMINGO_MODEL_TAG in model.tags:
run_spec = singleton(OpenFlamingoRunExpander().expand(run_spec))

# For multiple choice
if BUGGY_TEMP_0_TAG in model.tags and run_spec.adapter_spec.temperature == 0:
increase_temperature_expander = IncreaseTemperatureRunExpander(value=1e-4)
Expand Down
2 changes: 2 additions & 0 deletions src/helm/clients/vision_language/open_flamingo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .src.flamingo import Flamingo
from .src.factory import create_model_and_transforms
Empty file.
147 changes: 147 additions & 0 deletions src/helm/clients/vision_language/open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Source: https://github.com/mlfoundations/open_flamingo
"""

from typing import Optional

from transformers import AutoModelForCausalLM, AutoTokenizer

from helm.common.general import handle_module_not_found_error
from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import extend_instance


def create_model_and_transforms(
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_encoder_path: str,
tokenizer_path: str,
cross_attn_every_n_layers: int = 1,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
cache_dir: Optional[str] = None,
**flamingo_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Appends special tokens to the tokenizer and freezes backbones.

Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_encoder_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
try:
import open_clip
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["vlm"])

vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path,
pretrained=clip_vision_encoder_pretrained,
cache_dir=cache_dir,
)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True

text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})

lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)

# hacks for MPT-1B, which doesn't have a get_input_embeddings method
if "mpt-1b-redpajama-200b" in lang_encoder_path:

class EmbeddingFnMixin:
def get_input_embeddings(self):
return self.transformer.wte

def set_input_embeddings(self, new_embeddings):
self.transformer.wte = new_embeddings

extend_instance(lang_encoder, EmbeddingFnMixin)

# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)

if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))

model = Flamingo(
vision_encoder,
lang_encoder,
text_tokenizer.encode("<|endofchunk|>")[-1],
text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
cross_attn_every_n_layers=cross_attn_every_n_layers,
**flamingo_kwargs,
)

# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0

# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
model.perceiver.requires_grad_(True)
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
if not freeze_lm_embeddings:
model.lang_encoder.get_input_embeddings().requires_grad_(True)
# TODO: investigate also training the output embeddings when untied

print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)

return model, image_processor, text_tokenizer


def _infer_decoder_layers_attr_name(model):
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

raise ValueError(
"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. "
"Please supply this string manually."
)


__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"llama": "model.layers",
"gptneoxforcausallm": "gpt_neox.layers",
"mpt": "transformer.blocks",
"mosaicgpt": "transformer.blocks",
}
Loading