Skip to content

Commit

Permalink
Support for many different caption models:
Browse files Browse the repository at this point in the history
blip-base, blip-large, blip2-2.7b, blip2-flan-t5-xl, git-large-coco
  • Loading branch information
pharmapsychotic committed Mar 20, 2023
1 parent ac74904 commit ce9d271
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 128 deletions.
53 changes: 9 additions & 44 deletions clip_interrogator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"id": "3jm8RYrLqvzz"
},
"source": [
"# CLIP Interrogator 2.3 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"# CLIP Interrogator 2.4 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
Expand All @@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "aP9FjmWxtLKJ"
Expand All @@ -42,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xpPKQR40qvz2"
Expand All @@ -54,8 +54,7 @@
"\n",
"def setup():\n",
" install_cmds = [\n",
" ['pip', 'install', 'transformers==4.15.0'],\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'gradio'],\n",
" ['pip', 'install', 'open_clip_torch'],\n",
" ['pip', 'install', 'clip-interrogator'],\n",
" ]\n",
Expand All @@ -65,16 +64,15 @@
"setup()\n",
"\n",
"\n",
"caption_model_name = 'blip-large' #@param [\"blip-base\", \"blip-large\", \"git-large-coco\"]\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"import gradio as gr\n",
"from clip_interrogator import Config, Interrogator\n",
"\n",
"config = Config()\n",
"config.blip_num_beams = 64\n",
"config.blip_offload = False\n",
"config.clip_model_name = clip_model_name\n",
"config.caption_model_name = caption_model_name\n",
"ci = Interrogator(config)\n",
"\n",
"def image_analysis(image):\n",
Expand Down Expand Up @@ -112,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
Expand All @@ -122,40 +120,7 @@
"id": "Pf6qkFG6MPRj",
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n",
"Using Embedded Colab Mode (NEW). If you have issues, please use share=True and file an issue at https://github.com/gradio-app/gradio/\n",
"Note: opening the browser inspector may crash Embedded Colab Mode.\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n https://localhost:${port}${path}\n </a>\n </div>\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"#@title Image to prompt! 🖼️ -> 📝\n",
" \n",
Expand Down Expand Up @@ -291,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15 (default, Nov 24 2022, 18:44:54) [MSC v.1916 64 bit (AMD64)]"
"version": "3.9.5"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
4 changes: 2 additions & 2 deletions clip_interrogator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .clip_interrogator import Config, Interrogator, LabelTable, load_list
from .clip_interrogator import Config, Interrogator, LabelTable, list_caption_models, list_clip_models, load_list

__version__ = '0.5.5'
__version__ = '0.6.0'
__author__ = 'pharmapsychotic'
130 changes: 63 additions & 67 deletions clip_interrogator/clip_interrogator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import hashlib
import inspect
import math
import numpy as np
import open_clip
Expand All @@ -9,18 +8,19 @@
import torch

from dataclasses import dataclass
from blip.models.blip import blip_decoder, BLIP_Decoder
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration
from tqdm import tqdm
from typing import List, Optional

from safetensors.numpy import load_file, save_file

BLIP_MODELS = {
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
'git-large-coco': 'microsoft/git-large-coco', # 1.58GB
}

CACHE_URL_BASE = 'https://huggingface.co/pharma/ci-preprocess/resolve/main/'
Expand All @@ -29,16 +29,15 @@
@dataclass
class Config:
# models can optionally be passed in directly
blip_model: Optional[BLIP_Decoder] = None
caption_model = None
caption_processor = None
clip_model = None
clip_preprocess = None

# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 32
blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None
blip_num_beams: int = 8
blip_offload: bool = False
caption_max_length: int = 32
caption_model_name: Optional[str] = 'blip-large' # use a key from CAPTION_MODELS or None
caption_offload: bool = False

# clip settings
clip_model_name: str = 'ViT-L-14/openai'
Expand All @@ -55,8 +54,8 @@ class Config:
quiet: bool = False # when quiet progress bars are not shown

def apply_low_vram_defaults(self):
self.blip_model_type = 'base'
self.blip_offload = True
self.caption_model_name = 'blip-base'
self.caption_offload = True
self.clip_offload = True
self.chunk_size = 1024
self.flavor_intermediate_count = 1024
Expand All @@ -65,29 +64,33 @@ class Interrogator():
def __init__(self, config: Config):
self.config = config
self.device = config.device
self.blip_offloaded = True
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
self.caption_offloaded = True
self.clip_offloaded = True
self.load_caption_model()
self.load_clip_model()

if config.blip_model is None and config.blip_model_type:
if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip_decoder(
pretrained=BLIP_MODELS[config.blip_model_type],
image_size=config.blip_image_eval_size,
vit=config.blip_model_type,
med_config=med_config
)
blip_model.eval()
if not self.config.blip_offload:
blip_model = blip_model.to(config.device)
self.blip_model = blip_model
def load_caption_model(self):
if self.config.caption_model is None and self.config.caption_model_name:
if not self.config.quiet:
print(f"Loading caption model {self.config.caption_model_name}...")

model_path = CAPTION_MODELS[self.config.caption_model_name]
if self.config.caption_model_name.startswith('git-'):
caption_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
elif self.config.caption_model_name.startswith('blip2-'):
caption_model = Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
else:
caption_model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
self.caption_processor = AutoProcessor.from_pretrained(model_path)

caption_model.eval()
if not self.config.caption_offload:
caption_model = caption_model.to(self.config.device)
self.caption_model = caption_model
else:
self.blip_model = config.blip_model

self.load_clip_model()
self.caption_model = self.config.caption_model
self.caption_processor = self.config.caption_processor

def load_clip_model(self):
start_time = time.time()
Expand All @@ -97,7 +100,7 @@ def load_clip_model(self):

if config.clip_model is None:
if not config.quiet:
print("Loading CLIP model...")
print(f"Loading CLIP model {config.clip_model_name}...")

self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
Expand Down Expand Up @@ -183,26 +186,13 @@ def check(addition: str, idx: int) -> bool:
return best_prompt

def generate_caption(self, pil_image: Image) -> str:
assert self.blip_model is not None, "No BLIP model loaded."
self._prepare_blip()

size = self.config.blip_image_eval_size
gpu_image = transforms.Compose([
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(self.device)

with torch.no_grad():
caption = self.blip_model.generate(
gpu_image,
sample=False,
num_beams=self.config.blip_num_beams,
max_length=self.config.blip_max_length,
min_length=5
)

return caption[0]
assert self.caption_model is not None, "No caption model loaded."
self._prepare_caption()
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(self.device)
if not self.config.caption_model_name.startswith('git-'):
inputs = inputs.to(self.dtype)
tokens = self.caption_model.generate(**inputs, max_new_tokens=self.config.caption_max_length)
return self.caption_processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()

def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip()
Expand Down Expand Up @@ -237,7 +227,7 @@ def interrogate_fast(self, image: Image, max_flavors: int=32, caption: Optional[
are less readable."""
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops), self.tokenize)

Expand All @@ -254,7 +244,7 @@ def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, cap
caption = caption or self.generate_caption(image)
image_features = self.image_to_features(image)

merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self)
flaves = merged.rank(image_features, self.config.flavor_intermediate_count)
best_prompt, best_sim = caption, self.similarity(image_features, caption)
best_prompt = self.chain(image_features, flaves, best_prompt, best_sim, min_count=min_flavors, max_count=max_flavors, desc="Flavor chain")
Expand Down Expand Up @@ -293,18 +283,18 @@ def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> L
similarity = text_features @ image_features.T
return similarity.T[0].tolist()

def _prepare_blip(self):
def _prepare_caption(self):
if self.config.clip_offload and not self.clip_offloaded:
self.clip_model = self.clip_model.to('cpu')
self.clip_offloaded = True
if self.blip_offloaded:
self.blip_model = self.blip_model.to(self.device)
self.blip_offloaded = False
if self.caption_offloaded:
self.caption_model = self.caption_model.to(self.device)
self.caption_offloaded = False

def _prepare_clip(self):
if self.config.blip_offload and not self.blip_offloaded:
self.blip_model = self.blip_model.to('cpu')
self.blip_offloaded = True
if self.config.caption_offload and not self.caption_offloaded:
self.caption_model = self.caption_model.to('cpu')
self.caption_offloaded = True
if self.clip_offloaded:
self.clip_model = self.clip_model.to(self.device)
self.clip_offloaded = False
Expand Down Expand Up @@ -425,8 +415,8 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk))
progress.close()

def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config)
def _merge_tables(tables: List[LabelTable], ci: Interrogator) -> LabelTable:
m = LabelTable([], None, ci)
for table in tables:
m.labels.extend(table.labels)
m.embeds.extend(table.embeds)
Expand All @@ -445,6 +435,12 @@ def _truncate_to_fit(text: str, tokenize) -> str:
new_text += ', ' + part
return new_text

def list_caption_models() -> List[str]:
return list(CAPTION_MODELS.keys())

def list_clip_models() -> List[str]:
return ['/'.join(x) for x in open_clip.list_pretrained()]

def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ requests
safetensors
tqdm
open_clip_torch
blip-ci
transformers>=4.15.0,<=4.26.1
accelerate
transformers>=4.27.1
5 changes: 2 additions & 3 deletions run_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#!/usr/bin/env python3
import argparse
import csv
import open_clip
import os
import requests
import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
from clip_interrogator import Interrogator, Config, list_clip_models

def inference(ci, image, mode):
image = image.convert('RGB')
Expand Down Expand Up @@ -36,7 +35,7 @@ def main():
exit(1)

# validate clip model name
models = ['/'.join(x) for x in open_clip.list_pretrained()]
models = list_clip_models()
if args.clip not in models:
print(f"Could not find CLIP model {args.clip}!")
print(f" available models: {models}")
Expand Down
Loading

0 comments on commit ce9d271

Please sign in to comment.