diff --git a/src/hats_import/verification/arguments.py b/src/hats_import/verification/arguments.py index 9182ca18..0f56d04b 100644 --- a/src/hats_import/verification/arguments.py +++ b/src/hats_import/verification/arguments.py @@ -3,43 +3,51 @@ from __future__ import annotations from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional -from hats import read_hats -from hats.catalog import Catalog -from hats.io.validation import is_valid_catalog +import hats.io.paths +from hats.io import file_io from upath import UPath -from hats_import.runtime_arguments import RuntimeArguments - - -@dataclass -class VerificationArguments(RuntimeArguments): - """Data class for holding verification arguments""" - - ## Input - input_catalog_path: str | Path | UPath | None = None - """Path to an existing catalog that will be inspected.""" - input_catalog: Optional[Catalog] = None - """In-memory representation of a catalog. If not provided, it will be loaded - from the input_catalog_path.""" - - ## Verification options - field_distribution_cols: List[str] = field(default_factory=list) - """List of fields to get the overall distribution for. e.g. ["ra", "dec"]. - Should be valid columns in the parquet files.""" - - def __post_init__(self): - self._check_arguments() - - def _check_arguments(self): - super()._check_arguments() - if not self.input_catalog_path and not self.input_catalog: - raise ValueError("input catalog is required (either input_catalog_path or input_catalog)") - if not self.input_catalog: - if not is_valid_catalog(self.input_catalog_path): - raise ValueError("input_catalog_path not a valid catalog") - self.input_catalog = read_hats(catalog_path=self.input_catalog_path) - if not self.input_catalog_path: - self.input_catalog_path = self.input_catalog.catalog_path + +@dataclass(kw_only=True) +class VerificationArguments: + """Container for verification arguments.""" + + input_catalog_path: UPath = field() + """Path to an existing catalog that will be inspected. This must be a directory + containing (at least) the hats ancillary files and a 'dataset/' directory + containing the parquet dataset. Can be supplied as a string or path object.""" + output_path: UPath = field() + """Directory where the verification report should be written. + Can be supplied as a string or path object.""" + output_filename: str = field(default="verifier_results.csv") + """Filename for the verification report.""" + truth_total_rows: int | None = field(default=None) + """Total number of rows expected in this catalog.""" + truth_schema: UPath | None = field(default=None) + """Path to a parquet file or dataset containing the expected schema. If None (default), + the catalog's _common_metadata file will be used. This schema will be used to verify + all non-hats columns and (optionally) the file-level metadata. Can be supplied as a + string or path object.""" + + @property + def input_dataset_path(self) -> UPath: + """Path to the directory under `input_catalog_path` that contains the parquet dataset.""" + return file_io.append_paths_to_pointer(self.input_catalog_path, hats.io.paths.DATASET_DIR) + + @property + def output_file_path(self) -> UPath: + """Path to the output file (`output_path` / `output_filename`).""" + return file_io.append_paths_to_pointer(self.output_path, self.output_filename) + + def __post_init__(self) -> None: + self.input_catalog_path = file_io.get_upath(self.input_catalog_path) + if not self.input_catalog_path.is_dir(): + raise ValueError("input_catalog_path must be an existing directory") + + self.output_path = file_io.get_upath(self.output_path) + + if self.truth_schema is not None: + self.truth_schema = file_io.append_paths_to_pointer(self.truth_schema) + if not self.truth_schema.exists(): + raise ValueError("truth_schema must be an existing file or directory") diff --git a/src/hats_import/verification/run_verification.py b/src/hats_import/verification/run_verification.py index ea623ddf..e2e0931a 100644 --- a/src/hats_import/verification/run_verification.py +++ b/src/hats_import/verification/run_verification.py @@ -1,14 +1,458 @@ """Run pass/fail checks and generate verification report of existing hats table.""" +import datetime +import re +from dataclasses import dataclass, field +from typing import Literal + +import hats.io.paths +import hats.io.validation +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as pds +from hats import read_hats +from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN + from hats_import.verification.arguments import VerificationArguments -def run(args): - """Run verification pipeline.""" +def run( + args: VerificationArguments, check_metadata: bool = False, write_mode: Literal["a", "w", "x"] = "a" +) -> "Verifier": + """Create a `Verifier` using `args`, run all tests, and write a verification report. + + Parameters + ---------- + args : VerificationArguments + Arguments to construct the Verifier. + check_metadata : bool, optional + Whether to check the metadata along with the schema. Default is False. + write_mode : Literal["a", "w", "x"], optional + Mode to be used when writing output files. Default is "a" (append). + + Returns + ------- + Verifier + The `Verifier` instance used to perform the tests. The `results_df` property contains + the same information as written to the output report. + + Raises + ------ + TypeError + If `args` is not provided or is not an instance of `VerificationArguments`. + """ if not args: raise TypeError("args is required and should be type VerificationArguments") if not isinstance(args, VerificationArguments): raise TypeError("args must be type VerificationArguments") - # implement everything else. - raise NotImplementedError("Verification not yet implemented.") + verifier = Verifier.from_args(args) + verifier.run(write_mode=write_mode, check_metadata=check_metadata) + + return verifier + + +def now() -> str: + """Get the current time as a string.""" + return datetime.datetime.now(datetime.timezone.utc).strftime("%Y/%m/%d %H:%M:%S %Z") + + +@dataclass(kw_only=True, frozen=True) +class Result: + """Verification test result for a single test.""" + + datetime: str = field(default_factory=now) + """The date and time when the test was run.""" + passed: bool = field() + """Whether the test passed.""" + test: str = field() + """Test name.""" + target: str = field() + """The file(s) targeted by the test.""" + description: str = field() + """Test description.""" + bad_files: list[str] = field(default_factory=list) + """List of additional files that caused the test to fail (empty if none or not applicable).""" + + +@dataclass(kw_only=True) +class Verifier: + """Run verification tests. To create an instance of this class, use `Verifier.from_args`.""" + + args: VerificationArguments = field() + """Arguments to use during verification.""" + metadata_ds: pds.Dataset = field() + """Pyarrow dataset, loaded from the _metadata file.""" + files_ds: pds.Dataset = field() + """Pyarrow dataset, loaded from the parquet data files.""" + common_metadata_schema: pa.Schema = field() + """Pyarrow schema, loaded from the _common_metadata file.""" + constructed_truth_schema: pa.Schema = field() + """Pyarrow schema treated as truth during verification. This is constructed + from `common_metadata_schema` and `args.truth_schema`. `common_metadata_schema` + is used for hats-specific columns. If provided, `args.truth_schema` is used + for all other columns plus metadata, otherwise `common_metadata_schema` is used. + """ + results: list[Result] = field(default_factory=list) + """List of results, one for each test that has been done.""" + + @classmethod + def from_args(cls, args: VerificationArguments) -> "Verifier": + """Create a `Verifier` with initialized datasets and schemas based on `args`. + + Parameters + ---------- + args : VerificationArguments + Arguments for the Verifier. + + Returns + ------- + Verifier + """ + args.output_path.mkdir(exist_ok=True, parents=True) + + print("Loading dataset and schema.") + parquet_fs = args.input_catalog_path.fs + files_ds = pds.dataset(args.input_dataset_path.path, filesystem=parquet_fs) + metadata_ds = pds.parquet_dataset( + hats.io.paths.get_parquet_metadata_pointer(args.input_catalog_path), filesystem=parquet_fs + ) + + input_truth_schema = None + if args.truth_schema is not None: + input_truth_schema = pds.parquet_dataset(args.truth_schema, filesystem=parquet_fs).schema + common_metadata_schema = pds.parquet_dataset( + hats.io.paths.get_common_metadata_pointer(args.input_catalog_path), filesystem=parquet_fs + ).schema + constructed_truth_schema = cls._construct_truth_schema( + input_truth_schema=input_truth_schema, common_metadata_schema=common_metadata_schema + ) + + return cls( + args=args, + metadata_ds=metadata_ds, + files_ds=files_ds, + common_metadata_schema=common_metadata_schema, + constructed_truth_schema=constructed_truth_schema, + ) + + @property + def results_df(self) -> pd.DataFrame: + """Test results as a dataframe.""" + return pd.DataFrame(self.results) + + def run(self, write_mode: Literal["a", "w", "x"] = "a", check_metadata: bool = False) -> None: + """Run all tests and write a verification report. See `results_df` property or + written report for results. + + Parameters + ---------- + write_mode : Literal["a", "w", "x"], optional + Mode to be used when writing output files. + check_metadata : bool, optional + Whether to check the metadata as well as the schema. + """ + self.test_is_valid_catalog() + self.test_file_sets() + self.test_num_rows() + self.test_schemas(check_metadata=check_metadata) + + self.write_results(write_mode=write_mode) + + def test_is_valid_catalog(self) -> bool: + """Test if the provided catalog is a valid HATS catalog. Add one `Result` to `results`. + + Returns + ------- + bool: True if the test passed, else False. + """ + version = f"hats version {hats.__version__}" + test, description = "valid hats", f"Test hats.io.validation.is_valid_catalog ({version})." + target = self.args.input_catalog_path + print(f"\nStarting: {description}") + + passed = hats.io.validation.is_valid_catalog(target, strict=True) + self.results.append(Result(test=test, description=description, passed=passed, target=target.name)) + print(f"Result: {'PASSED' if passed else 'FAILED'}") + return passed + + def test_file_sets(self) -> bool: + """Test that files in _metadata match the parquet files on disk. Add one `Result` to `results`. + + This is a simple test that can be especially useful to run after copying or moving + the catalog to a different local or cloud-based destination. + + Returns + ------- + bool: True if the file sets match, else False. + """ + # info for the report + description = "Test that files in _metadata match the data files on disk." + print(f"\nStarting: {description}") + + files_ds_files = self._relative_paths(self.files_ds.files) + metadata_ds_files = self._relative_paths(self.metadata_ds.files) + failed_files = list(set(files_ds_files).symmetric_difference(metadata_ds_files)) + passed = len(failed_files) == 0 + self.results.append( + Result( + passed=passed, + test="file sets", + target="_metadata vs data files", + bad_files=failed_files, + description=description, + ) + ) + + print(f"Result: {'PASSED' if passed else 'FAILED'}") + return passed + + def test_num_rows(self) -> bool: + """Test the number of rows in the dataset. Add `Result`s to `results`. + + File footers are compared with _metadata and the user-supplied truth (if provided). + + Returns + ------- + bool: True if all checks pass, else False. + """ + test = "num rows" + description = "Test that number of rows are equal." + print(f"\nStarting: {description}") + + catalog_prop_len = read_hats(self.args.input_catalog_path).catalog_info.total_rows + + # get the number of rows in each file, indexed by file path. we treat this as truth. + files_df = self._load_nrows(self.files_ds) + files_df_sum = files_df.num_rows.sum() + files_df_total = f"file footers ({files_df_sum:,})" + + target = "file footers vs catalog properties" + print(f"\t{target}") + passed_cat = catalog_prop_len == files_df_sum + _description = f" {files_df_total} vs catalog properties ({catalog_prop_len:,})." + self.results.append( + Result(passed=passed_cat, test=test, target=target, description=description + _description) + ) + + # check _metadata + target = "file footers vs _metadata" + print(f"\t{target}") + metadata_df = self._load_nrows(self.metadata_ds) + row_diff = files_df - metadata_df + failed_frags = row_diff.loc[row_diff.num_rows != 0].index.to_list() + passed_md = len(failed_frags) == 0 + _description = f" {files_df_total} vs _metadata ({metadata_df.num_rows.sum():,})." + self.results.append( + Result( + passed=passed_md, + test=test, + bad_files=failed_frags, + target=target, + description=description + _description, + ) + ) + + # check user-supplied total, if provided + if self.args.truth_total_rows is not None: + target = "file footers vs truth" + print(f"\t{target}") + passed_th = self.args.truth_total_rows == files_df_sum + _description = f" {files_df_total} vs user-provided truth ({self.args.truth_total_rows:,})." + self.results.append( + Result(passed=passed_th, test=test, target=target, description=description + _description) + ) + else: + passed_th = True # this test did not fail. this is only needed for the return value. + + all_passed = all([passed_md, passed_th, passed_cat]) + print(f"Result: {'PASSED' if all_passed else 'FAILED'}") + return all_passed + + def _load_nrows(self, dataset: pds.Dataset) -> pd.DataFrame: + """Load the number of rows in each file in the dataset. + + Parameters + ---------- + dataset : pyarrow.dataset.Dataset + The dataset from which to load the number of rows. + + Returns + ------- + pd.DataFrame: A DataFrame with the number of rows per file, indexed by file path. + """ + num_rows = [frag.metadata.num_rows for frag in dataset.get_fragments()] + frag_names = self._relative_paths([frag.path for frag in dataset.get_fragments()]) + nrows_df = pd.DataFrame({"num_rows": num_rows, "frag_path": frag_names}) + nrows_df = nrows_df.set_index("frag_path").sort_index() + return nrows_df + + def test_schemas(self, check_metadata: bool = False) -> bool: + """Test the equality of schemas. Add `Result`s to `results`. + + This performs three tests: + 1. `common_metadata_schema` vs `constructed_truth_schema`. + 2. `metadata_ds.schema` vs `constructed_truth_schema`. + 3. File footers vs `constructed_truth_schema`. + + Parameters + ---------- + check_metadata : bool, optional + Whether to check the metadata as well as the schema. + + Returns + ------- + bool: True if all tests pass, else False. + """ + # info for the report + _include_md = "including metadata" if check_metadata else "excluding metadata" + test_info = {"test": "schema", "description": f"Test that schemas are equal, {_include_md}."} + print(f"\nStarting: {test_info['description']}") + + passed_cm = self._test_schema__common_metadata(test_info, check_metadata=check_metadata) + passed_md = self._test_schema__metadata(test_info, check_metadata=check_metadata) + passed_ff = self._test_schema_file_footers(test_info, check_metadata=check_metadata) + + all_passed = all([passed_cm, passed_md, passed_ff]) + print(f"Result: {'PASSED' if all_passed else 'FAILED'}") + return all_passed + + @staticmethod + def _construct_truth_schema( + *, input_truth_schema: pa.Schema | None, common_metadata_schema: pa.Schema + ) -> pa.Schema: + """Copy of `input_truth_schema` with HATS fields added from `common_metadata_schema`. + + If `input_truth_schema` is not provided, this is just `common_metadata_schema`. + + Parameters + ---------- + input_truth_schema : pyarrow.Schema or None + The input truth schema, if provided. + common_metadata_schema : pyarrow.Schema + The common metadata schema. + + Returns + ------- + pyarrow.Schema + The constructed truth schema. + """ + if input_truth_schema is None: + return common_metadata_schema + + hats_cols = ["Norder", "Dir", "Npix"] + hats_partition_fields = [common_metadata_schema.field(fld) for fld in hats_cols] + hats_idx_fields = [] + if SPATIAL_INDEX_COLUMN in common_metadata_schema.names: + hats_cols.append(SPATIAL_INDEX_COLUMN) + hats_idx_fields.append(common_metadata_schema.field(SPATIAL_INDEX_COLUMN)) + input_truth_fields = [fld for fld in input_truth_schema if fld.name not in hats_cols] + + constructed_fields = hats_idx_fields + input_truth_fields + hats_partition_fields + constructed_schema = pa.schema(constructed_fields).with_metadata(input_truth_schema.metadata) + return constructed_schema + + def _test_schema__common_metadata(self, test_info: dict, check_metadata: bool = False) -> bool: + """Test `common_metadata_schema` against `constructed_truth_schema`. + + Parameters + ---------- + test_info : dict + Information about this test for the reported results. + check_metadata : bool, optional + Whether to check the metadata as well as the schema. + + Returns + ------- + bool: True if all tests pass, else False. + """ + targets = "_common_metadata vs truth" + print(f"\t{targets}") + passed = self.common_metadata_schema.equals( + self.constructed_truth_schema, check_metadata=check_metadata + ) + self.results.append( + Result( + passed=passed, target=targets, test=test_info["test"], description=test_info["description"] + ) + ) + return passed + + def _test_schema__metadata(self, test_info: dict, check_metadata: bool = False) -> bool: + """Test _metadata schema against the truth schema. + + Parameters + ---------- + test_info : dict + Information about this test for the reported results. + check_metadata : bool, optional + Whether to check the metadata as well as the schema. + + Returns + ------- + bool: True if both schema and metadata match the truth source, else False. + """ + targets = "_metadata vs truth" + print(f"\t{targets}") + passed = self.metadata_ds.schema.equals(self.constructed_truth_schema, check_metadata=check_metadata) + self.results.append( + Result( + passed=passed, target=targets, test=test_info["test"], description=test_info["description"] + ) + ) + return passed + + def _test_schema_file_footers(self, test_info: dict, check_metadata: bool = False) -> bool: + """Test the file footers schema and metadata against the truth schema. + + Parameters + ---------- + test_info : dict + Information about this test for the reported results. + check_metadata : bool, optional + Whether to check the metadata as well as the schema. + + Returns + ------- + bool: True if all schema and metadata tests pass, else False. + """ + targets = "file footers vs truth" + print(f"\t{targets}") + + bad_files = [] + for frag in self.files_ds.get_fragments(): + if not frag.physical_schema.equals(self.constructed_truth_schema, check_metadata=check_metadata): + bad_files.append(frag.path) + bad_files = self._relative_paths(bad_files) + + passed = len(bad_files) == 0 + self.results.append( + Result( + passed=passed, + target=targets, + bad_files=bad_files, + test=test_info["test"], + description=test_info["description"], + ) + ) + return passed + + def _relative_paths(self, absolute_paths): + """Find the relative path for dataset parquet files, + assuming a pattern like /Norder=d/Dir=d/Npix=d""" + relative_path_pattern = re.compile(r".*(Norder.*)") + relative_paths = [str(relative_path_pattern.match(file).group(1)) or file for file in absolute_paths] + return relative_paths + + def write_results(self, *, write_mode: Literal["a", "w", "x"] = "a") -> None: + """Write the verification results to file at `args.output_path` / `args.output_filename`. + + Parameters + ---------- + write_mode : Literal["a", "w", "x"], optional + Mode to be used when writing output file. Passed to pandas.DataFrame.to_csv as `mode`. + """ + self.args.output_file_path.parent.mkdir(exist_ok=True, parents=True) + header = not (write_mode == "a" and self.args.output_file_path.exists()) + self.results_df.to_csv(self.args.output_file_path, mode=write_mode, header=header, index=False) + print(f"\nVerifier results written to {self.args.output_file_path}") diff --git a/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.extra_column.parquet b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.extra_column.parquet new file mode 100644 index 00000000..adbaf0e2 Binary files /dev/null and b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.extra_column.parquet differ diff --git a/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.missing_column.parquet b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.missing_column.parquet new file mode 100644 index 00000000..4070c183 Binary files /dev/null and b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.missing_column.parquet differ diff --git a/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.parquet b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.parquet new file mode 100644 index 00000000..26f5bb80 Binary files /dev/null and b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.parquet differ diff --git a/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.wrong_dtypes.parquet b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.wrong_dtypes.parquet new file mode 100644 index 00000000..3db76c00 Binary files /dev/null and b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.wrong_dtypes.parquet differ diff --git a/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.wrong_metadata.parquet b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.wrong_metadata.parquet new file mode 100644 index 00000000..d88b1b4b Binary files /dev/null and b/tests/data/bad_schemas/dataset/Norder=0/Dir=0/Npix=11.wrong_metadata.parquet differ diff --git a/tests/data/bad_schemas/dataset/_common_metadata b/tests/data/bad_schemas/dataset/_common_metadata new file mode 100644 index 00000000..2cbcc792 Binary files /dev/null and b/tests/data/bad_schemas/dataset/_common_metadata differ diff --git a/tests/data/bad_schemas/dataset/_common_metadata.import_truth b/tests/data/bad_schemas/dataset/_common_metadata.import_truth new file mode 100644 index 00000000..6e504600 Binary files /dev/null and b/tests/data/bad_schemas/dataset/_common_metadata.import_truth differ diff --git a/tests/data/bad_schemas/dataset/_metadata b/tests/data/bad_schemas/dataset/_metadata new file mode 100644 index 00000000..b1f8665b Binary files /dev/null and b/tests/data/bad_schemas/dataset/_metadata differ diff --git a/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.extra_file.parquet b/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.extra_file.parquet new file mode 100644 index 00000000..30a095a6 Binary files /dev/null and b/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.extra_file.parquet differ diff --git a/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.extra_rows.parquet b/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.extra_rows.parquet new file mode 100644 index 00000000..1e3e6723 Binary files /dev/null and b/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.extra_rows.parquet differ diff --git a/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.parquet b/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.parquet new file mode 100644 index 00000000..26f5bb80 Binary files /dev/null and b/tests/data/wrong_files_and_rows/dataset/Norder=0/Dir=0/Npix=11.parquet differ diff --git a/tests/data/wrong_files_and_rows/dataset/_common_metadata b/tests/data/wrong_files_and_rows/dataset/_common_metadata new file mode 100644 index 00000000..2e494c10 Binary files /dev/null and b/tests/data/wrong_files_and_rows/dataset/_common_metadata differ diff --git a/tests/data/wrong_files_and_rows/dataset/_metadata b/tests/data/wrong_files_and_rows/dataset/_metadata new file mode 100644 index 00000000..9d2598d6 Binary files /dev/null and b/tests/data/wrong_files_and_rows/dataset/_metadata differ diff --git a/tests/data/wrong_files_and_rows/partition_info.csv b/tests/data/wrong_files_and_rows/partition_info.csv new file mode 100644 index 00000000..bf77935e --- /dev/null +++ b/tests/data/wrong_files_and_rows/partition_info.csv @@ -0,0 +1,2 @@ +Norder,Npix +0,11 diff --git a/tests/data/wrong_files_and_rows/properties b/tests/data/wrong_files_and_rows/properties new file mode 100644 index 00000000..801bc768 --- /dev/null +++ b/tests/data/wrong_files_and_rows/properties @@ -0,0 +1,8 @@ +#HATS catalog +obs_collection=wrong_files_and_rows +dataproduct_type=object +hats_nrows=600 +hats_col_ra=source_ra +hats_col_dec=source_dec +hats_order=2 + diff --git a/tests/hats_import/conftest.py b/tests/hats_import/conftest.py index 6db6f58c..bccd0109 100644 --- a/tests/hats_import/conftest.py +++ b/tests/hats_import/conftest.py @@ -305,3 +305,13 @@ def assert_parquet_file_index(file_name, expected_values): npt.assert_array_equal(values, expected_values) return assert_parquet_file_index + + +@pytest.fixture +def bad_schemas_dir(test_data_dir): + return test_data_dir / "bad_schemas" + + +@pytest.fixture +def wrong_files_and_rows_dir(test_data_dir): + return test_data_dir / "wrong_files_and_rows" diff --git a/tests/hats_import/verification/generate_malformed_catalogs.py b/tests/hats_import/verification/generate_malformed_catalogs.py new file mode 100644 index 00000000..effa14e8 --- /dev/null +++ b/tests/hats_import/verification/generate_malformed_catalogs.py @@ -0,0 +1,197 @@ +import random +import shutil +from pathlib import Path + +import attrs +import pyarrow +import pyarrow.dataset +import pyarrow.parquet + +DATA_DIR = Path(__file__).parent.parent.parent.parent / "tests/data" +INPUT_CATALOG_DIR = DATA_DIR / "small_sky_object_catalog" + + +def run(input_catalog_dir: Path = INPUT_CATALOG_DIR, data_dir: Path = DATA_DIR) -> None: + """Generate malformed catalogs to be used as test data for verification. + This only needs to be run once unless/until it is desirable to regenerate the dataset. + """ + Generate().run(input_catalog_dir=input_catalog_dir, data_dir=data_dir) + + +@attrs.define +class CatalogBase: + """Container to hold the input catalog (loaded as pyarrow objects) and output paths.""" + + dataset: pyarrow.dataset.Dataset = attrs.field() + frag: pyarrow.dataset.FileFragment = attrs.field() + tbl: pyarrow.Table = attrs.field() + schema: pyarrow.Schema = attrs.field() + input_catalog_dir: Path = attrs.field() + data_dir: Path = attrs.field() + cat_dir_name: str = attrs.field(factory=str) + + @classmethod + def from_dirs(cls, input_catalog_dir: Path, data_dir: Path) -> "CatalogBase": + input_ds = pyarrow.dataset.parquet_dataset(input_catalog_dir / "dataset/_metadata") + # assert input_ds.metadata is not None, "Unit tests expect " + input_frag = next(input_ds.get_fragments()) + input_tbl = input_frag.to_table() + return cls( + dataset=input_ds, + frag=input_frag, + tbl=input_tbl, + schema=input_tbl.schema, + input_catalog_dir=input_catalog_dir, + data_dir=data_dir, + ) + + @property + def fmeta(self) -> Path: + return self.data_dir / self.cat_dir_name / "dataset/_metadata" + + @property + def fcmeta(self) -> Path: + return self.data_dir / self.cat_dir_name / "dataset/_common_metadata" + + @property + def fdata(self) -> Path: + frag_key = Path(self.frag.path).relative_to(self.input_catalog_dir) + return self.data_dir / self.cat_dir_name / frag_key + + +@attrs.define +class Generate: + """Generate malformed catalogs for verification testing.""" + + def run(self, input_catalog_dir: Path = INPUT_CATALOG_DIR, data_dir: Path = DATA_DIR) -> None: + """Generate malformed catalogs to be used as test data for verification. + This only needs to be run once unless/until it is desirable to regenerate the dataset. + """ + print(f"Generating malformed catalogs from input catalog {input_catalog_dir}.") + + catbase = CatalogBase.from_dirs(input_catalog_dir=input_catalog_dir, data_dir=data_dir) + Generate().bad_schemas(catbase=catbase) + Generate().wrong_files_and_rows(catbase=catbase) + + def bad_schemas(self, catbase: CatalogBase) -> None: + """Case: Files are altered in a way that affects the schema after _metadata gets written.""" + catbase.cat_dir_name = "bad_schemas" + self._start_new_catalog(catbase) + + # Write new files with the correct schema + fextra_col = catbase.fdata.with_suffix(".extra_column.parquet") + fmissing_col = catbase.fdata.with_suffix(".missing_column.parquet") + fwrong_metadata = catbase.fdata.with_suffix(".wrong_metadata.parquet") + fwrong_types = catbase.fdata.with_suffix(".wrong_dtypes.parquet") + for _fout in [fmissing_col, fextra_col, fwrong_types]: + pyarrow.parquet.write_table(catbase.tbl, _fout) + + # Write a _metadata that is correct except for file-level metadata + extra_metadata = {b"extra key": b"extra value"} + self._collect_and_write_metadata(catbase, schema=catbase.schema.with_metadata(extra_metadata)) + + # Overwrite the new data files using incorrect schemas. + # drop a column + pyarrow.parquet.write_table(catbase.tbl.drop_columns("dec_error"), fmissing_col) + # add an extra column + extra_col = pyarrow.array(random.sample(range(1000), len(catbase.tbl))) + extra_col_tbl = catbase.tbl.add_column(5, pyarrow.field("extra", pyarrow.int64()), extra_col) + pyarrow.parquet.write_table(extra_col_tbl, fextra_col) + # add or drop file-level metadata + wrong_metadata = {"bad key": "bad value"} if catbase.tbl.schema.metadata is None else None + pyarrow.parquet.write_table(catbase.tbl.replace_schema_metadata(wrong_metadata), fwrong_metadata) + # change some types + wrong_dtypes = [ + fld if not fld.name.startswith("ra") else fld.with_type(pyarrow.float16()) + for fld in catbase.schema + ] + wrong_dtypes_schema = pyarrow.schema(wrong_dtypes).with_metadata(catbase.schema.metadata) + pyarrow.parquet.write_table(catbase.tbl.cast(wrong_dtypes_schema), fwrong_types) + + # Write a _common_metadata with the wrong dtypes. + pyarrow.parquet.write_metadata(schema=wrong_dtypes_schema, where=catbase.fcmeta) + + # Write a _common_metadata without hats columns. + # This mimics a schema that could have been passed as 'use_schema_file' upon import. + fcustom_md = catbase.fcmeta.with_suffix(".import_truth") + hats_cols = ["_healpix_29", "Norder", "Dir", "Npix"] + import_fields = [fld for fld in catbase.schema if fld.name not in hats_cols] + import_schema = pyarrow.schema(import_fields) + pyarrow.parquet.write_metadata(schema=import_schema, where=fcustom_md) + + print(f"Malformed catalog written to {catbase.fmeta.parent.parent}") + + def wrong_files_and_rows(self, catbase: CatalogBase) -> None: + """Case: Dataset is altered in a way that affects the number of rows and/or files + after _metadata gets written.""" + catbase.cat_dir_name = "wrong_files_and_rows" + self._start_new_catalog(catbase) + + fmissing_file = catbase.fdata.with_suffix(".missing_file.parquet") + fextra_file = catbase.fdata.with_suffix(".extra_file.parquet") + fextra_rows = catbase.fdata.with_suffix(".extra_rows.parquet") + + # Write the "correct" dataset, including metadata. + pyarrow.parquet.write_table(catbase.tbl, fmissing_file) + pyarrow.parquet.write_table(catbase.tbl, fextra_rows) + self._collect_and_write_metadata(catbase) + + # Mangle the dataset. + fmissing_file.unlink() + pyarrow.parquet.write_table(catbase.tbl, fextra_file) + pyarrow.parquet.write_table(self._tbl_with_extra_rows(catbase), fextra_rows) + + print(f"Malformed catalog written to {catbase.fmeta.parent.parent}") + + def _tbl_with_extra_rows(self, catbase: CatalogBase) -> pyarrow.Table: + """Generate a table with extra rows.""" + # generate new rows + rng = range(len(catbase.tbl)) + nrows, new_rows = 2, {} + for col in catbase.tbl.column_names: + if col not in ("_healpix_29", "id"): + # just take a random sample + new_rows[col] = catbase.tbl.column(col).take(random.sample(rng, nrows)) + else: + # increment the max value to avoid duplicates + max_id = catbase.tbl.column(col).sort()[-1].as_py() + new_rows[col] = [i + max_id for i in range(1, nrows + 1)] + + # add the rows to the table + new_tbl = pyarrow.concat_tables( + [catbase.tbl, pyarrow.Table.from_pydict(new_rows, schema=catbase.schema)] + ) + return new_tbl + + @staticmethod + def _start_new_catalog(catbase: CatalogBase, with_ancillaries: bool = False) -> None: + # Start a new catalog by creating the directory and copying in input files. + dataset_dir = catbase.fmeta.parent + if dataset_dir.is_dir(): + print(f"Output directory exists. Remove it and try again.\n{dataset_dir}") + return + + catbase.fdata.parent.mkdir(parents=True) + shutil.copy(catbase.frag.path, catbase.fdata) + + if with_ancillaries: + for fin in catbase.input_catalog_dir.iterdir(): + if fin.is_file(): + shutil.copy(fin, dataset_dir.parent / fin.name) + for fin in (catbase.input_catalog_dir / "dataset").iterdir(): + if fin.is_file(): + shutil.copy(fin, dataset_dir / fin.name) + + @staticmethod + def _collect_and_write_metadata(catbase: CatalogBase, schema: pyarrow.Schema | None = None) -> None: + base_dir = catbase.fmeta.parent + schema = schema or catbase.schema + dataset = pyarrow.dataset.dataset(base_dir) + metadata_collector = [] + for frag in dataset.get_fragments(): + frag.ensure_complete_metadata() + frag.metadata.set_file_path(str(Path(frag.path).relative_to(base_dir))) + metadata_collector.append(frag.metadata) + pyarrow.parquet.write_metadata( + schema=schema, where=catbase.fmeta, metadata_collector=metadata_collector + ) diff --git a/tests/hats_import/verification/test_run_verification.py b/tests/hats_import/verification/test_run_verification.py index 33be14f5..697416e4 100644 --- a/tests/hats_import/verification/test_run_verification.py +++ b/tests/hats_import/verification/test_run_verification.py @@ -1,3 +1,4 @@ +import pandas as pd import pytest import hats_import.verification.run_verification as runner @@ -14,12 +15,132 @@ def test_bad_args(): runner.run(args) -def test_no_implementation(tmp_path, small_sky_object_catalog): - """Womp womp. Test that we don't have a verification pipeline implemented""" +def test_runner(small_sky_object_catalog, wrong_files_and_rows_dir, tmp_path): + """Runner should execute all tests and write a report to file.""" + result_cols = ["datetime", "passed", "test", "target"] + + args = VerificationArguments(input_catalog_path=small_sky_object_catalog, output_path=tmp_path) + verifier = runner.run(args, write_mode="w") + all_passed = verifier.results_df.passed.all() + assert all_passed, "good catalog failed" + written_results = pd.read_csv(args.output_path / args.output_filename) + assert written_results[result_cols].equals(verifier.results_df[result_cols]), "report failed" + + args = VerificationArguments(input_catalog_path=wrong_files_and_rows_dir, output_path=tmp_path) + verifier = runner.run(args, write_mode="w") + assert not verifier.results_df.passed.all(), "bad catalog passed" + written_results = pd.read_csv(args.output_path / args.output_filename) + assert written_results[result_cols].equals(verifier.results_df[result_cols]), "report failed" + + +def test_test_file_sets(small_sky_object_catalog, wrong_files_and_rows_dir, tmp_path): + """File set tests should fail if files listed in _metadata don't match the actual data files.""" + args = VerificationArguments(input_catalog_path=small_sky_object_catalog, output_path=tmp_path) + verifier = runner.Verifier.from_args(args) + passed = verifier.test_file_sets() + assert passed, "good catalog failed" + + args = VerificationArguments(input_catalog_path=wrong_files_and_rows_dir, output_path=tmp_path) + verifier = runner.Verifier.from_args(args) + passed = verifier.test_file_sets() + assert not passed, "bad catalog passed" + expected_bad_file_names = {"Npix=11.extra_file.parquet", "Npix=11.missing_file.parquet"} + actual_bad_file_names = { + file_name.split("/")[-1] for file_name in verifier.results_df.bad_files.squeeze() + } + assert expected_bad_file_names == actual_bad_file_names, "bad_files failed" + + +def test_test_is_valid_catalog(small_sky_object_catalog, wrong_files_and_rows_dir, tmp_path): + """`hats.is_valid_catalog` should pass for good catalogs, fail for catalogs without ancillary files.""" + args = VerificationArguments(input_catalog_path=small_sky_object_catalog, output_path=tmp_path) + verifier = runner.Verifier.from_args(args) + passed = verifier.test_is_valid_catalog() + assert passed, "good catalog failed" + + args = VerificationArguments(input_catalog_path=wrong_files_and_rows_dir, output_path=tmp_path) + verifier = runner.Verifier.from_args(args) + passed = verifier.test_is_valid_catalog() + assert not passed, "bad catalog passed" + + +def test_test_num_rows(small_sky_object_catalog, wrong_files_and_rows_dir, tmp_path): + """Row count tests should pass if all row counts match, else fail.""" + args = VerificationArguments( + input_catalog_path=small_sky_object_catalog, output_path=tmp_path, truth_total_rows=131 + ) + verifier = runner.Verifier.from_args(args) + verifier.test_num_rows() + all_passed = verifier.results_df.passed.all() + assert all_passed, "good catalog failed" + + args = VerificationArguments( + input_catalog_path=wrong_files_and_rows_dir, output_path=tmp_path, truth_total_rows=131 + ) + verifier = runner.Verifier.from_args(args) + verifier.test_num_rows() + results = verifier.results_df + all_failed = not results.passed.any() + assert all_failed, "bad catalog passed" + + targets = {"file footers vs catalog properties", "file footers vs _metadata", "file footers vs truth"} + assert targets == set(results.target), "wrong targets" + + expected_bad_file_names = { + "Npix=11.extra_file.parquet", + "Npix=11.extra_rows.parquet", + "Npix=11.missing_file.parquet", + } + _result = results.loc[results.target == "file footers vs _metadata"].squeeze() + actual_bad_file_names = {file_name.split("/")[-1] for file_name in _result.bad_files} + assert expected_bad_file_names == actual_bad_file_names, "wrong bad_files" + + +@pytest.mark.parametrize("check_metadata", [(False,), (True,)]) +def test_test_schemas(small_sky_object_catalog, bad_schemas_dir, tmp_path, check_metadata): + """Schema tests should pass if all column names, dtypes, and (optionally) metadata match, else fail.""" + # Show that a good catalog passes args = VerificationArguments( input_catalog_path=small_sky_object_catalog, output_path=tmp_path, - output_artifact_name="small_sky_object_verification_report", + truth_schema=small_sky_object_catalog / "dataset/_common_metadata", ) - with pytest.raises(NotImplementedError, match="not yet implemented"): - runner.run(args) + verifier = runner.Verifier.from_args(args) + verifier.test_schemas(check_metadata=check_metadata) + all_passed = verifier.results_df.passed.all() + assert all_passed, "good catalog failed" + + # Show that bad schemas fail. + args = VerificationArguments( + input_catalog_path=bad_schemas_dir, + output_path=tmp_path, + truth_schema=bad_schemas_dir / "dataset/_common_metadata.import_truth", + ) + verifier = runner.Verifier.from_args(args) + verifier.test_schemas(check_metadata=check_metadata) + results = verifier.results_df + + # Expecting _common_metadata and some file footers to always fail + # and _metadata to fail if check_metadata is true. + expect_failed = ["_common_metadata vs truth", "file footers vs truth"] + if check_metadata: + expect_passed = [] + expect_failed = expect_failed + ["_metadata vs truth"] + else: + expect_passed = ["_metadata vs truth"] + assert set(expect_passed + expect_failed) == set(results.target), "wrong targets" + assert all(results.loc[results.target.isin(expect_passed)].passed), "good targets failed" + assert not any(results.loc[results.target.isin(expect_failed)].passed), "bad targets passed" + + # Expecting data files with wrong columns or dtypes to always fail + # and files with wrong metadata to fail if check_metadata is true. + result = results.loc[results.target == "file footers vs truth"].squeeze() + expected_bad_files = [ + "Npix=11.extra_column.parquet", + "Npix=11.missing_column.parquet", + "Npix=11.wrong_dtypes.parquet", + ] + if check_metadata: + expected_bad_files = expected_bad_files + ["Npix=11.wrong_metadata.parquet"] + actual_bad_file_names = {file_name.split("/")[-1] for file_name in result.bad_files} + assert set(expected_bad_files) == set(actual_bad_file_names), "wrong bad_files" diff --git a/tests/hats_import/verification/test_verification_arguments.py b/tests/hats_import/verification/test_verification_arguments.py index 8798fa9f..f661bf96 100644 --- a/tests/hats_import/verification/test_verification_arguments.py +++ b/tests/hats_import/verification/test_verification_arguments.py @@ -1,71 +1,33 @@ """Tests of argument validation""" import pytest -from hats import read_hats from hats_import.verification.arguments import VerificationArguments -def test_none(): - """No arguments provided. Should error for required args.""" - with pytest.raises(ValueError): - VerificationArguments() - - -def test_empty_required(tmp_path): - """*Most* required arguments are provided.""" - ## Input path is missing - with pytest.raises(ValueError, match="input_catalog_path"): - VerificationArguments( - output_path=tmp_path, - output_artifact_name="small_sky_object_verification_report", - ) - - def test_invalid_paths(tmp_path, small_sky_object_catalog): """Required arguments are provided, but paths aren't found.""" ## Prove that it works with required args - VerificationArguments( - input_catalog_path=small_sky_object_catalog, - output_path=tmp_path, - output_artifact_name="small_sky_object_verification_report", - ) + VerificationArguments(input_catalog_path=small_sky_object_catalog, output_path=tmp_path) + + ## Input path is not an existing directory + with pytest.raises(ValueError, match="input_catalog_path must be an existing directory"): + VerificationArguments(input_catalog_path="path", output_path=f"{tmp_path}/path") - ## Input path is invalid catalog - with pytest.raises(ValueError, match="input_catalog_path not a valid catalog"): + # Truth schema is not an existing file + with pytest.raises(ValueError, match="truth_schema must be an existing file or directory"): VerificationArguments( - input_catalog_path="path", - output_path=f"{tmp_path}/path", - output_artifact_name="small_sky_object_verification_report", + input_catalog_path=small_sky_object_catalog, output_path=tmp_path, truth_schema="path" ) -def test_good_paths(tmp_path, small_sky_object_catalog): - """Required arguments are provided, and paths are found.""" - tmp_path_str = str(tmp_path) - args = VerificationArguments( - input_catalog_path=small_sky_object_catalog, - output_path=tmp_path, - output_artifact_name="small_sky_object_verification_report", - ) - assert args.input_catalog_path == small_sky_object_catalog - assert str(args.output_path) == tmp_path_str - assert str(args.tmp_path).startswith(tmp_path_str) - - @pytest.mark.timeout(5) -def test_catalog_object(tmp_path, small_sky_object_catalog): +def test_good_paths(tmp_path, small_sky_object_catalog): """Required arguments are provided, and paths are found. NB: This is currently the last test in alpha-order, and may require additional time to teardown fixtures.""" - small_sky_catalog_object = read_hats(catalog_path=small_sky_object_catalog) tmp_path_str = str(tmp_path) - args = VerificationArguments( - input_catalog=small_sky_catalog_object, - output_path=tmp_path, - output_artifact_name="small_sky_object_verification_report", - ) + args = VerificationArguments(input_catalog_path=small_sky_object_catalog, output_path=tmp_path) assert args.input_catalog_path == small_sky_object_catalog assert str(args.output_path) == tmp_path_str - assert str(args.tmp_path).startswith(tmp_path_str)