Skip to content

Commit

Permalink
Simplify catalog reading (#394)
Browse files Browse the repository at this point in the history
* Simplify catalog reading

* Simplify more.

* Code review comments.
  • Loading branch information
delucchi-cmu authored Oct 28, 2024
1 parent f2704b8 commit 1710fd7
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 182 deletions.
26 changes: 1 addition & 25 deletions src/hats/catalog/association_catalog/association_catalog.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

from pathlib import Path
from typing import Tuple, Union
from typing import Union

import pandas as pd
import pyarrow as pa
from mocpy import MOC
from upath import UPath

from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset, PixelInputTypes
from hats.io import file_io, paths


class AssociationCatalog(HealpixDataset):
Expand Down Expand Up @@ -54,24 +51,3 @@ def _get_partition_join_info_from_pixels(
if isinstance(join_pixels, pd.DataFrame):
return PartitionJoinInfo(join_pixels)
raise TypeError("join_pixels must be of type PartitionJoinInfo or DataFrame")

@classmethod
def _read_args(
cls, catalog_base_dir: str | Path | UPath
) -> Tuple[TableProperties, PixelInputTypes, JoinPixelInputTypes]: # type: ignore[override]
args = super()._read_args(catalog_base_dir)
partition_join_info = PartitionJoinInfo.read_from_dir(catalog_base_dir)
return args + (partition_join_info,)

@classmethod
def _check_files_exist(cls, catalog_base_dir: str | Path | UPath):
super()._check_files_exist(catalog_base_dir)
partition_join_info_file = paths.get_partition_join_info_pointer(catalog_base_dir)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
if not (
file_io.does_file_or_directory_exist(partition_join_info_file)
or file_io.does_file_or_directory_exist(metadata_file)
):
raise FileNotFoundError(
f"_metadata or partition join info file is required in catalog directory {catalog_base_dir}"
)
44 changes: 10 additions & 34 deletions src/hats/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from __future__ import annotations

from pathlib import Path
from typing import Tuple

from typing_extensions import Self
import pyarrow as pa
from upath import UPath

from hats.catalog.dataset.table_properties import TableProperties
from hats.io import file_io


# pylint: disable=too-few-public-methods
class Dataset:
"""A base HATS dataset that contains a properties file
and the data contained in parquet files"""

def __init__(self, catalog_info: TableProperties, catalog_path: str | Path | UPath | None = None) -> None:
def __init__(
self,
catalog_info: TableProperties,
catalog_path: str | Path | UPath | None = None,
schema: pa.Schema | None = None,
) -> None:
"""Initializes a Dataset
Args:
catalog_info: A TableProperties object with the catalog metadata
catalog_path: If the catalog is stored on disk, specify the location of the catalog
Does not load the catalog from this path, only store as metadata
schema (pa.Schema): The pyarrow schema for the catalog
"""
self.catalog_info = catalog_info
self.catalog_name = self.catalog_info.catalog_name
Expand All @@ -29,34 +35,4 @@ def __init__(self, catalog_info: TableProperties, catalog_path: str | Path | UPa
self.on_disk = catalog_path is not None
self.catalog_base_dir = file_io.get_upath(self.catalog_path)

@classmethod
def read_hats(cls, catalog_path: str | Path | UPath) -> Self:
"""Reads a HATS Catalog from a HATS directory
Args:
catalog_path: path to the root directory of the catalog
Returns:
The initialized catalog object
"""
catalog_base_dir = file_io.get_upath(catalog_path)
cls._check_files_exist(catalog_base_dir)
args = cls._read_args(catalog_base_dir)
kwargs = cls._read_kwargs(catalog_base_dir)
return cls(*args, **kwargs)

@classmethod
def _read_args(cls, catalog_base_dir: str | Path | UPath) -> Tuple[TableProperties]:
catalog_info = TableProperties.read_from_dir(catalog_base_dir)
return (catalog_info,)

@classmethod
def _read_kwargs(cls, catalog_base_dir: str | Path | UPath) -> dict:
return {"catalog_path": catalog_base_dir}

@classmethod
def _check_files_exist(cls, catalog_base_dir: str | Path | UPath):
if not file_io.does_file_or_directory_exist(catalog_base_dir):
raise FileNotFoundError(f"No directory exists at {str(catalog_base_dir)}")
if not file_io.does_file_or_directory_exist(catalog_base_dir / "properties"):
raise FileNotFoundError(f"No properties file found where expected: {str(catalog_base_dir)}")
self.schema = schema
65 changes: 2 additions & 63 deletions src/hats/catalog/healpix_dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import warnings
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Union

import numpy as np
import pandas as pd
Expand All @@ -11,13 +10,10 @@
from typing_extensions import Self
from upath import UPath

import hats.pixel_math.healpix_shim as hp
from hats.catalog.dataset import Dataset
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.partition_info import PartitionInfo
from hats.inspection import plot_pixels
from hats.io import file_io, paths
from hats.io.file_io import read_parquet_metadata
from hats.pixel_math import HealpixPixel
from hats.pixel_tree import PixelAlignment, PixelAlignmentType
from hats.pixel_tree.moc_filter import filter_by_moc
Expand Down Expand Up @@ -55,11 +51,10 @@ def __init__(
moc (mocpy.MOC): MOC object representing the coverage of the catalog
schema (pa.Schema): The pyarrow schema for the catalog
"""
super().__init__(catalog_info, catalog_path=catalog_path)
super().__init__(catalog_info, catalog_path=catalog_path, schema=schema)
self.partition_info = self._get_partition_info_from_pixels(pixels)
self.pixel_tree = self._get_pixel_tree_from_pixels(pixels)
self.moc = moc
self.schema = schema

def get_healpix_pixels(self) -> List[HealpixPixel]:
"""Get healpix pixel objects for all pixels contained in the catalog.
Expand Down Expand Up @@ -89,62 +84,6 @@ def _get_pixel_tree_from_pixels(pixels: PixelInputTypes) -> PixelTree:
return PixelTree.from_healpix(pixels)
raise TypeError("Pixels must be of type PartitionInfo, PixelTree, or List[HealpixPixel]")

@classmethod
def _read_args(cls, catalog_base_dir: str | Path | UPath) -> Tuple[TableProperties, PartitionInfo]:
args = super()._read_args(catalog_base_dir)
partition_info = PartitionInfo.read_from_dir(catalog_base_dir)
return args + (partition_info,)

@classmethod
def _read_kwargs(cls, catalog_base_dir: str | Path | UPath) -> dict:
kwargs = super()._read_kwargs(catalog_base_dir)
kwargs["moc"] = cls._read_moc_from_point_map(catalog_base_dir)
kwargs["schema"] = cls._read_schema_from_metadata(catalog_base_dir)
return kwargs

@classmethod
def _read_moc_from_point_map(cls, catalog_base_dir: str | Path | UPath) -> MOC | None:
"""Reads a MOC object from the `point_map.fits` file if it exists in the catalog directory"""
point_map_path = paths.get_point_map_file_pointer(catalog_base_dir)
if not file_io.does_file_or_directory_exist(point_map_path):
return None
fits_image = file_io.read_fits_image(point_map_path)
order = hp.nside2order(hp.npix2nside(len(fits_image)))
boolean_skymap = fits_image.astype(bool)
ipix = np.where(boolean_skymap)[0]
orders = np.full(ipix.shape, order)
return MOC.from_healpix_cells(ipix, orders, order)

@classmethod
def _read_schema_from_metadata(cls, catalog_base_dir: str | Path | UPath) -> pa.Schema | None:
"""Reads the schema information stored in the _common_metadata or _metadata files."""
common_metadata_file = paths.get_common_metadata_pointer(catalog_base_dir)
common_metadata_exists = file_io.does_file_or_directory_exist(common_metadata_file)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
metadata_exists = file_io.does_file_or_directory_exist(metadata_file)
if not (common_metadata_exists or metadata_exists):
warnings.warn(
"_common_metadata or _metadata files not found for this catalog."
"The arrow schema will not be set."
)
return None
schema_file = common_metadata_file if common_metadata_exists else metadata_file
metadata = read_parquet_metadata(schema_file)
return metadata.schema.to_arrow_schema()

@classmethod
def _check_files_exist(cls, catalog_base_dir: str | Path | UPath):
super()._check_files_exist(catalog_base_dir)
partition_info_file = paths.get_partition_info_pointer(catalog_base_dir)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
if not (
file_io.does_file_or_directory_exist(partition_info_file)
or file_io.does_file_or_directory_exist(metadata_file)
):
raise FileNotFoundError(
f"_metadata or partition info file is required in catalog directory {catalog_base_dir}"
)

def __len__(self):
"""The number of rows in the catalog.
Expand Down
85 changes: 66 additions & 19 deletions src/hats/loaders/read_hats.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from __future__ import annotations

import warnings
from pathlib import Path
from typing import Type

import numpy as np
import pyarrow as pa
from mocpy import MOC
from upath import UPath

import hats.pixel_math.healpix_shim as hp
from hats.catalog import AssociationCatalog, Catalog, CatalogType, Dataset, MarginCatalog
from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.index.index_catalog import IndexCatalog
from hats.catalog.partition_info import PartitionInfo
from hats.io import file_io, paths
from hats.io.file_io import read_parquet_metadata

CATALOG_TYPE_TO_CLASS = {
DATASET_TYPE_TO_CLASS = {
CatalogType.OBJECT: Catalog,
CatalogType.SOURCE: Catalog,
CatalogType.ASSOCIATION: AssociationCatalog,
Expand All @@ -18,32 +26,71 @@
}


def read_hats(catalog_path: str | Path | UPath, catalog_type: CatalogType | None = None) -> Dataset:
def read_hats(catalog_path: str | Path | UPath) -> Dataset:
"""Reads a HATS Catalog from a HATS directory
Args:
catalog_path (str): path to the root directory of the catalog
catalog_type (CatalogType): Default `None`. By default, the type of the catalog is loaded
from the catalog info and the corresponding object type is returned. Python's type hints
cannot allow a return type specified by a loaded value, so to use the correct return
type for type checking, the type of the catalog can be specified here. Use by specifying
the hats class for that catalog.
Returns:
The initialized catalog object
"""
catalog_type_to_use = (
_read_dataset_class_from_metadata(catalog_path) if catalog_type is None else catalog_type
catalog_path = file_io.get_upath(catalog_path)
try:
properties = TableProperties.read_from_dir(catalog_path)
dataset_type = properties.catalog_type
if dataset_type not in DATASET_TYPE_TO_CLASS:
raise NotImplementedError(f"Cannot load catalog of type {dataset_type}")

loader = DATASET_TYPE_TO_CLASS[dataset_type]
kwargs = {
"catalog_path": catalog_path,
"catalog_info": properties,
"schema": _read_schema_from_metadata(catalog_path),
}
if _is_healpix_dataset(dataset_type):
kwargs["pixels"] = PartitionInfo.read_from_dir(catalog_path)
kwargs["moc"] = _read_moc_from_point_map(catalog_path)
if dataset_type == CatalogType.ASSOCIATION:
kwargs["join_pixels"] = PartitionJoinInfo.read_from_dir(catalog_path)
return loader(**kwargs)
except Exception as exception: # pylint: disable=broad-except
raise FileNotFoundError(f"Failed to read HATS at location {catalog_path}") from exception


def _is_healpix_dataset(dataset_type):
return dataset_type in (
CatalogType.OBJECT,
CatalogType.SOURCE,
CatalogType.ASSOCIATION,
CatalogType.MARGIN,
)
loader = _get_loader_from_catalog_type(catalog_type_to_use)
return loader.read_hats(catalog_path)


def _read_dataset_class_from_metadata(catalog_base_path: str) -> CatalogType:
return TableProperties.read_from_dir(catalog_base_path).catalog_type
def _read_moc_from_point_map(catalog_base_dir: str | Path | UPath) -> MOC | None:
"""Reads a MOC object from the `point_map.fits` file if it exists in the catalog directory"""
point_map_path = paths.get_point_map_file_pointer(catalog_base_dir)
if not file_io.does_file_or_directory_exist(point_map_path):
return None
fits_image = file_io.read_fits_image(point_map_path)
order = hp.nside2order(hp.npix2nside(len(fits_image)))
boolean_skymap = fits_image.astype(bool)
ipix = np.where(boolean_skymap)[0]
orders = np.full(ipix.shape, order)
return MOC.from_healpix_cells(ipix, orders, order)


def _get_loader_from_catalog_type(catalog_type: CatalogType) -> Type[Dataset]:
if catalog_type not in CATALOG_TYPE_TO_CLASS:
raise NotImplementedError(f"Cannot load catalog of type {catalog_type}")
return CATALOG_TYPE_TO_CLASS[catalog_type]
def _read_schema_from_metadata(catalog_base_dir: str | Path | UPath) -> pa.Schema | None:
"""Reads the schema information stored in the _common_metadata or _metadata files."""
common_metadata_file = paths.get_common_metadata_pointer(catalog_base_dir)
common_metadata_exists = file_io.does_file_or_directory_exist(common_metadata_file)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
metadata_exists = file_io.does_file_or_directory_exist(metadata_file)
if not (common_metadata_exists or metadata_exists):
warnings.warn(
"_common_metadata or _metadata files not found for this catalog."
"The arrow schema will not be set."
)
return None
schema_file = common_metadata_file if common_metadata_exists else metadata_file
metadata = read_parquet_metadata(schema_file)
return metadata.schema.to_arrow_schema()
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def test_empty_directory(tmp_path, association_catalog_info_data, association_ca
os.makedirs(catalog_path, exist_ok=True)

## Path exists but there's nothing there
with pytest.raises(FileNotFoundError, match="properties file"):
AssociationCatalog.read_hats(catalog_path)
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

## catalog_info file exists - getting closer
properties = TableProperties(**association_catalog_info_data)
properties.to_properties_file(catalog_path)

with pytest.raises(FileNotFoundError, match="metadata"):
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

## Now we create the needed _metadata and everything is right.
Expand All @@ -99,14 +99,14 @@ def test_csv_round_trip(tmp_path, association_catalog_info_data, association_cat
properties = TableProperties(**association_catalog_info_data)
properties.to_properties_file(catalog_path)

with pytest.raises(FileNotFoundError, match="partition"):
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

file_name = catalog_path / "partition_info.csv"
with open(file_name, "w", encoding="utf-8") as metadata_file:
# dump some garbage in there - just needs to exist.
metadata_file.write("Norder,Npix")
with pytest.raises(FileNotFoundError, match="partition"):
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

part_info = PartitionJoinInfo(association_catalog_join_pixels)
Expand Down
25 changes: 0 additions & 25 deletions tests/hats/catalog/dataset/test_dataset.py

This file was deleted.

Loading

0 comments on commit 1710fd7

Please sign in to comment.