Skip to content

Commit

Permalink
Merge branch 'main' into issue/453/histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Jan 6, 2025
2 parents 7a26c0f + 1aa15a8 commit c559473
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[project]
name = "hats-import"
license = {file = "LICENSE"}
description = "Utility for ingesting large survey data into HATS structure"
readme = "README.md"
authors = [
{ name = "LINCC Frameworks", email = "[email protected]" }
Expand Down
22 changes: 19 additions & 3 deletions src/hats_import/catalog/run_import.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
"""Import a set of non-hats files using dask for parallelization
Methods in this file set up a dask pipeline using futures.
Methods in this file set up a dask pipeline using futures.
The actual logic of the map reduce is in the `map_reduce.py` file.
"""

import os
import pickle
from pathlib import Path

import hats.io.file_io as io
from hats.catalog import PartitionInfo
from hats.io import paths
from hats.io.parquet_metadata import write_parquet_metadata
from hats.io.validation import is_valid_catalog

import hats_import.catalog.map_reduce as mr
from hats_import.catalog.arguments import ImportArguments
from hats_import.catalog.resume_plan import ResumePlan


def run(args, client):
"""Run catalog creation pipeline."""
def _validate_arguments(args):
"""
Verify that the args for run are valid: they exist, are of the appropriate type,
and do not specify an output which is a valid catalog.
Raises ValueError if they are invalid.
"""
if not args:
raise ValueError("args is required and should be type ImportArguments")
if not isinstance(args, ImportArguments):
raise ValueError("args must be type ImportArguments")

potential_path = Path(args.output_path) / args.output_artifact_name
if is_valid_catalog(potential_path):
raise ValueError(f"Output path {potential_path} already contains a valid catalog")


def run(args, client):
"""Run catalog creation pipeline."""
_validate_arguments(args)

resume_plan = ResumePlan(import_args=args)

pickled_reader_file = os.path.join(resume_plan.tmp_path, "reader.pickle")
Expand Down
8 changes: 8 additions & 0 deletions src/hats_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from pathlib import Path

from hats.catalog import PartitionInfo
from hats.io import file_io, parquet_metadata, paths
from hats.io.validation import is_valid_catalog

import hats_import.margin_cache.margin_cache_map_reduce as mcmr
from hats_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
Expand All @@ -15,6 +18,11 @@ def generate_margin_cache(args, client):
args (MarginCacheArguments): A valid `MarginCacheArguments` object.
client (dask.distributed.Client): A dask distributed client object.
"""
potential_path = Path(args.output_path) / args.output_artifact_name
# Verify that the planned output path is not occupied by a valid catalog
if is_valid_catalog(potential_path):
raise ValueError(f"Output path {potential_path} already contains a valid catalog")

resume_plan = MarginCachePlan(args)
original_catalog_metadata = paths.get_common_metadata_pointer(args.input_catalog_path)

Expand Down
30 changes: 30 additions & 0 deletions tests/hats_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from hats.pixel_math.sparse_histogram import SparseHistogram

import hats_import.catalog.run_import as runner
import hats_import.margin_cache.margin_cache as margin_runner
from hats_import.catalog.arguments import ImportArguments
from hats_import.catalog.file_readers import CsvReader
from hats_import.catalog.resume_plan import ResumePlan
from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments


def test_empty_args():
Expand All @@ -31,6 +33,34 @@ def test_bad_args():
runner.run(args, None)


def test_no_import_overwrite(small_sky_object_catalog, parquet_shards_dir):
"""Runner should refuse to overwrite a valid catalog"""
catalog_dir = small_sky_object_catalog.parent
catalog_name = small_sky_object_catalog.name
args = ImportArguments(
input_path=parquet_shards_dir,
output_path=catalog_dir,
output_artifact_name=catalog_name,
file_reader="parquet",
)
with pytest.raises(ValueError, match="already contains a valid catalog"):
runner.run(args, None)


def test_no_margin_cache_overwrite(small_sky_object_catalog):
"""Runner should refuse to generate margin cache which overwrites valid catalog"""
catalog_dir = small_sky_object_catalog.parent
catalog_name = small_sky_object_catalog.name
args = MarginCacheArguments(
input_catalog_path=small_sky_object_catalog,
output_path=catalog_dir,
margin_threshold=10.0,
output_artifact_name=catalog_name,
)
with pytest.raises(ValueError, match="already contains a valid catalog"):
margin_runner.generate_margin_cache(args, None)


@pytest.mark.dask
def test_resume_dask_runner(
dask_client,
Expand Down

0 comments on commit c559473

Please sign in to comment.