Skip to content

Commit

Permalink
Introduced the UnfoldDataset. A dataset to divide the input images in…
Browse files Browse the repository at this point in the history
…to patches
  • Loading branch information
vittoriopippi committed Oct 16, 2024
1 parent 48ac367 commit d4d91ff
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
2 changes: 1 addition & 1 deletion datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base_dataset import BaseDataset
from .folder_dataset import FolderDataset
from .folder_dataset import FolderDataset, UnfoldDataset
from .cvl import CVLDataset
from .iam import IAMDataset
from .leopardi import LeopardiDataset
Expand Down
4 changes: 2 additions & 2 deletions datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import copy
from pathlib import Path

class BaseDataset(Dataset):
def __init__(self, path, transform=None, nameset=None, preprocess=None):
Expand Down Expand Up @@ -35,8 +35,8 @@ def __getitem__(self, index):
tuple: (image, label) where label is index of the target class.
"""
img = self.imgs[index]
img = Image.open(img).convert('RGB') if isinstance(img, Path) else img
label = self.labels[index]
img = Image.open(img).convert('RGB')
if self.preprocess is not None:
img = self.preprocess(img)
if self.transform is not None:
Expand Down
52 changes: 52 additions & 0 deletions datasets/folder_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from .base_dataset import BaseDataset
from pathlib import Path
from PIL import Image
import numpy as np
import math


class FolderDataset(BaseDataset):
Expand All @@ -10,3 +13,52 @@ def __init__(self, path, extension='png', **kwargs):
assert len(self.imgs) > 0, 'No images found.'
self.labels = [img.parent.name for img in self.imgs]
self.author_ids = sorted(set(self.labels))

class UnfoldDataset(FolderDataset):
def __init__(self, path, extension='png', patch_width=None, stride=None, pad_img=True, **kwargs):
super().__init__(path, extension, **kwargs)

tmp_imgs = []
tmp_labels = []
for img, label in zip(self.imgs, self.labels):
img = Image.open(img).convert('RGB')
img_patches = self._unfold_img(img, pad_img=pad_img)
tmp_imgs.extend(img_patches)
tmp_labels.extend([label] * len(img_patches))
self.imgs = tmp_imgs
self.labels = tmp_labels

def _unfold_img(self, img, patch_width=None, stride=None, pad_img=True):
image_array = np.array(img)

# Get the dimensions of the image
height, width, channels = image_array.shape
patch_width = patch_width if patch_width is not None else height
stride = stride if stride is not None else height

if pad_img:
new_width = math.ceil(width / stride) * stride
remaining_width = new_width - width
image_array = np.pad(
image_array,
((0, 0), (0, remaining_width), (0, 0)),
mode='constant',
constant_values=255
)
width = new_width

# Initialize an empty list to store patches
patches = []

# Loop through the image width and extract patches
for x_start in range(0, width - patch_width + 1, stride):
# Set y_start to 0 and y_end to patch_height to cover the full height slice
y_start = 0
y_end = height

# Extract the patch
patch = image_array[y_start:y_end, x_start:x_start + patch_width, :]
patches.append(patch)

patches = [Image.fromarray(patch) for patch in patches]
return patches

0 comments on commit d4d91ff

Please sign in to comment.