Skip to content

Commit

Permalink
Add SaintGallDataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
vittoriopippi committed Feb 11, 2025
1 parent 67625df commit 24ac36c
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions hwd/datasets/saint_gall.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,58 @@
from .base_dataset import BaseDataset
from pathlib import Path
from .shtg.base_dataset import extract_zip, download_file


SAINTGALL_URL = 'https://github.com/aimagelab/HWD/releases/download/saintgall/saintgalldb-v1.0.zip'
SAINTGALL_ZIP_PATH = Path('~/.cache/saintgall/saintgalldb-v1.0.zip').expanduser()
SAINTGALL_DIR_PATH = Path('~/.cache/saintgall').expanduser()

SPECIAL_MAP = {
"pt": ".",
"et": "&",
}

class SaintGallDataset(BaseDataset):
def __init__(self, path, transform=None, nameset=None):
"""
Args:
path (string): Path folder of the dataset.
transform (callable, optional): Optional transform to be applied
on a sample.
author_ids (list, optional): List of authors to consider.
nameset (string, optional): Name of the dataset.
max_samples (int, optional): Maximum number of samples to consider.
"""
super().__init__(path, transform, nameset)
def __init__(self, transform=None, nameset='train', dataset_type='lines'):
if not SAINTGALL_DIR_PATH.exists():
download_file(SAINTGALL_URL, SAINTGALL_ZIP_PATH)
extract_zip(SAINTGALL_ZIP_PATH, SAINTGALL_DIR_PATH, delete=True)
saintgall_unzip_path = SAINTGALL_DIR_PATH / 'saintgalldb-v1.0'

self.path = Path(self.path) / 'data' / 'line_images_normalized'
self.imgs = list(self.path.rglob('*.png'))
nameset_path = saintgall_unzip_path / 'sets' / f'{nameset}.txt'
assert nameset_path.exists(), f'The nameset file {nameset_path} does not exist at the specified path {nameset_path}'
split_ids = set(nameset_path.read_text().splitlines())

self.author_ids = [0, ]
self.labels = [0, ] * len(self.imgs)
if dataset_type == 'lines':
self.path = saintgall_unzip_path / 'data' / 'line_images'
elif dataset_type == 'lines_normalized':
self.path = saintgall_unzip_path / 'data' / 'line_images_normalized'
else:
raise ValueError(f'Invalid dataset_type: {dataset_type}. Available types: ["lines", "lines_normalized"]')

self.imgs = list(self.path.rglob('*.png'))
self.imgs = [img for img in self.imgs if img.stem[:10] in split_ids]

self.author_ids = [0, ] * len(self.imgs) # All samples are from the same author
super().__init__(
saintgall_unzip_path,
self.imgs,
self.author_ids,
[0, ], # All samples are from the same author
transform=transform,
)

labels_path = saintgall_unzip_path / 'ground_truth' / 'transcription.txt'
self.labels_dict = {}
for line in labels_path.read_text().splitlines():
img_id, label, _ = line.split(' ')
label = label.replace('|', '- -').split('-')
for i in range(len(label)):
if len(label[i]) > 1:
label[i] = SPECIAL_MAP[label[i]]
label = ''.join(label)
self.labels_dict[img_id] = label

self.labels = [self.labels_dict[img.stem] for img in self.imgs]
self.has_labels = True

0 comments on commit 24ac36c

Please sign in to comment.