Skip to content

Commit

Permalink
Support non-local file systems
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Jan 7, 2025
1 parent 575d06e commit 11c14ca
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions src/hats_import/verification/run_verification.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Run pass/fail checks and generate verification report of existing hats table."""

import datetime
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal

import hats.io.paths
import hats.io.validation
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pds
from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN

Expand Down Expand Up @@ -79,13 +80,13 @@ class Verifier:

args: VerificationArguments = field()
"""Arguments to use during verification."""
metadata_ds: pyarrow.dataset.Dataset = field()
metadata_ds: pds.Dataset = field()
"""Pyarrow dataset, loaded from the _metadata file."""
files_ds: pyarrow.dataset.Dataset = field()
files_ds: pds.Dataset = field()
"""Pyarrow dataset, loaded from the parquet data files."""
common_metadata_schema: pyarrow.Schema = field()
common_metadata_schema: pa.Schema = field()
"""Pyarrow schema, loaded from the _common_metadata file."""
constructed_truth_schema: pyarrow.Schema = field()
constructed_truth_schema: pa.Schema = field()
"""Pyarrow schema treated as truth during verification. This is constructed
from `common_metadata_schema` and `args.truth_schema`. `common_metadata_schema`
is used for hats-specific columns. If provided, `args.truth_schema` is used
Expand All @@ -110,16 +111,17 @@ def from_args(cls, args: VerificationArguments) -> "Verifier":
args.output_path.mkdir(exist_ok=True, parents=True)

print("Loading dataset and schema.")
files_ds = pyarrow.dataset.dataset(args.input_dataset_path)
metadata_ds = pyarrow.dataset.parquet_dataset(
hats.io.paths.get_parquet_metadata_pointer(args.input_catalog_path)
parquet_fs = args.input_catalog_path.fs
files_ds = pds.dataset(args.input_dataset_path.path, filesystem=parquet_fs)
metadata_ds = pds.parquet_dataset(
hats.io.paths.get_parquet_metadata_pointer(args.input_catalog_path), filesystem=parquet_fs
)

input_truth_schema = None
if args.truth_schema is not None:
input_truth_schema = pds.parquet_dataset(args.truth_schema).schema
common_metadata_schema = pyarrow.dataset.parquet_dataset(
hats.io.paths.get_common_metadata_pointer(args.input_catalog_path)
input_truth_schema = pds.parquet_dataset(args.truth_schema, filesystem=parquet_fs).schema
common_metadata_schema = pds.parquet_dataset(
hats.io.paths.get_common_metadata_pointer(args.input_catalog_path), filesystem=parquet_fs
).schema
constructed_truth_schema = cls._construct_truth_schema(
input_truth_schema=input_truth_schema, common_metadata_schema=common_metadata_schema
Expand Down Expand Up @@ -187,8 +189,9 @@ def test_file_sets(self) -> bool:
description = "Test that files in _metadata match the data files on disk."
print(f"\nStarting: {description}")

_failed_files = set(self.files_ds.files).symmetric_difference(self.metadata_ds.files)
failed_files = [str(Path(file).relative_to(self.args.input_dataset_path)) for file in _failed_files]
files_ds_files = self._relative_paths(self.files_ds.files)
metadata_ds_files = self._relative_paths(self.metadata_ds.files)
failed_files = list(set(files_ds_files).symmetric_difference(metadata_ds_files))
passed = len(failed_files) == 0
self.results.append(
Result(
Expand Down Expand Up @@ -254,7 +257,7 @@ def test_num_rows(self) -> bool:
print(f"Result: {'PASSED' if all_passed else 'FAILED'}")
return all_passed

def _load_nrows(self, dataset: pyarrow.dataset.Dataset) -> pd.DataFrame:
def _load_nrows(self, dataset: pds.Dataset) -> pd.DataFrame:
"""Load the number of rows in each file in the dataset.
Parameters
Expand All @@ -266,13 +269,9 @@ def _load_nrows(self, dataset: pyarrow.dataset.Dataset) -> pd.DataFrame:
-------
pd.DataFrame: A DataFrame with the number of rows per file, indexed by file path.
"""
nrows_df = pd.DataFrame(
columns=["num_rows", "frag_path"],
data=[
(frag.metadata.num_rows, str(Path(frag.path).relative_to(self.args.input_dataset_path)))
for frag in dataset.get_fragments()
],
)
num_rows = [frag.metadata.num_rows for frag in dataset.get_fragments()]
frag_names = self._relative_paths([frag.path for frag in dataset.get_fragments()])
nrows_df = pd.DataFrame({"num_rows": num_rows, "frag_path": frag_names})
nrows_df = nrows_df.set_index("frag_path").sort_index()
return nrows_df

Expand Down Expand Up @@ -308,8 +307,8 @@ def test_schemas(self, check_metadata: bool = False) -> bool:

@staticmethod
def _construct_truth_schema(
*, input_truth_schema: pyarrow.Schema | None, common_metadata_schema: pyarrow.Schema
) -> pyarrow.Schema:
*, input_truth_schema: pa.Schema | None, common_metadata_schema: pa.Schema
) -> pa.Schema:
"""Copy of `input_truth_schema` with HATS fields added from `common_metadata_schema`.
If `input_truth_schema` is not provided, this is just `common_metadata_schema`.
Expand Down Expand Up @@ -338,7 +337,7 @@ def _construct_truth_schema(
input_truth_fields = [fld for fld in input_truth_schema if fld.name not in hats_cols]

constructed_fields = hats_idx_fields + input_truth_fields + hats_partition_fields
constructed_schema = pyarrow.schema(constructed_fields).with_metadata(input_truth_schema.metadata)
constructed_schema = pa.schema(constructed_fields).with_metadata(input_truth_schema.metadata)
return constructed_schema

def _test_schema__common_metadata(self, test_info: dict, check_metadata: bool = False) -> bool:
Expand Down Expand Up @@ -411,7 +410,8 @@ def _test_schema_file_footers(self, test_info: dict, check_metadata: bool = Fals
bad_files = []
for frag in self.files_ds.get_fragments():
if not frag.physical_schema.equals(self.constructed_truth_schema, check_metadata=check_metadata):
bad_files.append(str(Path(frag.path).relative_to(self.args.input_dataset_path)))
bad_files.append(frag.path)
bad_files = self._relative_paths(bad_files)

passed = len(bad_files) == 0
self.results.append(
Expand All @@ -425,6 +425,13 @@ def _test_schema_file_footers(self, test_info: dict, check_metadata: bool = Fals
)
return passed

def _relative_paths(self, absolute_paths):
"""Find the relative path for dataset parquet files,
assuming a pattern like <base_path>/Norder=d/Dir=d/Npix=d"""
relative_path_pattern = re.compile(r".*(Norder.*)")
relative_paths = [str(relative_path_pattern.match(file).group(1)) or file for file in absolute_paths]
return relative_paths

def write_results(self, *, write_mode: Literal["a", "w", "x"] = "a") -> None:
"""Write the verification results to file at `args.output_path` / `args.output_filename`.
Expand Down

0 comments on commit 11c14ca

Please sign in to comment.