From d09a58f2d8ca4769d6aa5e6dfe212fa69e50b01b Mon Sep 17 00:00:00 2001 From: Cryaaa Date: Wed, 4 Sep 2024 17:05:23 +0000 Subject: [PATCH] added transforms to make loading possible --- src/embed_time/transforms.py | 142 ----------------------------------- 1 file changed, 142 deletions(-) diff --git a/src/embed_time/transforms.py b/src/embed_time/transforms.py index 02da672..ec6f69a 100644 --- a/src/embed_time/transforms.py +++ b/src/embed_time/transforms.py @@ -174,145 +174,3 @@ def __call__(self, sample): cropped = crop_around_centroid_2D(sample,cent,self.crop_size,self.crop_size) return np.moveaxis(cropped,0,1) - -class ColorJitterBrightfield(torch.nn.Module): - """Randomly change the brightness, contrast, saturation and hue of an image. - If the image is torch Tensor, it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. - - Args: - brightness (float or tuple of float (min, max)): How much to jitter brightness. - brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] - or the given [min, max]. Should be non negative numbers. - contrast (float or tuple of float (min, max)): How much to jitter contrast. - contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] - or the given [min, max]. Should be non-negative numbers. - saturation (float or tuple of float (min, max)): How much to jitter saturation. - saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] - or the given [min, max]. Should be non negative numbers. - hue (float or tuple of float (min, max)): How much to jitter hue. - hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. - Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. - To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; - thus it does not work if you normalize your image to an interval with negative values, - or use an interpolation that generates negative values before using this function. - """ - - def __init__( - self, - brightness: Union[float, Tuple[float, float]] = 0, - contrast: Union[float, Tuple[float, float]] = 0, - saturation: Union[float, Tuple[float, float]] = 0, - hue: Union[float, Tuple[float, float]] = 0, - channel_dim: int = 0, - ) -> None: - super().__init__() - self.brightness = self._check_input(brightness, "brightness") - self.contrast = self._check_input(contrast, "contrast") - self.saturation = self._check_input(saturation, "saturation") - self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - self.channel_dim = channel_dim - - @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True): - if isinstance(value, numbers.Number): - if value < 0: - raise ValueError(f"If {name} is a single number, it must be non negative.") - value = [center - float(value), center + float(value)] - if clip_first_on_zero: - value[0] = max(value[0], 0.0) - elif isinstance(value, (tuple, list)) and len(value) == 2: - value = [float(value[0]), float(value[1])] - else: - raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") - - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"{name} values should be between {bound}, but got {value}.") - - # if value is 0 or (1., 1.) for brightness/contrast/saturation - # or (0., 0.) for hue, do nothing - if value[0] == value[1] == center: - return None - else: - return tuple(value) - - @staticmethod - def get_params( - brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]], - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: - """Get the parameters for the randomized transform to be applied on image. - - Args: - brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen - uniformly. Pass None to turn off the transformation. - contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen - uniformly. Pass None to turn off the transformation. - saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen - uniformly. Pass None to turn off the transformation. - hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. - Pass None to turn off the transformation. - - Returns: - tuple: The parameters used to apply the randomized transform - along with their random order. - """ - fn_idx = torch.randperm(4) - - b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) - c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) - s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) - h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - - return fn_idx, b, c, s, h - - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Input image. - - Returns: - PIL Image or Tensor: Color jittered image. - """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( - self.brightness, self.contrast, self.saturation, self.hue - ) - shape = img.shape - - others = img[1] - outs =[] - for tp in range(4): - - out = img[0,tp] - for fn_id in fn_idx: - if fn_id == 0 and brightness_factor is not None: - out = F.adjust_brightness(out, brightness_factor) - elif fn_id == 1 and contrast_factor is not None: - out = F.adjust_contrast(out, contrast_factor) - elif fn_id == 2 and saturation_factor is not None: - out = F.adjust_saturation(out, saturation_factor) - elif fn_id == 3 and hue_factor is not None: - out = F.adjust_hue(out, hue_factor) - outs = torch.concat(outs,dim=0) - return torch.concat( - [ - out.unsqueeze(self.channel_dim), - others.unsqueeze(self.channel_dim) - ], - dim=self.channel_dim - ) - - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"brightness={self.brightness}" - f", contrast={self.contrast}" - f", saturation={self.saturation}" - f", hue={self.hue})" - ) - return s \ No newline at end of file