-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtag_segments.py
53 lines (43 loc) · 1.4 KB
/
tag_segments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# %%
from tqdm import tqdm
import pathlib
import PIL.Image
from .llava_encapsulate import LLaVA
import random
def shuffled(iterable):
return random.sample(iterable, len(iterable))
llava = LLaVA(config=LLaVA.LLaVAConfig(llava_id="liuhaotian/llava-v1.5-7b"))
pickle_path = pathlib.Path("../data/n11939491/segments")
categories = [
"petal",
"flower head",
"whole flower",
"center disk",
"stem",
"leaf",
"other",
]
prompt = f"""Please categorize the part highlighted in red as one of: {','.join(categories)}. If the segment is not part of a daisy, answer none.
Only answer with the category."""
def validate_size(image: PIL.Image.Image, size: int = 25):
return image.size[0] >= size and image.size[1] >= size
for mask_file in tqdm(
shuffled(
[
mask
for mask_dir in pickle_path.glob("*")
if mask_dir.is_dir()
for mask in mask_dir.iterdir()
if mask.is_file()
]
)
):
mask_image = PIL.Image.open(mask_file).convert("RGB")
if not validate_size(mask_image := PIL.Image.open(mask_file)):
continue
category = llava.infer(mask_image, prompt).lower()
if category in categories:
outdir = mask_file.parent / "sorted" / category
outdir.mkdir(parents=True, exist_ok=True)
if not (outdir / mask_file.name).exists():
(outdir / mask_file.name).symlink_to(mask_file)