From dd83788d4f8f7290a22a4b8813c439c011e57599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez-Gallego?= Date: Mon, 12 Feb 2024 00:51:17 -0800 Subject: [PATCH] Load ToO targets into database and tests --- poetry.lock | 50 ++++++++++++++++++++++++++++++- pyproject.toml | 1 + src/too/database.py | 67 ++++++++++++++++++++++++++++++++++++++++++ src/too/datamodel.py | 25 +++++++++++----- src/too/mock.py | 10 ++++++- tests/conftest.py | 14 ++++----- tests/test_database.py | 53 ++++++++++++++++++++++++++++++++- 7 files changed, 201 insertions(+), 19 deletions(-) diff --git a/poetry.lock b/poetry.lock index d487b29..fd1cf0f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2128,6 +2128,54 @@ files = [ {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, ] +[[package]] +name = "pyarrow" +version = "15.0.0" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"}, + {file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"}, + {file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"}, + {file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"}, + {file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"}, + {file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"}, + {file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"}, + {file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"}, + {file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"}, + {file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"}, + {file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"}, + {file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"}, + {file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"}, + {file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"}, + {file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"}, + {file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + [[package]] name = "pydl" version = "1.0.0" @@ -2846,4 +2894,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12,<4.0" -content-hash = "e2a9c22b31e52189b01720c40710aa50aa40cf65894b7d09912bcf249c90b595" +content-hash = "6b3ef7d3c917e63d3669029b72434cf0470eaa5f5b9723a271f787eae5e024fe" diff --git a/pyproject.toml b/pyproject.toml index 16cc405..dd91d58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ polars = "^0.20.7" httpx = "^0.26.0" rich = "^13.7.0" sdssdb = "^0.8.3" +pyarrow = "^15.0.0" [tool.poetry.group.dev.dependencies] ipython = ">=8.21.0" diff --git a/src/too/database.py b/src/too/database.py index be6a255..312820a 100644 --- a/src/too/database.py +++ b/src/too/database.py @@ -8,9 +8,16 @@ from __future__ import annotations +import pathlib + +import adbc_driver_postgresql.dbapi as dbapi import peewee +import polars from sdssdb.peewee import BaseModel +from too import log +from too.datamodel import too_dtypes + __all__ = ["ToO_Target"] @@ -56,3 +63,63 @@ class ToO_Target(ToOBaseModel): class Meta: table_name = "too_target" + + +def get_database_uri( + dbname: str, + host: str = "localhost", + port: int | None = None, + user: str | None = None, + password: str | None = None, +): + """Returns the URI to the database.""" + + if user is None and password is None: + auth: str = "" + elif user is not None and password is None: + auth: str = f"{user}@" + elif user is not None and password is not None: + auth: str = f"{user}:{password}@" + else: + raise ValueError("Passing a password requires also passing a user.") + + host_port: str = f"{host}" if port is None else f"{host}:{port}" + + return f"postgresql://{auth}{host_port}/{dbname}" + + +def load_too_targets( + targets: polars.DataFrame | str | pathlib.Path, + database_uri: str, + update_existing: bool = False, +): + """Loads a list of ToO targets into the database.""" + + if update_existing: + raise NotImplementedError("update_existing not yet implemented.") + + if isinstance(targets, (str, pathlib.Path)): + targets = polars.read_parquet(targets) + + with dbapi.connect(database_uri) as conn: + current_targets = polars.read_database( + "SELECT * from catalogdb.too_target", + conn, # type: ignore + ) + current_targets = current_targets.cast(too_dtypes) # type: ignore + + new_targets = targets.filter(~polars.col.too_id.is_in(current_targets["too_id"])) + + if len(new_targets) == 0: + log.info("No new ToO targets to add.") + return 0 + + log.info(f"Loading {len(new_targets)} new ToO targets into the database.") + n_added = new_targets.write_database( + "catalogdb.too_target", + database_uri, + if_table_exists="append", + engine="adbc", + ) + + return n_added diff --git a/src/too/datamodel.py b/src/too/datamodel.py index e77f16e..e12039e 100644 --- a/src/too/datamodel.py +++ b/src/too/datamodel.py @@ -8,17 +8,22 @@ from __future__ import annotations +from typing import Mapping + import polars +import polars.type_aliases as pta __all__ = ["too_dtypes"] -too_dtypes = { - "too_id": polars.UInt64, +PolarsTypeMapping = Mapping[str, pta.PolarsDataType] + +too_dtypes: PolarsTypeMapping = { + "too_id": polars.Int64, "fiber_type": polars.String, - "catalogid": polars.UInt64, - "sdss_id": polars.UInt64, - "gaia_dr3_source_id": polars.UInt64, + "catalogid": polars.Int64, + "sdss_id": polars.Int64, + "gaia_dr3_source_id": polars.Int64, "twomass_pts_key": polars.Int32, "sky_brightness_mode": polars.String, "ra": polars.Float64, @@ -48,7 +53,13 @@ "observed": polars.Boolean, } -too_fixed_columns = ["catalogid", "sdss_id", "gaia_dr3_source_id", "twomass_pts_key"] +too_fixed_columns = [ + "catalogid", + "sdss_id", + "gaia_dr3_source_id", + "twomass_pts_key", + "fiber_type", +] mag_columns = [ "u_mag", "g_mag", @@ -60,4 +71,4 @@ "gaia_g_mag", "h_mag", ] -fiber_type_values = ["APOGEE", "BOSS", ""] +fiber_type_values = ["APOGEE", "BOSS"] diff --git a/src/too/mock.py b/src/too/mock.py index 70e3948..281031f 100644 --- a/src/too/mock.py +++ b/src/too/mock.py @@ -122,8 +122,15 @@ def get_sample_targets( "psfmag_z": "z_mag", } + if "source_id" in sample.columns: + sample = sample.cast({"source_id": polars.Int64}) + elif "pts_key" in sample.columns: + sample = sample.cast({"pts_key": polars.Int32}) + sample = sample.select(*list(col_mapping), "catalogid", "sdss_id") sample = sample.rename(col_mapping) + sample = sample.cast({"catalogid": polars.Int64, "sdss_id": polars.Int64}) + sample = sample.with_columns(polars.col(["ra", "dec"]).cast(polars.Float64)) return sample @@ -254,7 +261,7 @@ def create_mock_too_catalogue( sdss_id=polars.when(polars.col.keep_cid).then(polars.col.sdss_id), ) df.drop_in_place("keep_cid") - df = df.sample(df.height) # Shuffle + df = df.sample(df.height, shuffle=True) # Shuffle df = df.with_columns(polars.int_range(1, df.height + 1).alias("too_id")) @@ -268,6 +275,7 @@ def create_mock_too_catalogue( observed=False, active=True, priority=polars.lit(5, dtype=polars.Int16), + n_exposures=polars.lit(3, dtype=polars.Int16), ) return df diff --git a/tests/conftest.py b/tests/conftest.py index 4e741cd..6bc4dd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from __future__ import annotations import pytest +from sdssdb.peewee.sdss5db import catalogdb from too.mock import create_mock_too_catalogue @@ -17,16 +18,11 @@ @pytest.fixture(autouse=True, scope="session") -def sdss5db(): - """A fixture that returns a connection to the ToO test database.""" +def connect_and_revert_database(): + """Reverts the database to the original state.""" - from sdssdb.peewee import sdss5db - - sdss5db.database.connect(DBNAME) - - yield - - sdss5db.database.close() + catalogdb.database.connect(DBNAME) + catalogdb.database.execute_sql("TRUNCATE TABLE catalogdb.too_target;") @pytest.fixture(scope="session") diff --git a/tests/test_database.py b/tests/test_database.py index 3df33c9..0699c64 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -8,10 +8,12 @@ from __future__ import annotations +import polars +import pytest from conftest import DBNAME from sdssdb.peewee.sdss5db import catalogdb -from too.database import ToO_Target +from too.database import ToO_Target, get_database_uri, load_too_targets def test_database_exists(): @@ -41,3 +43,52 @@ def test_models_exist(): ToO_Target.bind(catalogdb.database) assert ToO_Target.table_exists() assert ToO_Target.select().count() == 0 + + +@pytest.mark.parametrize( + "dbname,user,password,host,port,expected", + [ + ("testdb", None, None, "localhost", None, "localhost/testdb"), + ("testdb", "user", "1234", "localhost", None, "user:1234@localhost/testdb"), + ("testdb", "user", None, "localhost", 5432, "user@localhost:5432/testdb"), + ], +) +def test_get_database_uri( + dbname: str, + user: str | None, + password: str | None, + host: str, + port: int | None, + expected: str, +): + + uri = get_database_uri( + dbname, + user=user, + password=password, + host=host, + port=port, + ) + + assert uri == f"postgresql://{expected}" + + +def test_get_database_uri_password_fails(): + with pytest.raises(ValueError): + get_database_uri("testdb", password="1234") + + +def test_load_too_targets(too_mock: polars.DataFrame): + n_added = load_too_targets(too_mock[0:10], get_database_uri(DBNAME)) + + assert n_added == 10 + assert ToO_Target.select().count() == 10 + + # Repeat. No new targets should be added. + n_added = load_too_targets(too_mock[0:10], get_database_uri(DBNAME)) + assert n_added == 0 + assert ToO_Target.select().count() == 10 + + n_added = load_too_targets(too_mock[5:100000], get_database_uri(DBNAME)) + assert n_added == 99990 + assert ToO_Target.select().count() == 100000