diff --git a/vision/transforms/transforms.py b/vision/transforms/transforms.py index c53b38e..1d6ed8b 100644 --- a/vision/transforms/transforms.py +++ b/vision/transforms/transforms.py @@ -259,7 +259,7 @@ class RandomSampleCrop(object): """ def __init__(self): - self.sample_options = ( + self.sample_options = np.array([ # using entire original input image None, # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 @@ -269,7 +269,7 @@ def __init__(self): (0.9, None), # randomly sample a patch (None, None), - ) + ], dtype=object) def __call__(self, image, boxes=None, labels=None): height, width, _ = image.shape @@ -364,7 +364,7 @@ class RandomSampleCrop_v2(object): """ def __init__(self): - self.sample_options = ( + self.sample_options = np.array([ # using entire original input image None, # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 @@ -374,7 +374,7 @@ def __init__(self): (1, None), (1, None), (1, None), - ) + ], dtype=object) def __call__(self, image, boxes=None, labels=None): height, width, _ = image.shape