From 6470ec2fca566cb8b66011283be283d80c3efe46 Mon Sep 17 00:00:00 2001 From: ibarraz5 Date: Wed, 27 Mar 2024 18:58:37 -0600 Subject: [PATCH] use a class attribute for type checking --- segment_anything/utils/amg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/segment_anything/utils/amg.py b/segment_anything/utils/amg.py index be064071e..9c3a4d535 100644 --- a/segment_anything/utils/amg.py +++ b/segment_anything/utils/amg.py @@ -6,7 +6,6 @@ import numpy as np import torch - import math from copy import deepcopy from itertools import product @@ -18,17 +17,18 @@ class MaskData: A structure for storing masks and their related data in batched format. Implements basic filtering and concatenation. """ + SUPPORTED_TYPES = (list, np.ndarray, torch.Tensor) def __init__(self, **kwargs) -> None: for v in kwargs.values(): assert isinstance( - v, (list, np.ndarray, torch.Tensor) + v, (self.SUPPORTED_TYPES) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats = dict(**kwargs) def __setitem__(self, key: str, item: Any) -> None: assert isinstance( - item, (list, np.ndarray, torch.Tensor) + item, (self.SUPPORTED_TYPES) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats[key] = item