diff --git a/spotiflow/cli/predict.py b/spotiflow/cli/predict.py index 0a518e6..0f99e7e 100644 --- a/spotiflow/cli/predict.py +++ b/spotiflow/cli/predict.py @@ -70,9 +70,19 @@ def get_args(): type=str, required=False, default="auto", choices=["auto", "cpu", "cuda", "mps"], help="Device to run model on. Defaults to 'auto'.") + utils = parser.add_argument_group(title="Utility arguments", + description="Diverse utility arguments, e.g. I/O related.") + utils.add_argument("--exclude-hidden-files", action="store_true", required=False, default=False, help="Exclude hidden files in the input directory. Defaults to False.") + args = parser.parse_args() return args +def _imread_wrapped(fname): + try: + return imread(fname) + except Exception as e: + log.error(f"Could not read image {fname}. Execution will halt.") + raise e def main(): # Get arguments from command line @@ -108,6 +118,8 @@ def main(): image_files = sorted( tuple(chain(*tuple(args.data_path.glob(f"*.{ext}") for ext in ALLOWED_EXTENSIONS))) ) + if args.exclude_hidden_files: + image_files = tuple(f for f in image_files if not f.name.startswith(".")) if len(image_files) == 0: raise ValueError(f"No valid image files found in directory {args.data_path}. Allowed extensions are: {ALLOWED_EXTENSIONS}") if out_dir is None: @@ -119,7 +131,7 @@ def main(): out_dir.mkdir(exist_ok=True, parents=True) # Predict spots in images and write to CSV - images = [imread(img) for img in image_files] + images = [_imread_wrapped(img) for img in image_files] for img, fname in tqdm(zip(images, image_files), desc="Predicting", total=len(images)): spots, _ = model.predict(img, prob_thresh=args.probability_threshold,