diff --git a/scripts/amg.py b/scripts/amg.py index f2dbf676a..93c96086a 100644 --- a/scripts/amg.py +++ b/scripts/amg.py @@ -4,15 +4,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import cv2 # type: ignore - -from segment_anything import SamAutomaticMaskGenerator, sam_model_registry - import argparse import json -import os +from pathlib import Path from typing import Any, Dict, List +import cv2 # type: ignore +import tqdm +import torch.utils.data + +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry + parser = argparse.ArgumentParser( description=( "Runs automatic mask generation on an input image or directory of images, " @@ -53,6 +55,9 @@ ) parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") +parser.add_argument("--rank", type=int, default=0, help="Rank of the current process.") +parser.add_argument("--world", type=int, default=1, help="Number of processes.") +parser.add_argument("--num-workers", type=int, default=4, help="Dataloader workers.") parser.add_argument( "--convert-to-rle", @@ -149,13 +154,13 @@ ) -def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: +def write_masks_to_folder(masks: List[Dict[str, Any]], path: Path) -> None: header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa metadata = [header] for i, mask_data in enumerate(masks): mask = mask_data["segmentation"] filename = f"{i}.png" - cv2.imwrite(os.path.join(path, filename), mask * 255) + cv2.imwrite((path / filename).as_posix(), mask * 255) mask_metadata = [ str(i), str(mask_data["area"]), @@ -167,8 +172,7 @@ def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: ] row = ",".join(mask_metadata) metadata.append(row) - metadata_path = os.path.join(path, "metadata.csv") - with open(metadata_path, "w") as f: + with open(path / "metadata.csv", "w") as f: f.write("\n".join(metadata)) return @@ -192,6 +196,21 @@ def get_amg_kwargs(args): return amg_kwargs +class ImageDataset(torch.utils.data.Dataset): + def __init__(self, paths: List[Path], base: Path): + self.paths = paths + self.base = base + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + image = cv2.imread((self.base / path).as_posix()) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return path, image + + def main(args: argparse.Namespace) -> None: print("Loading model...") sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) @@ -200,34 +219,54 @@ def main(args: argparse.Namespace) -> None: amg_kwargs = get_amg_kwargs(args) generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) - if not os.path.isdir(args.input): - targets = [args.input] + args.input = Path(args.input).expanduser().resolve() + args.output = Path(args.output).expanduser().resolve() + + if not args.input.is_dir(): + targets = ImageDataset([args.input.name], args.input.parent) else: targets = [ - f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) + f.relative_to(args.input) for f in args.input.rglob("*") + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] ] - targets = [os.path.join(args.input, f) for f in targets] - - os.makedirs(args.output, exist_ok=True) - - for t in targets: - print(f"Processing '{t}'...") - image = cv2.imread(t) - if image is None: - print(f"Could not load '{t}' as an image, skipping...") - continue - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - + print(f"Found {len(targets)} images in {args.input}.") + + # Per-process split + if args.world > 1: + targets = targets[args.rank::args.world] + print(f"Rank {args.rank}/{args.world} will process {len(targets)} images.") + + # Skip existing + if output_mode == "binary_mask": + targets = [ + f for f in targets + if not Path.is_dir(args.output / f.with_suffix("")) + ] + else: + targets = [ + f for f in targets + if not Path.is_file(args.output / f.with_suffix(".json")) + ] + print(f"Skip already processed images, {len(targets)} remain to do.") + + targets = torch.utils.data.DataLoader( + ImageDataset(targets, args.input), + batch_size=None, + shuffle=False, + num_workers=args.num_workers, + collate_fn=lambda x: x, + ) + + for path, image in tqdm.tqdm(targets, ncols=0): masks = generator.generate(image) - base = os.path.basename(t) - base = os.path.splitext(base)[0] - save_base = os.path.join(args.output, base) if output_mode == "binary_mask": - os.makedirs(save_base, exist_ok=False) - write_masks_to_folder(masks, save_base) + save_dir = args.output / path.with_suffix("") + save_dir.mkdir(parents=True, exist_ok=False) + write_masks_to_folder(masks, save_dir) else: - save_file = save_base + ".json" + save_file = args.output / path.with_suffix(".json") + save_file.parent.mkdir(parents=True, exist_ok=True) with open(save_file, "w") as f: json.dump(masks, f) print("Done!")