Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic image-caption dataset #65

Merged
merged 11 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions diffusion/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
from diffusion.datasets.coco import StreamingCOCOCaption, build_streaming_cocoval_dataloader
from diffusion.datasets.image_caption import StreamingImageCaptionDataset, build_streaming_image_caption_dataloader
from diffusion.datasets.laion import StreamingLAIONDataset, build_streaming_laion_dataloader
from diffusion.datasets.synthetic_image_caption import (SyntheticImageCaptionDataset,
build_synthetic_image_caption_dataloader)

__all__ = [
'build_streaming_laion_dataloader', 'StreamingLAIONDataset', 'build_streaming_cocoval_dataloader',
'StreamingCOCOCaption', 'build_streaming_image_caption_dataloader', 'StreamingImageCaptionDataset'
'build_streaming_laion_dataloader',
'StreamingLAIONDataset',
'build_streaming_cocoval_dataloader',
'StreamingCOCOCaption',
'build_streaming_image_caption_dataloader',
'StreamingImageCaptionDataset',
'build_synthetic_image_caption_dataloader',
'SyntheticImageCaptionDataset',
]
68 changes: 68 additions & 0 deletions diffusion/datasets/synthetic_image_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Synthetic Image-Caption dataset."""

from typing import Dict, Optional

import torch
from composer.utils import dist
from torch.utils.data import DataLoader, Dataset


class SyntheticImageCaptionDataset(Dataset):
"""Synthetic dataset imitating a dataset containing image-caption pairs.

Args:
image_size (int): Size of the synthetic images. Default: ``512``.
caption_length (int): Length of the synthetic captions. Default: ``77``.
num_samples (int): Number of samples in the synthetic dataset. Default: ``100_000``.
"""

def __init__(self, image_size: int = 512, caption_length: int = 77, num_samples: int = 100_000):

super().__init__()
self.num_samples = num_samples
self.images = torch.randn(num_samples, 3, image_size, image_size)
self.captions = torch.randint(0, 128, (num_samples, caption_length), dtype=torch.long)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
return {'image': self.images[idx], 'captions': self.captions[idx]}


def build_synthetic_image_caption_dataloader(
batch_size: int,
image_size: int = 512,
caption_length: int = 77,
num_samples: int = 100_000,
dataloader_kwargs: Optional[Dict] = None,
):
"""Builds a dataloader for the synthetic image-caption dataset.

Args:
batch_size (int): Batch size for the dataloader.
image_size (int): Size of the synthetic images. Default: ``512``.
caption_length (int): Length of the synthetic captions. Default: ``77``.
num_samples (int): Number of samples in the synthetic dataset. Default: ``100_000``.
dataloader_kwargs (optional, dict): Additional arguments to pass to the dataloader. Default ``None``.
"""
if dataloader_kwargs is None:
dataloader_kwargs = {}

dataset = SyntheticImageCaptionDataset(
image_size=image_size,
caption_length=caption_length,
num_samples=num_samples,
)

dataloader = DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
batch_size=batch_size,
**dataloader_kwargs,
)

return dataloader