Skip to content

Commit

Permalink
Load ToO targets into database and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Feb 12, 2024
1 parent ae36528 commit dd83788
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 19 deletions.
50 changes: 49 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ polars = "^0.20.7"
httpx = "^0.26.0"
rich = "^13.7.0"
sdssdb = "^0.8.3"
pyarrow = "^15.0.0"

[tool.poetry.group.dev.dependencies]
ipython = ">=8.21.0"
Expand Down
67 changes: 67 additions & 0 deletions src/too/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@

from __future__ import annotations

import pathlib

import adbc_driver_postgresql.dbapi as dbapi
import peewee
import polars
from sdssdb.peewee import BaseModel

from too import log
from too.datamodel import too_dtypes


__all__ = ["ToO_Target"]

Expand Down Expand Up @@ -56,3 +63,63 @@ class ToO_Target(ToOBaseModel):

class Meta:
table_name = "too_target"


def get_database_uri(
dbname: str,
host: str = "localhost",
port: int | None = None,
user: str | None = None,
password: str | None = None,
):
"""Returns the URI to the database."""

if user is None and password is None:
auth: str = ""
elif user is not None and password is None:
auth: str = f"{user}@"
elif user is not None and password is not None:
auth: str = f"{user}:{password}@"
else:
raise ValueError("Passing a password requires also passing a user.")

host_port: str = f"{host}" if port is None else f"{host}:{port}"

return f"postgresql://{auth}{host_port}/{dbname}"


def load_too_targets(
targets: polars.DataFrame | str | pathlib.Path,
database_uri: str,
update_existing: bool = False,
):
"""Loads a list of ToO targets into the database."""

if update_existing:
raise NotImplementedError("update_existing not yet implemented.")

if isinstance(targets, (str, pathlib.Path)):
targets = polars.read_parquet(targets)

with dbapi.connect(database_uri) as conn:
current_targets = polars.read_database(
"SELECT * from catalogdb.too_target",
conn, # type: ignore
)
current_targets = current_targets.cast(too_dtypes) # type: ignore

new_targets = targets.filter(~polars.col.too_id.is_in(current_targets["too_id"]))

if len(new_targets) == 0:
log.info("No new ToO targets to add.")
return 0

log.info(f"Loading {len(new_targets)} new ToO targets into the database.")
n_added = new_targets.write_database(
"catalogdb.too_target",
database_uri,
if_table_exists="append",
engine="adbc",
)

return n_added
25 changes: 18 additions & 7 deletions src/too/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,22 @@

from __future__ import annotations

from typing import Mapping

import polars
import polars.type_aliases as pta


__all__ = ["too_dtypes"]

too_dtypes = {
"too_id": polars.UInt64,
PolarsTypeMapping = Mapping[str, pta.PolarsDataType]

too_dtypes: PolarsTypeMapping = {
"too_id": polars.Int64,
"fiber_type": polars.String,
"catalogid": polars.UInt64,
"sdss_id": polars.UInt64,
"gaia_dr3_source_id": polars.UInt64,
"catalogid": polars.Int64,
"sdss_id": polars.Int64,
"gaia_dr3_source_id": polars.Int64,
"twomass_pts_key": polars.Int32,
"sky_brightness_mode": polars.String,
"ra": polars.Float64,
Expand Down Expand Up @@ -48,7 +53,13 @@
"observed": polars.Boolean,
}

too_fixed_columns = ["catalogid", "sdss_id", "gaia_dr3_source_id", "twomass_pts_key"]
too_fixed_columns = [
"catalogid",
"sdss_id",
"gaia_dr3_source_id",
"twomass_pts_key",
"fiber_type",
]
mag_columns = [
"u_mag",
"g_mag",
Expand All @@ -60,4 +71,4 @@
"gaia_g_mag",
"h_mag",
]
fiber_type_values = ["APOGEE", "BOSS", ""]
fiber_type_values = ["APOGEE", "BOSS"]
10 changes: 9 additions & 1 deletion src/too/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,15 @@ def get_sample_targets(
"psfmag_z": "z_mag",
}

if "source_id" in sample.columns:
sample = sample.cast({"source_id": polars.Int64})
elif "pts_key" in sample.columns:
sample = sample.cast({"pts_key": polars.Int32})

sample = sample.select(*list(col_mapping), "catalogid", "sdss_id")
sample = sample.rename(col_mapping)
sample = sample.cast({"catalogid": polars.Int64, "sdss_id": polars.Int64})

sample = sample.with_columns(polars.col(["ra", "dec"]).cast(polars.Float64))

return sample
Expand Down Expand Up @@ -254,7 +261,7 @@ def create_mock_too_catalogue(
sdss_id=polars.when(polars.col.keep_cid).then(polars.col.sdss_id),
)
df.drop_in_place("keep_cid")
df = df.sample(df.height) # Shuffle
df = df.sample(df.height, shuffle=True) # Shuffle

df = df.with_columns(polars.int_range(1, df.height + 1).alias("too_id"))

Expand All @@ -268,6 +275,7 @@ def create_mock_too_catalogue(
observed=False,
active=True,
priority=polars.lit(5, dtype=polars.Int16),
n_exposures=polars.lit(3, dtype=polars.Int16),
)

return df
14 changes: 5 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import pytest
from sdssdb.peewee.sdss5db import catalogdb

from too.mock import create_mock_too_catalogue

Expand All @@ -17,16 +18,11 @@


@pytest.fixture(autouse=True, scope="session")
def sdss5db():
"""A fixture that returns a connection to the ToO test database."""
def connect_and_revert_database():
"""Reverts the database to the original state."""

from sdssdb.peewee import sdss5db

sdss5db.database.connect(DBNAME)

yield

sdss5db.database.close()
catalogdb.database.connect(DBNAME)
catalogdb.database.execute_sql("TRUNCATE TABLE catalogdb.too_target;")


@pytest.fixture(scope="session")
Expand Down
53 changes: 52 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

from __future__ import annotations

import polars
import pytest
from conftest import DBNAME
from sdssdb.peewee.sdss5db import catalogdb

from too.database import ToO_Target
from too.database import ToO_Target, get_database_uri, load_too_targets


def test_database_exists():
Expand Down Expand Up @@ -41,3 +43,52 @@ def test_models_exist():
ToO_Target.bind(catalogdb.database)
assert ToO_Target.table_exists()
assert ToO_Target.select().count() == 0


@pytest.mark.parametrize(
"dbname,user,password,host,port,expected",
[
("testdb", None, None, "localhost", None, "localhost/testdb"),
("testdb", "user", "1234", "localhost", None, "user:1234@localhost/testdb"),
("testdb", "user", None, "localhost", 5432, "user@localhost:5432/testdb"),
],
)
def test_get_database_uri(
dbname: str,
user: str | None,
password: str | None,
host: str,
port: int | None,
expected: str,
):

uri = get_database_uri(
dbname,
user=user,
password=password,
host=host,
port=port,
)

assert uri == f"postgresql://{expected}"


def test_get_database_uri_password_fails():
with pytest.raises(ValueError):
get_database_uri("testdb", password="1234")


def test_load_too_targets(too_mock: polars.DataFrame):
n_added = load_too_targets(too_mock[0:10], get_database_uri(DBNAME))

assert n_added == 10
assert ToO_Target.select().count() == 10

# Repeat. No new targets should be added.
n_added = load_too_targets(too_mock[0:10], get_database_uri(DBNAME))
assert n_added == 0
assert ToO_Target.select().count() == 10

n_added = load_too_targets(too_mock[5:100000], get_database_uri(DBNAME))
assert n_added == 99990
assert ToO_Target.select().count() == 100000

0 comments on commit dd83788

Please sign in to comment.