-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathautotag
executable file
·80 lines (68 loc) · 3.51 KB
/
autotag
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
import json
import click
import itertools
import logging
import PIL
from fastai.vision.core import PILImage
from autotagger import Autotagger
from pathlib import Path
from more_itertools import ichunked
@click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140))
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
@click.option("-c", "--csv", is_flag=True, help="Output CSV instead of JSON.")
@click.option("-i", "--input-file", type=click.File(), help="Read list of image filenames from text file.")
@click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
@click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.")
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path))
def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, model):
autotagger = Autotagger(model)
if input_file:
paths = (Path(line.rstrip()) for line in input_file)
elif len(files) > 0:
paths = get_filepaths(files)
else:
click.echo(click.get_current_context().get_help())
click.get_current_context().exit()
for filepaths in ichunked(paths, bs):
paths_with_images = list(filter(None, [open_image(filepath) for filepath in filepaths]))
filepaths = [x[0] for x in paths_with_images]
images = [x[1] for x in paths_with_images]
predictions = autotagger.predict(images, threshold=threshold, limit=limit, bs=bs)
for file, tags in zip(filepaths, predictions):
output_result(Path(file.name), tags, csv, group_tags, name_only)
def output_result(filepath, tags, csv, group_tags, name_only):
name = filepath.stem if name_only else str(filepath)
if csv and group_tags:
tag_names = " ".join(sorted(tags.keys()))
click.echo(f"{name},{tag_names}")
elif csv and not group_tags:
for tag, score in tags.items():
click.echo(f"{name},{tag},{score}")
elif group_tags:
data = { "filename": name, "tags": tags }
click.echo(json.dumps(data))
else:
for tag, score in tags.items():
data = { "filename": name, "tag": tag, "score": score }
click.echo(json.dumps(data))
def get_filepaths(paths):
files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths)
return itertools.chain(*files)
def recurse_dir(directory):
return (path for path in directory.glob("**/*") if not path.is_dir())
def open_image(filepath):
try:
with click.open_file(filepath, "rb") as file:
return (filepath, PILImage.create(file))
except PIL.UnidentifiedImageError as err:
logging.warning(f"Skipped {filepath} (not an image)")
return None
except Exception as err:
logging.warning(f"Skipped {filepath} ({type(err).__name__}: {str(err)})")
return None
if __name__ == "__main__":
main()