diff --git a/src/too/__init__.py b/src/too/__init__.py index 58b5ecc..829470b 100644 --- a/src/too/__init__.py +++ b/src/too/__init__.py @@ -39,4 +39,5 @@ from .datamodel import * from .mock import * from .tools import * +from .validate import * from .xmatch import * diff --git a/src/too/tools.py b/src/too/tools.py index 26433f9..3f63488 100644 --- a/src/too/tools.py +++ b/src/too/tools.py @@ -8,6 +8,7 @@ from __future__ import annotations +import os import pathlib import shutil import tempfile @@ -17,17 +18,20 @@ import httpx import polars import rich.progress +from astropy.coordinates import SkyCoord, match_coordinates_sky + +from coordio.defaults import APO_MAX_FIELD_R, LCO_MAX_FIELD_R from too.datamodel import too_dtypes if TYPE_CHECKING: - import os - import rich.console + from sdssdb.connection import PeeweeDatabaseConnection -__all__ = ["download_file", "read_too_file"] + +__all__ = ["download_file", "read_too_file", "match_fields"] def read_too_file(path: polars.DataFrame | pathlib.Path | str) -> polars.DataFrame: @@ -86,3 +90,77 @@ def download_file( download_file.flush() shutil.move(download_file.name, path / pathlib.Path(url).name) + + +def match_fields( + targets: polars.DataFrame, + database: PeeweeDatabaseConnection, + rs_version: str | None = None, + check_separation: bool = False, +) -> polars.DataFrame: + """Matches a list of targets with their fields and observatories. + + Parameters + ---------- + targets + The data frame of targets. It must include fully populated ``ra`` and + ``dec`` columns. + database + The database connection. + rs_version + The robostrategy plan to use to select fields. Defaults to the + value of the ``$RS_VERSION`` environment variable. + check_separation + If ``True``, checks that the separation between the target and the + field centre is less than the FPS FoV. + + Returns + ------- + dataframe + The input dataframe with ``field_id`` and ``observatory`` columns. + + """ + + targets = targets.clone() + + if rs_version is None: + rs_version = os.environ.get("RS_VERSION", None) + if rs_version is None: + raise ValueError("No rs_version provided and $RS_VERSION not set.") + + too_sc = SkyCoord(ra=targets["ra"], dec=targets["dec"], unit="deg", frame="icrs") + + fields = polars.read_database( + "SELECT f.field_id, f.racen AS field_ra, " + " f.deccen AS field_dec, o.label AS observatory " + "FROM targetdb.field f " + "JOIN targetdb.version v ON f.version_pk = v.pk " + "JOIN targetdb.observatory o ON o.pk = f.observatory_pk " + f"WHERE v.plan = '{rs_version}' AND v.robostrategy;", + database, + ) + + fields_sc = SkyCoord( + ra=fields["field_ra"], + dec=fields["field_dec"], + unit="deg", + frame="icrs", + ) + + field_idx, sep2d, _ = match_coordinates_sky(too_sc, fields_sc) + + field_to_target = fields[field_idx] + field_to_target = field_to_target.with_columns(field_separation=sep2d.deg) + + targets = targets.hstack(field_to_target) + + if check_separation: + for obs, max_field_r in [("APO", APO_MAX_FIELD_R), ("LCO", LCO_MAX_FIELD_R)]: + obs_targets = targets.filter(polars.col.observatory == obs) + if (obs_targets["field_separation"] > max_field_r).any(): + raise ValueError( + f"Targets with separation larger than {max_field_r} deg " + f"found for observatory {obs}." + ) + + return targets diff --git a/tests/scripts/create_test_database.py b/tests/scripts/create_test_database.py index f4a063d..dab0e46 100644 --- a/tests/scripts/create_test_database.py +++ b/tests/scripts/create_test_database.py @@ -80,6 +80,7 @@ def create_test_database( "gaia_dr3_source.csv.gz", "twomass_psc.csv.gz", "targetdb_field.csv.gz", + "targetdb_version.csv.gz", ] for file in files: if not (CACHE_PATH / file).exists(): diff --git a/tests/scripts/sdss5db_too_test.sql b/tests/scripts/sdss5db_too_test.sql index 932aacc..2245cd2 100644 --- a/tests/scripts/sdss5db_too_test.sql +++ b/tests/scripts/sdss5db_too_test.sql @@ -278,7 +278,8 @@ CREATE TABLE targetdb.field ( racen DOUBLE PRECISION NOT NULL, deccen DOUBLE PRECISION NOT NULL, position_angle REAL, - observatory_pk INTEGER); + observatory_pk INTEGER, + version_pk INTEGER); CREATE TABLE targetdb.observatory ( pk SERIAL PRIMARY KEY NOT NULL, @@ -304,6 +305,9 @@ ALTER TABLE catalogdb.twomass_psc ADD PRIMARY KEY (pts_key); ALTER TABLE ONLY targetdb.field ADD CONSTRAINT observatory_fk FOREIGN KEY (observatory_pk) REFERENCES targetdb.observatory(pk); +ALTER TABLE ONLY targetdb.field + ADD CONSTRAINT version_fk + FOREIGN KEY (version_pk) REFERENCES targetdb.version(pk); CREATE INDEX ON catalogdb.catalog (version_id); CREATE INDEX ON catalogdb.catalog (q3c_ang2ipix(ra, dec)); @@ -356,6 +360,9 @@ CREATE UNIQUE INDEX ON targetdb.design_mode(label); CREATE INDEX CONCURRENTLY ON targetdb.field (q3c_ang2ipix(racen, deccen)); CREATE INDEX CONCURRENTLY ON targetdb.field USING BTREE(field_id); CREATE INDEX CONCURRENTLY ON targetdb.field USING BTREE(observatory_pk); +CREATE INDEX CONCURRENTLY ON targetdb.field USING BTREE(version_pk); + +CREATE INDEX CONCURRENTLY ON targetdb.version USING BTREE(plan); INSERT INTO catalogdb.version VALUES (31, '1.0.0', '1.0.0'); INSERT INTO targetdb.cadence VALUES ('bright_1x1', 1, '{0}', '{1}', '{0}', '{0}', '{1}', '{0}', 1, null, 'bright_1x1'); @@ -370,6 +377,7 @@ INSERT INTO targetdb.observatory VALUES (0, 'APO'), (1, 'LCO'); \copy catalogdb.gaia_dr3_source FROM PROGRAM '/usr/bin/gzip -dc gaia_dr3_source.csv.gz' WITH CSV HEADER; -- \copy catalogdb.sdss_dr13_photoobj FROM PROGRAM '/usr/bin/gzip -dc sdss_dr13_photoobj.csv.gz' WITH CSV HEADER; \copy catalogdb.twomass_psc FROM PROGRAM '/usr/bin/gzip -dc twomass_psc.csv.gz' WITH CSV HEADER; +\copy targetdb.version FROM PROGRAM '/usr/bin/gzip -dc targetdb_version.csv.gz' WITH CSV HEADER; \copy targetdb.field FROM PROGRAM '/usr/bin/gzip -dc targetdb_field.csv.gz' WITH CSV HEADER; \copy targetdb.design_mode FROM 'design_mode.csv' WITH CSV HEADER; @@ -387,4 +395,6 @@ VACUUM ANALYZE targetdb.carton; VACUUM ANALYZE targetdb.carton_to_target; VACUUM ANALYZE targetdb.category; VACUUM ANALYZE targetdb.version; +VACUUM ANALYZE targetdb.observatory; +VACUUM ANALYZE targetdb.field; VACUUM ANALYZE targetdb.magnitude; diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..9a3c2c2 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# @Author: José Sánchez-Gallego (gallegoj@uw.edu) +# @Date: 2024-04-23 +# @Filename: test_tools.py +# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars +import pytest + +from too.tools import match_fields + + +if TYPE_CHECKING: + from sdssdb.connection import PeeweeDatabaseConnection + + +@pytest.mark.parametrize("rs_version", ["eta-6", None]) +def test_match_fields( + too_mock: polars.DataFrame, + database: PeeweeDatabaseConnection, + rs_version: str | None, + monkeypatch: pytest.MonkeyPatch, +): + + if rs_version is None: + monkeypatch.setenv("RS_VERSION", "eta-6") + + targets = match_fields(too_mock[0:1000], database, rs_version=rs_version) + + assert isinstance(targets, polars.DataFrame) + assert "field_ra" in targets.columns + + +def test_match_fields_no_rs_version( + too_mock: polars.DataFrame, + database: PeeweeDatabaseConnection, + monkeypatch: pytest.MonkeyPatch, +): + + monkeypatch.delenv("RS_VERSION", raising=False) + + with pytest.raises(ValueError): + match_fields(too_mock, database) + + +def test_match_fields_check_separation( + too_mock: polars.DataFrame, + database: PeeweeDatabaseConnection, +): + # Most random coordinates will have some that do not fall inside our tiling. + with pytest.raises(ValueError): + match_fields(too_mock[0:1000], database, "eta-6", check_separation=True)