Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify catalog reading #394

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions src/hats/catalog/association_catalog/association_catalog.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

from pathlib import Path
from typing import Tuple, Union
from typing import Union

import pandas as pd
import pyarrow as pa
from mocpy import MOC
from upath import UPath

from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset, PixelInputTypes
from hats.io import file_io, paths


class AssociationCatalog(HealpixDataset):
Expand Down Expand Up @@ -54,24 +51,3 @@ def _get_partition_join_info_from_pixels(
if isinstance(join_pixels, pd.DataFrame):
return PartitionJoinInfo(join_pixels)
raise TypeError("join_pixels must be of type PartitionJoinInfo or DataFrame")

@classmethod
def _read_args(
cls, catalog_base_dir: str | Path | UPath
) -> Tuple[TableProperties, PixelInputTypes, JoinPixelInputTypes]: # type: ignore[override]
args = super()._read_args(catalog_base_dir)
partition_join_info = PartitionJoinInfo.read_from_dir(catalog_base_dir)
return args + (partition_join_info,)

@classmethod
def _check_files_exist(cls, catalog_base_dir: str | Path | UPath):
super()._check_files_exist(catalog_base_dir)
partition_join_info_file = paths.get_partition_join_info_pointer(catalog_base_dir)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
if not (
file_io.does_file_or_directory_exist(partition_join_info_file)
or file_io.does_file_or_directory_exist(metadata_file)
):
raise FileNotFoundError(
f"_metadata or partition join info file is required in catalog directory {catalog_base_dir}"
)
43 changes: 9 additions & 34 deletions src/hats/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from __future__ import annotations

from pathlib import Path
from typing import Tuple

from typing_extensions import Self
import pyarrow as pa
from upath import UPath

from hats.catalog.dataset.table_properties import TableProperties
from hats.io import file_io


# pylint: disable=too-few-public-methods
class Dataset:
"""A base HATS dataset that contains a properties file
and the data contained in parquet files"""

def __init__(self, catalog_info: TableProperties, catalog_path: str | Path | UPath | None = None) -> None:
def __init__(
self,
catalog_info: TableProperties,
catalog_path: str | Path | UPath | None = None,
schema: pa.Schema | None = None,
) -> None:
"""Initializes a Dataset

Args:
Expand All @@ -29,34 +34,4 @@ def __init__(self, catalog_info: TableProperties, catalog_path: str | Path | UPa
self.on_disk = catalog_path is not None
self.catalog_base_dir = file_io.get_upath(self.catalog_path)

@classmethod
def read_hats(cls, catalog_path: str | Path | UPath) -> Self:
"""Reads a HATS Catalog from a HATS directory

Args:
catalog_path: path to the root directory of the catalog

Returns:
The initialized catalog object
"""
catalog_base_dir = file_io.get_upath(catalog_path)
cls._check_files_exist(catalog_base_dir)
args = cls._read_args(catalog_base_dir)
kwargs = cls._read_kwargs(catalog_base_dir)
return cls(*args, **kwargs)

@classmethod
def _read_args(cls, catalog_base_dir: str | Path | UPath) -> Tuple[TableProperties]:
catalog_info = TableProperties.read_from_dir(catalog_base_dir)
return (catalog_info,)

@classmethod
def _read_kwargs(cls, catalog_base_dir: str | Path | UPath) -> dict:
return {"catalog_path": catalog_base_dir}

@classmethod
def _check_files_exist(cls, catalog_base_dir: str | Path | UPath):
if not file_io.does_file_or_directory_exist(catalog_base_dir):
raise FileNotFoundError(f"No directory exists at {str(catalog_base_dir)}")
if not file_io.does_file_or_directory_exist(catalog_base_dir / "properties"):
raise FileNotFoundError(f"No properties file found where expected: {str(catalog_base_dir)}")
self.schema = schema
65 changes: 2 additions & 63 deletions src/hats/catalog/healpix_dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import warnings
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Union

import numpy as np
import pandas as pd
Expand All @@ -11,13 +10,10 @@
from typing_extensions import Self
from upath import UPath

import hats.pixel_math.healpix_shim as hp
from hats.catalog.dataset import Dataset
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.partition_info import PartitionInfo
from hats.inspection import plot_pixels
from hats.io import file_io, paths
from hats.io.file_io import read_parquet_metadata
from hats.pixel_math import HealpixPixel
from hats.pixel_tree import PixelAlignment, PixelAlignmentType
from hats.pixel_tree.moc_filter import filter_by_moc
Expand Down Expand Up @@ -55,11 +51,10 @@ def __init__(
moc (mocpy.MOC): MOC object representing the coverage of the catalog
schema (pa.Schema): The pyarrow schema for the catalog
"""
super().__init__(catalog_info, catalog_path=catalog_path)
super().__init__(catalog_info, catalog_path=catalog_path, schema=schema)
self.partition_info = self._get_partition_info_from_pixels(pixels)
self.pixel_tree = self._get_pixel_tree_from_pixels(pixels)
self.moc = moc
self.schema = schema

def get_healpix_pixels(self) -> List[HealpixPixel]:
"""Get healpix pixel objects for all pixels contained in the catalog.
Expand Down Expand Up @@ -89,62 +84,6 @@ def _get_pixel_tree_from_pixels(pixels: PixelInputTypes) -> PixelTree:
return PixelTree.from_healpix(pixels)
raise TypeError("Pixels must be of type PartitionInfo, PixelTree, or List[HealpixPixel]")

@classmethod
def _read_args(cls, catalog_base_dir: str | Path | UPath) -> Tuple[TableProperties, PartitionInfo]:
args = super()._read_args(catalog_base_dir)
partition_info = PartitionInfo.read_from_dir(catalog_base_dir)
return args + (partition_info,)

@classmethod
def _read_kwargs(cls, catalog_base_dir: str | Path | UPath) -> dict:
kwargs = super()._read_kwargs(catalog_base_dir)
kwargs["moc"] = cls._read_moc_from_point_map(catalog_base_dir)
kwargs["schema"] = cls._read_schema_from_metadata(catalog_base_dir)
return kwargs

@classmethod
def _read_moc_from_point_map(cls, catalog_base_dir: str | Path | UPath) -> MOC | None:
"""Reads a MOC object from the `point_map.fits` file if it exists in the catalog directory"""
point_map_path = paths.get_point_map_file_pointer(catalog_base_dir)
if not file_io.does_file_or_directory_exist(point_map_path):
return None
fits_image = file_io.read_fits_image(point_map_path)
order = hp.nside2order(hp.npix2nside(len(fits_image)))
boolean_skymap = fits_image.astype(bool)
ipix = np.where(boolean_skymap)[0]
orders = np.full(ipix.shape, order)
return MOC.from_healpix_cells(ipix, orders, order)

@classmethod
def _read_schema_from_metadata(cls, catalog_base_dir: str | Path | UPath) -> pa.Schema | None:
"""Reads the schema information stored in the _common_metadata or _metadata files."""
common_metadata_file = paths.get_common_metadata_pointer(catalog_base_dir)
common_metadata_exists = file_io.does_file_or_directory_exist(common_metadata_file)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
metadata_exists = file_io.does_file_or_directory_exist(metadata_file)
if not (common_metadata_exists or metadata_exists):
warnings.warn(
"_common_metadata or _metadata files not found for this catalog."
"The arrow schema will not be set."
)
return None
schema_file = common_metadata_file if common_metadata_exists else metadata_file
metadata = read_parquet_metadata(schema_file)
return metadata.schema.to_arrow_schema()

@classmethod
def _check_files_exist(cls, catalog_base_dir: str | Path | UPath):
super()._check_files_exist(catalog_base_dir)
partition_info_file = paths.get_partition_info_pointer(catalog_base_dir)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
if not (
file_io.does_file_or_directory_exist(partition_info_file)
or file_io.does_file_or_directory_exist(metadata_file)
):
raise FileNotFoundError(
f"_metadata or partition info file is required in catalog directory {catalog_base_dir}"
)

def __len__(self):
"""The number of rows in the catalog.

Expand Down
76 changes: 63 additions & 13 deletions src/hats/loaders/read_hats.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from __future__ import annotations

import warnings
from pathlib import Path
from typing import Type

import numpy as np
import pyarrow as pa
from mocpy import MOC
from upath import UPath

import hats.pixel_math.healpix_shim as hp
from hats.catalog import AssociationCatalog, Catalog, CatalogType, Dataset, MarginCatalog
from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.index.index_catalog import IndexCatalog
from hats.catalog.partition_info import PartitionInfo
from hats.io import file_io, paths
from hats.io.file_io import read_parquet_metadata

CATALOG_TYPE_TO_CLASS = {
DATASET_TYPE_TO_CLASS = {
CatalogType.OBJECT: Catalog,
CatalogType.SOURCE: Catalog,
CatalogType.ASSOCIATION: AssociationCatalog,
Expand All @@ -18,7 +26,7 @@
}


def read_hats(catalog_path: str | Path | UPath, catalog_type: CatalogType | None = None) -> Dataset:
def read_hats(catalog_path: str | Path | UPath) -> Dataset:
"""Reads a HATS Catalog from a HATS directory

Args:
Expand All @@ -32,18 +40,60 @@
Returns:
The initialized catalog object
"""
catalog_type_to_use = (
_read_dataset_class_from_metadata(catalog_path) if catalog_type is None else catalog_type
catalog_path = file_io.get_upath(catalog_path)
try:
properties = TableProperties.read_from_dir(catalog_path)
dataset_type = properties.catalog_type
if dataset_type not in DATASET_TYPE_TO_CLASS:
raise NotImplementedError(f"Cannot load catalog of type {dataset_type}")

Check warning on line 48 in src/hats/loaders/read_hats.py

View check run for this annotation

Codecov / codecov/patch

src/hats/loaders/read_hats.py#L48

Added line #L48 was not covered by tests

loader = DATASET_TYPE_TO_CLASS[dataset_type]
kwargs = {"catalog_path": catalog_path, "catalog_info": properties}
kwargs["schema"] = _read_schema_from_metadata(catalog_path)
if _is_healpix_dataset(dataset_type):
kwargs["pixels"] = PartitionInfo.read_from_dir(catalog_path)
kwargs["moc"] = _read_moc_from_point_map(catalog_path)
if dataset_type == CatalogType.ASSOCIATION:
kwargs["join_pixels"] = PartitionJoinInfo.read_from_dir(catalog_path)
return loader(**kwargs)
except Exception as exception: # pylint: disable=broad-except
raise FileNotFoundError(f"Failed to read HATS at location {catalog_path}") from exception


def _is_healpix_dataset(dataset_type):
return dataset_type in (
CatalogType.OBJECT,
CatalogType.SOURCE,
CatalogType.ASSOCIATION,
CatalogType.MARGIN,
)
loader = _get_loader_from_catalog_type(catalog_type_to_use)
return loader.read_hats(catalog_path)


def _read_dataset_class_from_metadata(catalog_base_path: str) -> CatalogType:
return TableProperties.read_from_dir(catalog_base_path).catalog_type
def _read_moc_from_point_map(catalog_base_dir: str | Path | UPath) -> MOC | None:
"""Reads a MOC object from the `point_map.fits` file if it exists in the catalog directory"""
point_map_path = paths.get_point_map_file_pointer(catalog_base_dir)
if not file_io.does_file_or_directory_exist(point_map_path):
return None
fits_image = file_io.read_fits_image(point_map_path)
order = hp.nside2order(hp.npix2nside(len(fits_image)))
boolean_skymap = fits_image.astype(bool)
ipix = np.where(boolean_skymap)[0]
orders = np.full(ipix.shape, order)
return MOC.from_healpix_cells(ipix, orders, order)


def _get_loader_from_catalog_type(catalog_type: CatalogType) -> Type[Dataset]:
if catalog_type not in CATALOG_TYPE_TO_CLASS:
raise NotImplementedError(f"Cannot load catalog of type {catalog_type}")
return CATALOG_TYPE_TO_CLASS[catalog_type]
def _read_schema_from_metadata(catalog_base_dir: str | Path | UPath) -> pa.Schema | None:
"""Reads the schema information stored in the _common_metadata or _metadata files."""
common_metadata_file = paths.get_common_metadata_pointer(catalog_base_dir)
common_metadata_exists = file_io.does_file_or_directory_exist(common_metadata_file)
metadata_file = paths.get_parquet_metadata_pointer(catalog_base_dir)
metadata_exists = file_io.does_file_or_directory_exist(metadata_file)
if not (common_metadata_exists or metadata_exists):
warnings.warn(
"_common_metadata or _metadata files not found for this catalog."
"The arrow schema will not be set."
)
return None
schema_file = common_metadata_file if common_metadata_exists else metadata_file
metadata = read_parquet_metadata(schema_file)
return metadata.schema.to_arrow_schema()
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def test_empty_directory(tmp_path, association_catalog_info_data, association_ca
os.makedirs(catalog_path, exist_ok=True)

## Path exists but there's nothing there
with pytest.raises(FileNotFoundError, match="properties file"):
AssociationCatalog.read_hats(catalog_path)
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

## catalog_info file exists - getting closer
properties = TableProperties(**association_catalog_info_data)
properties.to_properties_file(catalog_path)

with pytest.raises(FileNotFoundError, match="metadata"):
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

## Now we create the needed _metadata and everything is right.
Expand All @@ -99,14 +99,14 @@ def test_csv_round_trip(tmp_path, association_catalog_info_data, association_cat
properties = TableProperties(**association_catalog_info_data)
properties.to_properties_file(catalog_path)

with pytest.raises(FileNotFoundError, match="partition"):
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

file_name = catalog_path / "partition_info.csv"
with open(file_name, "w", encoding="utf-8") as metadata_file:
# dump some garbage in there - just needs to exist.
metadata_file.write("Norder,Npix")
with pytest.raises(FileNotFoundError, match="partition"):
with pytest.raises(FileNotFoundError):
read_hats(catalog_path)

part_info = PartitionJoinInfo(association_catalog_join_pixels)
Expand Down
25 changes: 0 additions & 25 deletions tests/hats/catalog/dataset/test_dataset.py

This file was deleted.

7 changes: 0 additions & 7 deletions tests/hats/catalog/loaders/test_read_hats.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import pytest

from hats.loaders import read_hats


def test_read_hats_wrong_catalog_type(small_sky_dir):
with pytest.raises(NotImplementedError, match="load catalog of type"):
read_hats(small_sky_dir, catalog_type="unknown")


def test_read_hats_branches(
small_sky_dir,
small_sky_order1_dir,
Expand Down
Loading