Skip to content

Commit

Permalink
Refactor IAMDataset to support multiple dataset types and split confi…
Browse files Browse the repository at this point in the history
…gurations
  • Loading branch information
vittoriopippi committed Feb 4, 2025
1 parent 55e13d9 commit 1a4a3fd
Showing 1 changed file with 79 additions and 115 deletions.
194 changes: 79 additions & 115 deletions hwd/datasets/iam.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from .base_dataset import BaseDataset
from pathlib import Path
from PIL import Image
import random
import xml.etree.ElementTree as ET
from .shtg.base_dataset import extract_zip, extract_tgz
from .shtg.iam import download_file, extract_lines_from_xml, extract_words_from_xml
from .shtg.iam import IAM_XML_DIR_PATH, IAM_XML_TGZ_PATH, IAM_XML_URL
from .shtg.iam import IAM_WORDS_DIR_PATH, IAM_WORDS_TGZ_PATH, IAM_WORDS_URL
from .shtg.iam import IAM_LINES_DIR_PATH, IAM_LINES_TGZ_PATH, IAM_LINES_URL
from .shtg.iam import SHTG_AUTHORS_URL, SHTG_AUTHORS_PATH

IAM_SPLITS_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/largeWriterIndependentTextLineRecognitionTask.1.zip'
IAM_SPLITS_ZIP_PATH = Path('~/.cache/iam/largeWriterIndependentTextLineRecognitionTask.1.zip').expanduser()
IAM_SPLITS_DIR_PATH = Path('~/.cache/iam/largeWriterIndependentTextLineRecognitionTask.1').expanduser()

IAM_SPLITS_PATHS = {
'train': IAM_SPLITS_DIR_PATH / 'trainset.txt',
'test': IAM_SPLITS_DIR_PATH / 'testset.txt',
'val1': IAM_SPLITS_DIR_PATH / 'validationset1.txt',
'val2': IAM_SPLITS_DIR_PATH / 'validationset2.txt',
}


class IAMDataset(BaseDataset):
def __init__(self, path, transform=None, nameset=None, dataset_type='words'):
def __init__(self, transform=None, nameset='train', dataset_type='words', split_type='shtg'):
"""
Args:
path (string): Path folder of the dataset.
Expand All @@ -16,115 +30,65 @@ def __init__(self, path, transform=None, nameset=None, dataset_type='words'):
nameset (string, optional): Name of the dataset.
max_samples (int, optional): Maximum number of samples to consider.
"""
super().__init__(path, transform, nameset)

self.imgs = list(Path(path, dataset_type).rglob('*.png'))
xml_files = list(Path(path, 'xmls').rglob('*.xml'))
author_dict = {xml_file.stem: ET.parse(xml_file).getroot().attrib['writer-id'] for xml_file in xml_files}

self.author_ids = list(set(author_dict.values()))
self.labels = [author_dict[img.parent.stem] for img in self.imgs]

# from PIL import Image
# from .base_dataset import BaseSHTGDataset, download_file, extract_tgz
# from pathlib import Path
# import json
# import gzip
# import xml.etree.ElementTree as ET

# SHTG_IAM_LINES_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/shtg_iam_lines.json.gz'
# SHTG_IAM_WORDS_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/shtg_iam_words.json.gz'
# SHTG_IAM_LINES_PATH = Path('.cache/iam/shtg_iam_lines.json.gz')
# SHTG_IAM_WORDS_PATH = Path('.cache/iam/shtg_iam_words.json.gz')

# SHTG_AUTHORS_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/gan.iam.test.gt.filter27.txt'
# SHTG_AUTHORS_PATH = Path('.cache/iam/gan.iam.test.gt.filter27.txt')

# IAM_LINES_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/lines.tgz'
# IAM_WORDS_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/words.tgz'
# IAM_XML_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/xml.tgz'
# IAM_ASCII_URL = 'https://github.com/aimagelab/HWD/releases/download/iam/ascii.tgz'

# IAM_LINES_TGZ_PATH = Path('.cache/iam/lines.tgz')
# IAM_WORDS_TGZ_PATH = Path('.cache/iam/words.tgz')
# IAM_XML_TGZ_PATH = Path('.cache/iam/xml.tgz')
# IAM_ASCII_TGZ_PATH = Path('.cache/iam/ascii.tgz')

# IAM_LINES_DIR_PATH = Path('.cache/iam/lines')
# IAM_WORDS_DIR_PATH = Path('.cache/iam/words')
# IAM_XML_DIR_PATH = Path('.cache/iam/xml')
# IAM_ASCII_DIR_PATH = Path('.cache/iam/ascii')


# def extract_lines_from_xml(xml_string):
# # Parse the XML string
# root = ET.fromstring(xml_string)

# lines_info = []

# # Find all line elements within the handwritten-part
# for line in root.findall('.//line'):
# line_data = {
# 'id': line.get('id'),
# 'text': line.get('text'),
# 'writer_id': root.attrib['writer-id']
# }
# lines_info.append(line_data)

# return lines_info


# class IAMLines(BaseSHTGDataset):
# def __init__(self, load_style_samples=True, num_style_samples=1):
# self.load_style_samples = load_style_samples
# self.num_style_samples = num_style_samples

# download_file(SHTG_IAM_LINES_URL, SHTG_IAM_LINES_PATH, exist_ok=True)
# download_file(SHTG_AUTHORS_URL, SHTG_AUTHORS_PATH, exist_ok=True)
# self.authors = set()
# for author_line in SHTG_AUTHORS_PATH.read_text().splitlines():
# self.authors.add(author_line.split(',')[0])

# if not IAM_LINES_DIR_PATH.exists():
# download_file(IAM_LINES_URL, IAM_LINES_TGZ_PATH)
# extract_tgz(IAM_LINES_TGZ_PATH, IAM_LINES_DIR_PATH, delete=True)

# if not IAM_XML_DIR_PATH.exists():
# download_file(IAM_XML_URL, IAM_XML_TGZ_PATH)
# extract_tgz(IAM_XML_TGZ_PATH, IAM_XML_DIR_PATH, delete=True)

# with gzip.open(SHTG_IAM_LINES_PATH, 'rt', encoding='utf-8') as file:
# self.data = json.load(file)

# self.imgs = {img_path.stem: img_path for img_path in IAM_LINES_DIR_PATH.rglob('*.png')}

# lines = []
# for xml_file in IAM_XML_DIR_PATH.rglob('*.xml'):
# lines.extend(extract_lines_from_xml(xml_file.read_text()))
# lines = [line for line in lines if line['writer_id'] in self.authors]
# self.lines = {line['id']: line for line in lines}

# # Switching from words ids to lines ids
# for sample in self.data:
# filtered_style_ids = []
# for style_sample in sample['style_ids']:
# if self.lines[style_sample[:-3]]['text'] != sample['word']:
# filtered_style_ids.append(style_sample[:-3])
# sample['style_ids'] = list(set(filtered_style_ids))
# assert len(sample['style_ids']) > 0

# def __len__(self):
# return len(self.data)

# def __getitem__(self, idx):
# sample = self.data[idx]
# output = {}
# output['gen_text'] = sample['word']
# output['author'] = Path(sample['dst']).parent.name
# output['dst_path'] = sample['dst']
# output['style_ids'] = sample['style_ids'][:self.num_style_samples]
# output['style_imgs_path'] = [self.imgs[id] for id in output['style_ids']]
# if self.load_style_samples:
# output['style_imgs'] = [Image.open(self.imgs[id]) for id in output['style_ids']]
# return output

if not IAM_XML_DIR_PATH.exists():
download_file(IAM_XML_URL, IAM_XML_TGZ_PATH)
extract_tgz(IAM_XML_TGZ_PATH, IAM_XML_DIR_PATH, delete=True)

if dataset_type == 'lines':
self.data = []
for xml_file in IAM_XML_DIR_PATH.rglob('*.xml'):
self.data.extend(extract_lines_from_xml(xml_file.read_text()))

if not IAM_LINES_DIR_PATH.exists():
download_file(IAM_LINES_URL, IAM_LINES_TGZ_PATH)
extract_tgz(IAM_LINES_TGZ_PATH, IAM_LINES_DIR_PATH, delete=True)

self.dict_path = {img_path.stem: img_path for img_path in IAM_LINES_DIR_PATH.rglob('*.png')}

else:
self.data = []
for xml_file in IAM_XML_DIR_PATH.rglob('*.xml'):
self.data.extend(extract_words_from_xml(xml_file.read_text()))

if not IAM_WORDS_DIR_PATH.exists():
download_file(IAM_WORDS_URL, IAM_WORDS_TGZ_PATH)
extract_tgz(IAM_WORDS_TGZ_PATH, IAM_WORDS_DIR_PATH, delete=True)

self.dict_path = {img_path.stem: img_path for img_path in IAM_WORDS_DIR_PATH.rglob('*.png')}

if split_type == 'shtg':
assert nameset in ['train', 'test'], f"Invalid nameset: {nameset}. Available namesets: ['train', 'test']"

if not SHTG_AUTHORS_PATH.exists():
download_file(SHTG_AUTHORS_URL, SHTG_AUTHORS_PATH, exist_ok=True)

test_authors = SHTG_AUTHORS_PATH.read_text().splitlines()
test_authors = {line.split(',')[0] for line in test_authors}

if nameset == 'train':
self.data = [sample for sample in self.data if sample['writer_id'] not in test_authors]
else:
self.data = [sample for sample in self.data if sample['writer_id'] in test_authors]

elif split_type == 'original':
assert nameset in IAM_SPLITS_PATHS, f"Invalid nameset: {nameset}. Available namesets: {list(IAM_SPLITS_PATHS.keys())}"

if not IAM_SPLITS_DIR_PATH.exists():
download_file(IAM_SPLITS_URL, IAM_SPLITS_ZIP_PATH)
extract_zip(IAM_SPLITS_ZIP_PATH, IAM_SPLITS_DIR_PATH, delete=True)

allowed_ids = set(IAM_SPLITS_PATHS[nameset].read_text().splitlines())

def _id_line(id):
return id[:-3] if dataset_type == 'words' else id
self.data = [sample for sample in self.data if _id_line(sample['id']) in allowed_ids]
else:
raise ValueError(f"Invalid split type: {split_type}. Available split types: ['shtg', 'original']")

self.imgs = [self.dict_path[sample['id']] for sample in self.data]
self.author_ids = [sample['writer_id'] for sample in self.data]
super().__init__(IAM_XML_DIR_PATH.parent, self.imgs, self.data, self.author_ids, nameset, transform=transform)

self.labels = [sample['text'] for sample in self.data]
self.has_labels = True

0 comments on commit 1a4a3fd

Please sign in to comment.