Skip to content

Commit

Permalink
Added some documentation to transforms. I'd been using [1,1] as identity
Browse files Browse the repository at this point in the history
blur parameters, which is false.
  • Loading branch information
alex404 committed Oct 12, 2024
1 parent 4ae29b9 commit c35076d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from runner.train import train
from runner.util import delete_results

# Load the eval resolver for config files
# Load the eval resolver for OmegaConf
OmegaConf.register_new_resolver("eval", eval)


Expand Down
48 changes: 48 additions & 0 deletions retinal_rl/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@


class IlluminationTransform(nn.Module):
"""Apply random illumination (brightness) adjustment to the input image."""

def __init__(self, brightness_range: Tuple[float, float]) -> None:
"""Initialize the IlluminationTransform.
Args:
----
brightness_range (Tuple[float, float]): Range of brightness adjustment factors. For an identity transform, set the range to (1, 1).
"""
super().__init__()
self.brightness_range = brightness_range

Expand All @@ -40,7 +50,16 @@ def forward(self, img: Image.Image) -> Image.Image:


class BlurTransform(nn.Module):
"""Apply random Gaussian blur to the input image."""

def __init__(self, blur_range: Tuple[float, float]) -> None:
"""Initialize the BlurTransform.
Args:
----
blur_range (Tuple[float, float]): Range of blur radii. For an identity transform, set the range to (0, 0).
"""
super().__init__()
self.blur_range = blur_range

Expand All @@ -61,12 +80,23 @@ def forward(self, img: Image.Image) -> Image.Image:


class ScaleShiftTransform(nn.Module):
"""Apply random scale and shift transformations to the input image."""

def __init__(
self,
vision_width: int,
vision_height: int,
image_rescale_range: List[float],
) -> None:
"""Initialize the ScaleShiftTransform.
Args:
----
vision_width (int): The width of the visual field.
vision_height (int): The height of the visual field.
image_rescale_range (List[float]): Range of image rescaling factors. For an identity transform, set the range to [1, 1].
"""
super().__init__()
self.vision_width = vision_width
self.vision_height = vision_height
Expand Down Expand Up @@ -122,7 +152,16 @@ def forward(self, img: Image.Image) -> Image.Image:


class ShotNoiseTransform(nn.Module):
"""Apply random shot noise to the input image."""

def __init__(self, lambda_range: Tuple[float, float]) -> None:
"""Initialize the ShotNoiseTransform.
Args:
----
lambda_range (Tuple[float, float]): Range of shot noise intensity factors. For an identity transform, set the range to (1, 1).
"""
super().__init__()
self.lambda_range = lambda_range

Expand Down Expand Up @@ -151,7 +190,16 @@ def forward(self, img: Image.Image) -> Image.Image:


class ContrastTransform(nn.Module):
"""Apply random contrast adjustment to the input image."""

def __init__(self, contrast_range: Tuple[float, float]) -> None:
"""Initialize the ContrastTransform.
Args:
----
contrast_range (Tuple[float, float]): Range of contrast adjustment factors. For an identity transform, set the range to (1, 1).
"""
super().__init__()
self.contrast_range = contrast_range

Expand Down
2 changes: 1 addition & 1 deletion runner/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

import omegaconf
import torch
import wandb
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from torch.optim.optimizer import Optimizer

import wandb
from retinal_rl.models.brain import Brain
from runner.util import save_checkpoint

Expand Down

0 comments on commit c35076d

Please sign in to comment.