Skip to content

Commit

Permalink
Add match_fields function
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Apr 24, 2024
1 parent 99e7f29 commit 2200af8
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/too/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
from .datamodel import *
from .mock import *
from .tools import *
from .validate import *
from .xmatch import *
84 changes: 81 additions & 3 deletions src/too/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import os
import pathlib
import shutil
import tempfile
Expand All @@ -17,17 +18,20 @@
import httpx
import polars
import rich.progress
from astropy.coordinates import SkyCoord, match_coordinates_sky

from coordio.defaults import APO_MAX_FIELD_R, LCO_MAX_FIELD_R

from too.datamodel import too_dtypes


if TYPE_CHECKING:
import os

import rich.console

from sdssdb.connection import PeeweeDatabaseConnection

__all__ = ["download_file", "read_too_file"]

__all__ = ["download_file", "read_too_file", "match_fields"]


def read_too_file(path: polars.DataFrame | pathlib.Path | str) -> polars.DataFrame:
Expand Down Expand Up @@ -86,3 +90,77 @@ def download_file(

download_file.flush()
shutil.move(download_file.name, path / pathlib.Path(url).name)


def match_fields(
targets: polars.DataFrame,
database: PeeweeDatabaseConnection,
rs_version: str | None = None,
check_separation: bool = False,
) -> polars.DataFrame:
"""Matches a list of targets with their fields and observatories.
Parameters
----------
targets
The data frame of targets. It must include fully populated ``ra`` and
``dec`` columns.
database
The database connection.
rs_version
The robostrategy plan to use to select fields. Defaults to the
value of the ``$RS_VERSION`` environment variable.
check_separation
If ``True``, checks that the separation between the target and the
field centre is less than the FPS FoV.
Returns
-------
dataframe
The input dataframe with ``field_id`` and ``observatory`` columns.
"""

targets = targets.clone()

if rs_version is None:
rs_version = os.environ.get("RS_VERSION", None)
if rs_version is None:
raise ValueError("No rs_version provided and $RS_VERSION not set.")

too_sc = SkyCoord(ra=targets["ra"], dec=targets["dec"], unit="deg", frame="icrs")

fields = polars.read_database(
"SELECT f.field_id, f.racen AS field_ra, "
" f.deccen AS field_dec, o.label AS observatory "
"FROM targetdb.field f "
"JOIN targetdb.version v ON f.version_pk = v.pk "
"JOIN targetdb.observatory o ON o.pk = f.observatory_pk "
f"WHERE v.plan = '{rs_version}' AND v.robostrategy;",
database,
)

fields_sc = SkyCoord(
ra=fields["field_ra"],
dec=fields["field_dec"],
unit="deg",
frame="icrs",
)

field_idx, sep2d, _ = match_coordinates_sky(too_sc, fields_sc)

field_to_target = fields[field_idx]
field_to_target = field_to_target.with_columns(field_separation=sep2d.deg)

targets = targets.hstack(field_to_target)

if check_separation:
for obs, max_field_r in [("APO", APO_MAX_FIELD_R), ("LCO", LCO_MAX_FIELD_R)]:
obs_targets = targets.filter(polars.col.observatory == obs)
if (obs_targets["field_separation"] > max_field_r).any():
raise ValueError(
f"Targets with separation larger than {max_field_r} deg "
f"found for observatory {obs}."
)

return targets
1 change: 1 addition & 0 deletions tests/scripts/create_test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def create_test_database(
"gaia_dr3_source.csv.gz",
"twomass_psc.csv.gz",
"targetdb_field.csv.gz",
"targetdb_version.csv.gz",
]
for file in files:
if not (CACHE_PATH / file).exists():
Expand Down
12 changes: 11 additions & 1 deletion tests/scripts/sdss5db_too_test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ CREATE TABLE targetdb.field (
racen DOUBLE PRECISION NOT NULL,
deccen DOUBLE PRECISION NOT NULL,
position_angle REAL,
observatory_pk INTEGER);
observatory_pk INTEGER,
version_pk INTEGER);

CREATE TABLE targetdb.observatory (
pk SERIAL PRIMARY KEY NOT NULL,
Expand All @@ -304,6 +305,9 @@ ALTER TABLE catalogdb.twomass_psc ADD PRIMARY KEY (pts_key);
ALTER TABLE ONLY targetdb.field
ADD CONSTRAINT observatory_fk
FOREIGN KEY (observatory_pk) REFERENCES targetdb.observatory(pk);
ALTER TABLE ONLY targetdb.field
ADD CONSTRAINT version_fk
FOREIGN KEY (version_pk) REFERENCES targetdb.version(pk);

CREATE INDEX ON catalogdb.catalog (version_id);
CREATE INDEX ON catalogdb.catalog (q3c_ang2ipix(ra, dec));
Expand Down Expand Up @@ -356,6 +360,9 @@ CREATE UNIQUE INDEX ON targetdb.design_mode(label);
CREATE INDEX CONCURRENTLY ON targetdb.field (q3c_ang2ipix(racen, deccen));
CREATE INDEX CONCURRENTLY ON targetdb.field USING BTREE(field_id);
CREATE INDEX CONCURRENTLY ON targetdb.field USING BTREE(observatory_pk);
CREATE INDEX CONCURRENTLY ON targetdb.field USING BTREE(version_pk);

CREATE INDEX CONCURRENTLY ON targetdb.version USING BTREE(plan);

INSERT INTO catalogdb.version VALUES (31, '1.0.0', '1.0.0');
INSERT INTO targetdb.cadence VALUES ('bright_1x1', 1, '{0}', '{1}', '{0}', '{0}', '{1}', '{0}', 1, null, 'bright_1x1');
Expand All @@ -370,6 +377,7 @@ INSERT INTO targetdb.observatory VALUES (0, 'APO'), (1, 'LCO');
\copy catalogdb.gaia_dr3_source FROM PROGRAM '/usr/bin/gzip -dc gaia_dr3_source.csv.gz' WITH CSV HEADER;
-- \copy catalogdb.sdss_dr13_photoobj FROM PROGRAM '/usr/bin/gzip -dc sdss_dr13_photoobj.csv.gz' WITH CSV HEADER;
\copy catalogdb.twomass_psc FROM PROGRAM '/usr/bin/gzip -dc twomass_psc.csv.gz' WITH CSV HEADER;
\copy targetdb.version FROM PROGRAM '/usr/bin/gzip -dc targetdb_version.csv.gz' WITH CSV HEADER;
\copy targetdb.field FROM PROGRAM '/usr/bin/gzip -dc targetdb_field.csv.gz' WITH CSV HEADER;
\copy targetdb.design_mode FROM 'design_mode.csv' WITH CSV HEADER;

Expand All @@ -387,4 +395,6 @@ VACUUM ANALYZE targetdb.carton;
VACUUM ANALYZE targetdb.carton_to_target;
VACUUM ANALYZE targetdb.category;
VACUUM ANALYZE targetdb.version;
VACUUM ANALYZE targetdb.observatory;
VACUUM ANALYZE targetdb.field;
VACUUM ANALYZE targetdb.magnitude;
58 changes: 58 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego ([email protected])
# @Date: 2024-04-23
# @Filename: test_tools.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

from typing import TYPE_CHECKING

import polars
import pytest

from too.tools import match_fields


if TYPE_CHECKING:
from sdssdb.connection import PeeweeDatabaseConnection


@pytest.mark.parametrize("rs_version", ["eta-6", None])
def test_match_fields(
too_mock: polars.DataFrame,
database: PeeweeDatabaseConnection,
rs_version: str | None,
monkeypatch: pytest.MonkeyPatch,
):

if rs_version is None:
monkeypatch.setenv("RS_VERSION", "eta-6")

targets = match_fields(too_mock[0:1000], database, rs_version=rs_version)

assert isinstance(targets, polars.DataFrame)
assert "field_ra" in targets.columns


def test_match_fields_no_rs_version(
too_mock: polars.DataFrame,
database: PeeweeDatabaseConnection,
monkeypatch: pytest.MonkeyPatch,
):

monkeypatch.delenv("RS_VERSION", raising=False)

with pytest.raises(ValueError):
match_fields(too_mock, database)


def test_match_fields_check_separation(
too_mock: polars.DataFrame,
database: PeeweeDatabaseConnection,
):
# Most random coordinates will have some that do not fall inside our tiling.
with pytest.raises(ValueError):
match_fields(too_mock[0:1000], database, "eta-6", check_separation=True)

0 comments on commit 2200af8

Please sign in to comment.