Skip to content

Commit

Permalink
feat(data): start building dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
charitarthchugh committed Nov 26, 2024
1 parent 57aa85d commit c4466d1
Showing 1 changed file with 65 additions and 23 deletions.
88 changes: 65 additions & 23 deletions src/lightningsparseinst/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from pathlib import Path
from typing import List, Mapping, Optional

import fiftyone as fo
import hydra
import lightning.pytorch as pl
import lightning as L
import omegaconf
from albumentations import Compose
from fiftyone import ViewField
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
Expand All @@ -14,6 +17,8 @@
from nn_core.common import PROJECT_ROOT
from nn_core.nn_types import Split

from lightningsparseinst.data.dataset import SegmentationDataset

pylogger = logging.getLogger(__name__)


Expand Down Expand Up @@ -76,9 +81,7 @@ def load(src_path: Path) -> "MetaData":
key, value = line.strip().split("\t")
class_vocab[key] = value

return MetaData(
class_vocab=class_vocab,
)
return MetaData(class_vocab=class_vocab)

def __repr__(self) -> str:
attributes = ",\n ".join([f"{key}={value}" for key, value in self.__dict__.items()])
Expand All @@ -99,29 +102,31 @@ def collate_fn(samples: List, split: Split, metadata: MetaData):
return default_collate(samples)


class MyDataModule(pl.LightningDataModule):
class DataModule(L.LightningDataModule):
def __init__(
self,
dataset: DictConfig,
num_workers: DictConfig,
batch_size: DictConfig,
split_names: DictConfig,
accelerator: str,
# example
val_images_fixed_idxs: List[int],
):
super().__init__()
self.dataset = dataset
self.num_workers = num_workers
self.batch_size = batch_size
self.split_names = split_names
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus
self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu"

self.fiftyone_dataset: Optional[fo.Dataset] = None
self.classes: Optional[List[str] | str] = None

self.train_dataset: Optional[Dataset] = None
self.val_dataset: Optional[Dataset] = None
self.test_dataset: Optional[Dataset] = None

# example
self.val_images_fixed_idxs: List[int] = val_images_fixed_idxs
self.transform: Optional[Compose] = None

@cached_property
def metadata(self) -> MetaData:
Expand All @@ -136,25 +141,62 @@ def metadata(self) -> MetaData:
if self.train_dataset is None:
self.setup(stage="fit")

return MetaData(class_vocab={i: name for i, name in enumerate(self.train_dataset.features["y"].names)})
return MetaData(class_vocab=self.train_dataset.labels_map_rev)

def prepare_data(self) -> None:
# download only
# download only\
pass

def setup(self, stage: Optional[str] = None):
self.transform = hydra.utils.instantiate(self.dataset.transforms)

self.hf_datasets = hydra.utils.instantiate(self.dataset)
self.hf_datasets.set_transform(self.transform)

# Here you should instantiate your dataset, you may also split the train into train and validation if needed.
self.fiftyone_dataset = fo.load_dataset(self.dataset.ref)
self.fiftyone_dataset.compute_metadata()

self.transform = hydra.utils.instantiate(self.dataset.transform)
# Label filtering logic
self.classes = self.dataset.classes if "classes" in self.dataset.keys() else None
if self.classes:
if isinstance(self.classes, list):
self.fiftyone_dataset = self.fiftyone_dataset.filter_labels(
f"{self.dataset.gt_field}.{self.dataset.detection_field}", ViewField("label").is_in(self.classes)
)
elif isinstance(self.classes, str):
# regex case
self.fiftyone_dataset = self.fiftyone_dataset.filter_labels(
f"{self.dataset.gt_field}.{self.dataset.detection_field}", ViewField("label").re_match(self.classes)
)
else:
self.classes = self.fiftyone_dataset.distinct(
f"{self.dataset.gt_field}.{self.dataset.detection_field}.label"
)
# self.hf_datasets = hydra.utils.instantiate(self.dataset)
# self.hf_datasets.set_transform(self.transform)
#
# # Here you should instantiate your dataset, you may also split the train into train and validation if needed.
if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_dataset is None):
self.train_dataset = self.hf_datasets["train"]
self.val_dataset = self.hf_datasets["val"]

self.train_dataset = SegmentationDataset(
self.fiftyone_dataset,
split=self.split_names["train"],
gt_field=self.dataset.gt_field,
detection_field=self.dataset.detection_field,
transform=self.transform,
max_num_instances_per_image=self.dataset.max_num_instances_per_image,
)
self.val_dataset = SegmentationDataset(
self.fiftyone_dataset,
split=self.split_names["validation"],
gt_field=self.dataset.gt_field,
detection_field=self.dataset.detection_field,
max_num_instances_per_image=self.dataset.max_num_instances_per_image,
)
#
if stage is None or stage == "test":
self.test_dataset = self.hf_datasets["test"]
self.test_dataset = SegmentationDataset(
self.fiftyone_dataset,
split=self.split_names["test"],
gt_field=self.dataset.gt_field,
detection_field=self.dataset.detection_field,
max_num_instances_per_image=self.dataset.max_num_instances_per_image,
)

def train_dataloader(self) -> DataLoader:
return DataLoader(
Expand Down Expand Up @@ -190,14 +232,14 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_size=})"


@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default", version_base="1.1")
def main(cfg: omegaconf.DictConfig) -> None:
"""Debug main to quickly develop the DataModule.
Args:
cfg: the hydra configuration
"""
m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
m: L.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
m.metadata
m.setup()

Expand Down

0 comments on commit c4466d1

Please sign in to comment.