diff --git a/README.md b/README.md index 27984b0c..acdce7af 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Install with PIP pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 # install clip-interrogator -pip install clip-interrogator==0.5.4 +pip install clip-interrogator==0.5.5 ``` You can then use it in your script @@ -67,3 +67,17 @@ The `Config` object lets you configure CLIP Interrogator's processing. On systems with low VRAM you can call `config.apply_low_vram_defaults()` to reduce the amount of VRAM needed (at the cost of some speed and quality). The default settings use about 6.3GB of VRAM and the low VRAM settings use about 2.7GB. See the [run_cli.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_cli.py) and [run_gradio.py](https://github.com/pharmapsychotic/clip-interrogator/blob/main/run_gradio.py) for more examples on using Config and Interrogator classes. + + +## Ranking against your own list of terms + +```python +from clip_interrogator import Config, Interrogator, LabelTable, load_list +from PIL import Image + +ci = Interrogator(Config(blip_model_type=None)) +image = Image.open(image_path).convert('RGB') +table = LabelTable(load_list('terms.txt'), 'terms', ci) +best_match = table.rank(ci.image_to_features(image), top_count=1)[0] +print(best_match) +``` \ No newline at end of file diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 4317c318..9a2936ad 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ -from .clip_interrogator import Interrogator, Config +from .clip_interrogator import Config, Interrogator, LabelTable, load_list -__version__ = '0.5.4' +__version__ = '0.5.5' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 284ef941..5d936fe1 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -29,20 +29,20 @@ @dataclass class Config: # models can optionally be passed in directly - blip_model: BLIP_Decoder = None + blip_model: Optional[BLIP_Decoder] = None clip_model = None clip_preprocess = None # blip settings blip_image_eval_size: int = 384 blip_max_length: int = 32 - blip_model_type: str = 'large' # choose between 'base' or 'large' + blip_model_type: Optional[str] = 'large' # use 'base', 'large' or None blip_num_beams: int = 8 blip_offload: bool = False # clip settings clip_model_name: str = 'ViT-L-14/openai' - clip_model_path: str = None + clip_model_path: Optional[str] = None clip_offload: bool = False # interrogator settings @@ -68,7 +68,7 @@ def __init__(self, config: Config): self.blip_offloaded = True self.clip_offloaded = True - if config.blip_model is None: + if config.blip_model is None and config.blip_model_type: if not config.quiet: print("Loading BLIP model...") blip_path = os.path.dirname(inspect.getfile(blip_decoder)) @@ -121,17 +121,17 @@ def load_clip_model(self): trending_list.extend(["featured on "+site for site in sites]) trending_list.extend([site+" contest winner" for site in sites]) - raw_artists = _load_list(config.data_path, 'artists.txt') + raw_artists = load_list(config.data_path, 'artists.txt') artists = [f"by {a}" for a in raw_artists] artists.extend([f"inspired by {a}" for a in raw_artists]) self._prepare_clip() - self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config) - self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model, self.tokenize, config) - self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model, self.tokenize, config) - self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) - self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) - self.negative = LabelTable(_load_list(config.data_path, 'negative.txt'), "negative", self.clip_model, self.tokenize, config) + self.artists = LabelTable(artists, "artists", self) + self.flavors = LabelTable(load_list(config.data_path, 'flavors.txt'), "flavors", self) + self.mediums = LabelTable(load_list(config.data_path, 'mediums.txt'), "mediums", self) + self.movements = LabelTable(load_list(config.data_path, 'movements.txt'), "movements", self) + self.trendings = LabelTable(trending_list, "trendings", self) + self.negative = LabelTable(load_list(config.data_path, 'negative.txt'), "negative", self) end_time = time.time() if not config.quiet: @@ -183,6 +183,7 @@ def check(addition: str, idx: int) -> bool: return best_prompt def generate_caption(self, pil_image: Image) -> str: + assert self.blip_model is not None, "No BLIP model loaded." self._prepare_blip() size = self.config.blip_image_eval_size @@ -310,13 +311,14 @@ def _prepare_clip(self): class LabelTable(): - def __init__(self, labels:List[str], desc:str, clip_model, tokenize, config: Config): + def __init__(self, labels:List[str], desc:str, ci: Interrogator): + clip_model, config = ci.clip_model, ci.config self.chunk_size = config.chunk_size self.config = config self.device = config.device self.embeds = [] self.labels = labels - self.tokenize = tokenize + self.tokenize = ci.tokenize hash = hashlib.sha256(",".join(labels).encode()).hexdigest() sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') @@ -423,11 +425,6 @@ def _download_file(url: str, filepath: str, chunk_size: int = 4*1024*1024, quiet progress.update(len(chunk)) progress.close() -def _load_list(data_path: str, filename: str) -> List[str]: - with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: - items = [line.strip() for line in f.readlines()] - return items - def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: m = LabelTable([], None, None, None, config) for table in tables: @@ -447,3 +444,11 @@ def _truncate_to_fit(text: str, tokenize) -> str: break new_text += ', ' + part return new_text + +def load_list(data_path: str, filename: Optional[str] = None) -> List[str]: + """Load a list of strings from a file.""" + if filename is not None: + data_path = os.path.join(data_path, filename) + with open(data_path, 'r', encoding='utf-8', errors='replace') as f: + items = [line.strip() for line in f.readlines()] + return items diff --git a/setup.py b/setup.py index 81fcfef8..f2806e0b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="clip-interrogator", - version="0.5.4", + version="0.5.5", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',