Skip to content

Commit

Permalink
added transforms to make loading possible
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryaaa committed Sep 4, 2024
1 parent d7b1a5f commit d09a58f
Showing 1 changed file with 0 additions and 142 deletions.
142 changes: 0 additions & 142 deletions src/embed_time/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,145 +174,3 @@ def __call__(self, sample):
cropped = crop_around_centroid_2D(sample,cent,self.crop_size,self.crop_size)

return np.moveaxis(cropped,0,1)

class ColorJitterBrightfield(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non-negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
"""

def __init__(
self,
brightness: Union[float, Tuple[float, float]] = 0,
contrast: Union[float, Tuple[float, float]] = 0,
saturation: Union[float, Tuple[float, float]] = 0,
hue: Union[float, Tuple[float, float]] = 0,
channel_dim: int = 0,
) -> None:
super().__init__()
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
self.channel_dim = channel_dim

@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError(f"If {name} is a single number, it must be non negative.")
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
value = [float(value[0]), float(value[1])]
else:
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")

if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError(f"{name} values should be between {bound}, but got {value}.")

# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
return None
else:
return tuple(value)

@staticmethod
def get_params(
brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]],
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Args:
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
fn_idx = torch.randperm(4)

b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))

return fn_idx, b, c, s, h


def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
shape = img.shape

others = img[1]
outs =[]
for tp in range(4):

out = img[0,tp]
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
out = F.adjust_brightness(out, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
out = F.adjust_contrast(out, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
out = F.adjust_saturation(out, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
out = F.adjust_hue(out, hue_factor)
outs = torch.concat(outs,dim=0)
return torch.concat(
[
out.unsqueeze(self.channel_dim),
others.unsqueeze(self.channel_dim)
],
dim=self.channel_dim
)


def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"brightness={self.brightness}"
f", contrast={self.contrast}"
f", saturation={self.saturation}"
f", hue={self.hue})"
)
return s

0 comments on commit d09a58f

Please sign in to comment.