Skip to content

Commit

Permalink
Adding SQL Dataset related files to the build script
Browse files Browse the repository at this point in the history
Summary: Now that we have SQLAlchemy 2.0, we can fully use them.

Reviewed By: bottler

Differential Revision: D66920096

fbshipit-source-id: 25c0ea1c4f7361e66348035519627dc961b9e6e6
  • Loading branch information
shapovalov authored and facebook-github-bot committed Dec 24, 2024
1 parent 055ab3a commit 64a5bfa
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 35 deletions.
86 changes: 54 additions & 32 deletions pytorch3d/implicitron/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import json
import logging
import os
from dataclasses import dataclass
import urllib
from dataclasses import dataclass, Field, field
from typing import (
Any,
ClassVar,
Expand All @@ -29,9 +30,9 @@
import torch
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase

from pytorch3d.implicitron.dataset.frame_data import ( # noqa
from pytorch3d.implicitron.dataset.frame_data import (
FrameData,
FrameDataBuilder,
FrameDataBuilder, # noqa
FrameDataBuilderBase,
)
from pytorch3d.implicitron.tools.config import (
Expand All @@ -51,7 +52,7 @@


@registry.register
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
class SqlIndexDataset(DatasetBase, ReplaceableBase):
"""
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
The length is returned after all sequence and frame filters are applied (see param
Expand Down Expand Up @@ -125,9 +126,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000
# we set it manually in the constructor
# _index: pd.DataFrame = field(init=False)

frame_data_builder: FrameDataBuilderBase
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
_sql_engine: sa.engine.Engine = field(
init=False, metadata={"omegaconf_ignore": True}
)
eval_batches: Optional[List[Any]] = field(
init=False, metadata={"omegaconf_ignore": True}
)

frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
frame_data_builder_class_type: str = "FrameDataBuilder"

def __post_init__(self) -> None:
Expand All @@ -138,17 +145,23 @@ def __post_init__(self) -> None:
raise ValueError("sqlite_metadata_file must be set")

if self.dataset_root:
frame_builder_type = self.frame_data_builder_class_type
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
"dataset_root"
] = self.dataset_root
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
getattr(self, frame_args)["dataset_root"] = self.dataset_root
getattr(self, frame_args)["path_manager"] = self.path_manager

run_auto_creation(self)
self.frame_data_builder.path_manager = self.path_manager

# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
if self.path_manager is not None:
self.sqlite_metadata_file = self.path_manager.get_local_path(
self.sqlite_metadata_file
)
self.subset_lists_file = self.path_manager.get_local_path(
self.subset_lists_file
)

# NOTE: sqlite-specific args (read-only mode).
self._sql_engine = sa.create_engine(
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
)

sequences = self._get_filtered_sequences_if_any()
Expand All @@ -166,16 +179,15 @@ def __post_init__(self) -> None:
if len(index) == 0:
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")

self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
self._index = index.set_index(["sequence_name", "frame_number"])

self.eval_batches = None # pyre-ignore
self.eval_batches = None
if self.eval_batches_file:
self.eval_batches = self._load_filter_eval_batches()

logger.info(str(self))

def __len__(self) -> int:
# pyre-ignore[16]
return len(self._index)

def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
Expand Down Expand Up @@ -250,7 +262,6 @@ def _get_item(
return frame_data

def __str__(self) -> str:
# pyre-ignore[16]
return f"SqlIndexDataset #frames={len(self._index)}"

def sequence_names(self) -> Iterable[str]:
Expand Down Expand Up @@ -335,12 +346,12 @@ def sequence_frames_in_order(
rows = self._index.index.get_loc(seq_name)
if isinstance(rows, slice):
assert rows.stop is not None, "Unexpected result from pandas"
rows = range(rows.start or 0, rows.stop, rows.step or 1)
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
else:
rows = np.where(rows)[0]
rows_seq = list(np.where(rows)[0])

index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
rows, seq_name, subset_filter
rows_seq, seq_name, subset_filter
)
index_slice["idx"] = idx

Expand Down Expand Up @@ -461,14 +472,15 @@ def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]

def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
assert self.subsets is not None
subsets = self.subsets
assert subsets is not None
with open(subset_lists_path, "r") as f:
subset_to_seq_frame = json.load(f)

seq_frame_list = sum(
(
[(*row, subset) for row in subset_to_seq_frame[subset]]
for subset in self.subsets
for subset in subsets
),
[],
)
Expand Down Expand Up @@ -522,7 +534,7 @@ def _build_index_from_subset_lists(
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
).where(self.frame_annotations_type._mask_mass == 0)
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
with Session(self._sql_engine) as session:
to_remove = session.execute(stmt).all()

Expand Down Expand Up @@ -586,7 +598,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]):
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
self.frame_annotations_type._image_path,
self.frame_annotations_type._image_path, # pyre-ignore[16]
sa.null().label("subset"),
)
where_conditions = []
Expand All @@ -600,7 +612,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]):
logger.info(" excluding samples with empty masks")
where_conditions.append(
sa.or_(
self.frame_annotations_type._mask_mass.is_(None),
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
self.frame_annotations_type._mask_mass != 0,
)
)
Expand Down Expand Up @@ -634,15 +646,18 @@ def _load_filter_eval_batches(self):
assert self.eval_batches_file
logger.info(f"Loading eval batches from {self.eval_batches_file}")

if not os.path.isfile(self.eval_batches_file):
if (
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
f"Looking for dataset json file in {self.eval_batches_file}. "
+ "Please specify a correct dataset_root folder."
)

with open(self.eval_batches_file, "r") as f:
eval_batches_file = self._local_path(self.eval_batches_file)
with open(eval_batches_file, "r") as f:
eval_batches = json.load(f)

# limit the dataset to sequences to allow multiple evaluations in one file
Expand Down Expand Up @@ -758,11 +773,18 @@ def _get_temp_index_table_instance(self, table_name: str = "__index"):
prefixes=["TEMP"], # NOTE SQLite specific!
)

@classmethod
def pre_expand(cls) -> None:
# remove dataclass annotations that are not meant to be init params
# because they cause troubles for OmegaConf
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
if isinstance(attr_value, Field) and attr_value.metadata.get(
"omegaconf_ignore", False
):
delattr(cls, attr)
del cls.__annotations__[attr]


def _seq_name_to_seed(seq_name) -> int:
"""Generates numbers in [0, 2 ** 28)"""
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)


def _safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None
6 changes: 3 additions & 3 deletions pytorch3d/implicitron/dataset/sql_dataset_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


@registry.register
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
"""
Generates the training, validation, and testing dataset objects for
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
Expand Down Expand Up @@ -193,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]

# this is a mould that is never constructed, used to build self._dataset_map values
dataset_class_type: str = "SqlIndexDataset"
dataset: SqlIndexDataset
dataset: SqlIndexDataset # pyre-ignore [13]

path_manager_factory: PathManagerFactory
path_manager_factory: PathManagerFactory # pyre-ignore [13]
path_manager_factory_class_type: str = "PathManagerFactory"

def __post_init__(self):
Expand Down

0 comments on commit 64a5bfa

Please sign in to comment.