Skip to content

Commit

Permalink
Move slow imoport inside the functions
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Apr 30, 2024
1 parent d92cfc7 commit 4da7b0c
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 36 deletions.
5 changes: 0 additions & 5 deletions src/too/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@

import sdssdb
from sdsstools import get_logger, get_package_version
from target_selection import log as ts_log
from target_selection.exceptions import TargetSelectionImportWarning


ts_log.setLevel(10000) # Disable target_selection logging
warnings.simplefilter("ignore", TargetSelectionImportWarning)

sdssdb.autoconnect = False


Expand Down
21 changes: 12 additions & 9 deletions src/too/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
import polars
from click_option_group import OptionGroup

from too import (
connect_to_database,
load_too_targets,
log,
read_too_file,
run_too_carton,
too_dtypes,
xmatch_too_targets,
)


run_options = OptionGroup(
"Run options",
Expand Down Expand Up @@ -80,15 +90,8 @@ def too_cli(
):
"""Command line interface for ToOs."""

from too import (
connect_to_database,
load_too_targets,
log,
read_too_file,
run_too_carton,
too_dtypes,
xmatch_too_targets,
)
if len(files) == 0:
raise click.UsageError("At least one file must be passed.")

if verbose:
log.sh.setLevel("DEBUG")
Expand Down
8 changes: 8 additions & 0 deletions src/too/carton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import warnings

from too import log


Expand All @@ -20,6 +22,12 @@
def run_too_carton():
"""Runs the ToO carton."""

from target_selection import log as ts_log
from target_selection.exceptions import TargetSelectionImportWarning

ts_log.setLevel(10000) # Disable target_selection logging
warnings.simplefilter("ignore", TargetSelectionImportWarning)

from target_selection.cartons.too import ToO_Carton # Slow import

too_carton = ToO_Carton(TOO_TARGET_PLAN)
Expand Down
14 changes: 10 additions & 4 deletions src/too/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
import os
import pathlib

import polars
from typing import TYPE_CHECKING

from sdssdb.connection import PeeweeDatabaseConnection
from sdssdb.peewee.sdss5db import catalogdb
from sdsstools.time import get_sjd
import polars

from too import log
from too.datamodel import too_dtypes, too_metadata_columns
from too.tools import read_too_file
from too.validate import validate_too_targets


if TYPE_CHECKING:
from sdssdb.connection import PeeweeDatabaseConnection


__all__ = [
"connect_to_database",
"get_database_uri",
Expand All @@ -43,6 +45,8 @@ def connect_to_database(
):
"""Connects the ``sdssdb`` ``sdss5db`` models to the database."""

from sdssdb.peewee.sdss5db import catalogdb

if port is None:
port = int(os.environ.get("PGPORT", 5432))

Expand Down Expand Up @@ -204,6 +208,8 @@ def get_active_targets(database: PeeweeDatabaseConnection):
"""

from sdsstools.time import get_sjd

assert database.connected, "Database is not connected."

database_uri = database_uri_from_connection(database)
Expand Down
11 changes: 8 additions & 3 deletions src/too/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
import os
import pathlib

from typing import TYPE_CHECKING

import numpy
import polars
from peewee import JOIN, fn

from sdssdb.connection import PeeweeDatabaseConnection
from sdssdb.peewee.sdss5db import catalogdb, opsdb

from too import log
from too.tools import match_fields
from too.validate import add_bright_limits_columns


if TYPE_CHECKING:
from sdssdb.connection import PeeweeDatabaseConnection


__all__ = ["dump_to_parquet"]


Expand All @@ -39,6 +42,8 @@ def dump_to_parquet(
"""

from sdssdb.peewee.sdss5db import catalogdb, opsdb

path = pathlib.Path(path)

observatory = observatory.upper()
Expand Down
7 changes: 4 additions & 3 deletions src/too/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
import httpx
import polars
import rich.progress
from astropy.coordinates import SkyCoord, match_coordinates_sky

from coordio.defaults import APO_MAX_FIELD_R, LCO_MAX_FIELD_R

from too.datamodel import too_dtypes

Expand Down Expand Up @@ -121,6 +118,10 @@ def match_fields(
"""

from astropy.coordinates import SkyCoord, match_coordinates_sky

from coordio.defaults import APO_MAX_FIELD_R, LCO_MAX_FIELD_R

targets = targets.clone()

if rs_version is None:
Expand Down
24 changes: 16 additions & 8 deletions src/too/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@
import polars
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astropy_healpix import HEALPix
from erfa import ErfaWarning

from coordio.utils import Moffat2dInterp, _offset_radec, object_offset
from sdssdb.connection import PeeweeDatabaseConnection
from sdssdb.peewee.sdss5db.targetdb import DesignMode as DesignModeDB

from too import log
from too.datamodel import mag_columns, too_dtypes
from too.exceptions import ValidationError
Expand All @@ -42,9 +37,6 @@
# load enviornment variable with path to healpix maps
BN_HEALPIX = os.getenv("BN_HEALPIX")

# add function for offseting
fmagloss = Moffat2dInterp()


def check_assign_mag_limit(
mag_metric_min,
Expand Down Expand Up @@ -126,6 +118,8 @@ def allDesignModes(database: PeeweeDatabaseConnection):
"""

from sdssdb.peewee.sdss5db.targetdb import DesignMode as DesignModeDB

desmodes = DesignModeDB.select()
dmd = collections.OrderedDict()
for desmode in desmodes:
Expand Down Expand Up @@ -187,6 +181,8 @@ class DesignMode:

def __init__(self, database: PeeweeDatabaseConnection, label: str | None = None):

from sdssdb.peewee.sdss5db.targetdb import DesignMode as DesignModeDB

DesignModeDB._meta.database(database) # type:ignore

if label is not None:
Expand All @@ -204,6 +200,8 @@ def fromdb(self, label: str):
"""

from sdssdb.peewee.sdss5db.targetdb import DesignMode as DesignModeDB

self.desmode_label = label

desmode = DesignModeDB.select().where(DesignModeDB.label == label)[0]
Expand Down Expand Up @@ -342,8 +340,14 @@ def calculate_offsets(
offset_flag: np.array
flags associated with the offseting
"""

from coordio.utils import Moffat2dInterp, object_offset

# add function for offseting
fmagloss = Moffat2dInterp()

delta_ra = np.zeros(len(targets), dtype=float)
delta_dec = np.zeros(len(targets), dtype=float)
offset_flag = np.zeros(len(targets), dtype=int)
Expand Down Expand Up @@ -428,6 +432,10 @@ def bn_validation(
"""

from astropy_healpix import HEALPix

from coordio.utils import _offset_radec

log.debug(
f"Running bright neighbour validation for observatory {observatory} "
f"and design mode {design_mode}"
Expand Down
16 changes: 12 additions & 4 deletions src/too/xmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import peewee
import polars

from sdssdb.connection import PeeweeDatabaseConnection
from sdssdb.peewee.sdss5db.catalogdb import ToO_Target, Version
from target_selection.xmatch import XMatchPlanner

from too import log
from too.database import get_database_uri


if TYPE_CHECKING:
from sdssdb.connection import PeeweeDatabaseConnection


__all__ = ["xmatch_too_targets"]


Expand Down Expand Up @@ -122,6 +124,12 @@ def xmatch_too_targets(
"""

from sdssdb.peewee.sdss5db.catalogdb import ToO_Target, Version
from target_selection import log as ts_log
from target_selection.xmatch import XMatchPlanner

ts_log.setLevel(10000) # Disable target_selection logging

assert database.connected, "Database is not connected"

too_target_schema: str = ToO_Target._meta.schema # type:ignore
Expand Down

0 comments on commit 4da7b0c

Please sign in to comment.