Skip to content

Commit

Permalink
Synthetic image-caption dataset (#65)
Browse files Browse the repository at this point in the history
* Add synthetic image-caption dataset
  • Loading branch information
Landanjs authored Oct 6, 2023
1 parent 35f5a57 commit 80e2af5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
12 changes: 10 additions & 2 deletions diffusion/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
from diffusion.datasets.coco import StreamingCOCOCaption, build_streaming_cocoval_dataloader
from diffusion.datasets.image_caption import StreamingImageCaptionDataset, build_streaming_image_caption_dataloader
from diffusion.datasets.laion import StreamingLAIONDataset, build_streaming_laion_dataloader
from diffusion.datasets.synthetic_image_caption import (SyntheticImageCaptionDataset,
build_synthetic_image_caption_dataloader)

__all__ = [
'build_streaming_laion_dataloader', 'StreamingLAIONDataset', 'build_streaming_cocoval_dataloader',
'StreamingCOCOCaption', 'build_streaming_image_caption_dataloader', 'StreamingImageCaptionDataset'
'build_streaming_laion_dataloader',
'StreamingLAIONDataset',
'build_streaming_cocoval_dataloader',
'StreamingCOCOCaption',
'build_streaming_image_caption_dataloader',
'StreamingImageCaptionDataset',
'build_synthetic_image_caption_dataloader',
'SyntheticImageCaptionDataset',
]
68 changes: 68 additions & 0 deletions diffusion/datasets/synthetic_image_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Synthetic Image-Caption dataset."""

from typing import Dict, Optional

import torch
from composer.utils import dist
from torch.utils.data import DataLoader, Dataset


class SyntheticImageCaptionDataset(Dataset):
"""Synthetic dataset imitating a dataset containing image-caption pairs.
Args:
image_size (int): Size of the synthetic images. Default: ``512``.
caption_length (int): Length of the synthetic captions. Default: ``77``.
num_samples (int): Number of samples in the synthetic dataset. Default: ``100_000``.
"""

def __init__(self, image_size: int = 512, caption_length: int = 77, num_samples: int = 100_000):

super().__init__()
self.num_samples = num_samples
self.images = torch.randn(num_samples, 3, image_size, image_size)
self.captions = torch.randint(0, 128, (num_samples, caption_length), dtype=torch.long)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
return {'image': self.images[idx], 'captions': self.captions[idx]}


def build_synthetic_image_caption_dataloader(
batch_size: int,
image_size: int = 512,
caption_length: int = 77,
num_samples: int = 100_000,
dataloader_kwargs: Optional[Dict] = None,
):
"""Builds a dataloader for the synthetic image-caption dataset.
Args:
batch_size (int): Batch size for the dataloader.
image_size (int): Size of the synthetic images. Default: ``512``.
caption_length (int): Length of the synthetic captions. Default: ``77``.
num_samples (int): Number of samples in the synthetic dataset. Default: ``100_000``.
dataloader_kwargs (optional, dict): Additional arguments to pass to the dataloader. Default ``None``.
"""
if dataloader_kwargs is None:
dataloader_kwargs = {}

dataset = SyntheticImageCaptionDataset(
image_size=image_size,
caption_length=caption_length,
num_samples=num_samples,
)

dataloader = DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
batch_size=batch_size,
**dataloader_kwargs,
)

return dataloader

0 comments on commit 80e2af5

Please sign in to comment.