diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index 8f9720f..d849e5a 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -100,23 +100,21 @@ def __init__( map_location=self.device, ) if transformers: - - if transformer_flip and not transformer_rotate: + if transformer_flip and transformer_rotate: # do all + transforms = tta.aliases.d4_transform() + elif transformer_rotate: # do rotate only transforms = tta.Compose( [ - tta.HorizontalFlip(), - tta.VerticalFlip(), + tta.Rotate90(angles=[90]), ] ) - elif not transformer_flip and transformer_rotate: + elif transformer_flip: # do flip only transforms = tta.Compose( [ - tta.Rotate90(angles=[90]), + tta.HorizontalFlip(), + tta.VerticalFlip(), ] ) - elif transformer_flip and transformer_rotate: - transforms = tta.aliases.d4_transform() - self.model = tta.SegmentationTTAWrapper(self.model, transforms, merge_mode='mean') self.mask_to_vec = mask_to_vec self.mask_to_coco = mask_to_coco