Skip to content

Commit

Permalink
Expose LabelTable and load_list and give example in README how they c…
Browse files Browse the repository at this point in the history
…an be used to rank your own list of terms.
  • Loading branch information
pharmapsychotic committed Mar 20, 2023
1 parent d2c6e07 commit ac74904
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 22 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Install with PIP
pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117
# install clip-interrogator
pip install clip-interrogator==0.5.4
pip install clip-interrogator==0.5.5
```

You can then use it in your script
Expand All @@ -67,3 +67,17 @@ The `Config` object lets you configure CLIP Interrogator's processing.
On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB.

See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes.


## Ranking against your own list of terms

```python
from clip_interrogator import Config, Interrogator, LabelTable, load_list
from PIL import Image

ci = Interrogator(Config(blip_model_type=None))
image = Image.open(image_path).convert('RGB')
table = LabelTable(load_list('terms.txt'), 'terms', ci)
best_match = table.rank(ci.image_to_features(image), top_count=1)[0]
print(best_match)
```
4 changes: 2 additions & 2 deletions clip_interrogator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config
from .clip_interrogator import Config, Interrogator, LabelTable, load_list

__version__ = '0.5.4'
__version__ = '0.5.5'
__author__ = 'pharmapsychotic'
41 changes: 23 additions & 18 deletions clip_interrogator/clip_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@
@dataclass
class Config:
# models can optionally be passed in directly
blip_model: BLIP_Decoder = None
blip_model: Optional[BLIP_Decoder] = None
clip_model = None
clip_preprocess = None

# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 32
blip_model_type: str = 'large' # choose between 'base' or 'large'
blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None
blip_num_beams: int = 8
blip_offload: bool = False

# clip settings
clip_model_name: str = 'ViT-L-14/openai'
clip_model_path: str = None
clip_model_path: Optional[str] = None
clip_offload: bool = False

# interrogator settings
Expand All @@ -68,7 +68,7 @@ def __init__(self, config: Config):
self.blip_offloaded = True
self.clip_offloaded = True

if config.blip_model is None:
if config.blip_model is None and config.blip_model_type:
if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
Expand Down Expand Up @@ -121,17 +121,17 @@ def load_clip_model(self):
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])

raw_artists = _load_list(config.data_path, 'artists.txt')
raw_artists = load_list(config.data_path, 'artists.txt')
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])

self._prepare_clip()
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config)
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config)
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config)
self.artists = LabelTable(artists, "artists", self)
self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self)
self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self)
self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self)
self.trendings = LabelTable(trending_list, "trendings", self)
self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self)

end_time = time.time()
if not config.quiet:
Expand Down Expand Up @@ -183,6 +183,7 @@ def check(addition: str, idx: int) -> bool:
return best_prompt

def generate_caption(self, pil_image: Image) -> str:
assert self.blip_model is not None, "No BLIP model loaded."
self._prepare_blip()

size = self.config.blip_image_eval_size
Expand Down Expand Up @@ -310,13 +311,14 @@ def _prepare_clip(self):


class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config):
def __init__(self, labels:List[str], desc:str, ci: Interrogator):
clip_model, config = ci.clip_model, ci.config
self.chunk_size = config.chunk_size
self.config = config
self.device = config.device
self.embeds = []
self.labels = labels
self.tokenize = tokenize
self.tokenize = ci.tokenize

hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_')
Expand Down Expand Up @@ -423,11 +425,6 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet
progress.update(len(chunk))
progress.close()

def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items

def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config)
for table in tables:
Expand All @@ -447,3 +444,11 @@ def _truncate_to_fit(text: str, tokenize) -> str:
break
new_text += ', ' + part
return new_text

def load_list(data_path: str, filename: Optional[str] = None) -> List[str]:
"""Load a list of strings from a file."""
if filename is not None:
data_path = os.path.join(data_path, filename)
with open(data_path, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="clip-interrogator",
version="0.5.4",
version="0.5.5",
license='MIT',
author='pharmapsychotic',
author_email='[email protected]',
Expand Down

0 comments on commit ac74904

Please sign in to comment.