Skip to content

Commit

Permalink
improvements to CLI and migration
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Nov 12, 2024
1 parent 1d66c36 commit 3dd5149
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 458 deletions.
175 changes: 121 additions & 54 deletions bin/new_astra
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,103 @@ from typing_extensions import Annotated

app = typer.Typer()

# TODO: make spectrum_model an enumerated type for tab
@app.command()
def srun(
task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")],
model: Annotated[str, typer.Argument(
help=(
"The input model to use (e.g., `ApogeeCombinedSpectrum`, `BossCombinedSpectrum`). "
)
)] = None,
nodes: Annotated[int, typer.Option(help="The number of nodes to use.")] = 1,
procs: Annotated[int, typer.Option(help="The number of processes to use per node.")] = 1,
limit: Annotated[int, typer.Option(help="Limit the number of inputs.")] = None,
account: Annotated[str, typer.Option(help="Slurm account")] = "sdss-np",
partition: Annotated[str, typer.Option(help="Slurm partition")] = "sdss-np",
time: Annotated[str, typer.Option(help="Wall-time")] = "24:00:00",
):
"""Distribute an Astra task over many nodes using Slurm."""

import os
import sys
import numpy as np
import concurrent.futures
import subprocess
from tempfile import TemporaryDirectory
from peewee import JOIN
from importlib import import_module
from astra import models, __version__, generate_queries_for_task
from astra.utils import silenced, expand_path
from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn

ASTRA = "/uufs/chpc.utah.edu/common/home/sdss50/sdsswork/mwm/spectro/astra/astra/astra_dev/bin/new_astra"

_, q = next(generate_queries_for_task(task, model, limit))

total = q.count()
workers = nodes * procs
limit = int(np.ceil(total / workers))

with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
transient=True
) as p:

executor = concurrent.futures.ProcessPoolExecutor(nodes)

# Load a whole bunch of sruns in processes
futures = {}
with TemporaryDirectory(dir=expand_path("$PBS")) as td:

for n in range(nodes):
commands = ["export CLUSTER=1"]
for page in range(n * procs, (n + 1) * procs):
commands.append(f"{ASTRA} run {task} {model} --limit {limit} --page {page} &")
commands.append("wait")

script_path = f"{td}/node_{n}.sh"
with open(script_path, "w") as fp:
fp.write("\n".join(commands))

os.system(f"chmod +x {script_path}")

executable = [
"srun",
"--nodes=1",
f"--partition={partition}",
f"--account={account}",
f"--job-name={task}-{n}",
f"--time={time}",
f"--output={td}/{n}.out",
f"--error={td}/{n}.err",
f"bash",
"-c",
f"sh {script_path}"
]

t = p.add_task(description=f"Running {task}-{n}", total=None)
job = executor.submit(
subprocess.run,
executable,
capture_output=True
)
futures[job] = (n, t)

max_returncode = 0
for future in concurrent.futures.as_completed(futures.keys()):
n, t = futures[future]
result = future.result()
if result.returncode == 0:
p.update(t, description=f"Completed")
p.remove_task(t)
else:
p.update(t, description=f"Error code {result.returncode} returned from {task}-{n}")
max_returncode = max(max_returncode, result.returncode)

sys.exit(max_returncode)

@app.command()
def run(
task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")],
Expand All @@ -21,13 +117,11 @@ def run(
"""Run an Astra task on spectra."""
from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn

from peewee import JOIN
from astra import models, __version__
from astra.models.source import Source
from astra.models.spectrum import Spectrum
from astra.utils import resolve_task, get_return_type, expects_spectrum_types, version_string_to_integer
from astra import models, __version__, generate_queries_for_task
from astra.utils import resolve_task

fun = resolve_task(task)

current_version = version_string_to_integer(__version__) // 1000
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as p:
t = p.add_task(description="Resolving task", total=None)
fun = resolve_task(task)
Expand All @@ -40,55 +134,19 @@ def run(
BarColumn(),
MofNCompleteColumn(),
TimeRemainingColumn(),
transient=True
transient=False
) as p:

if spectrum_model is None:
spectrum_models = expects_spectrum_types(fun)
else:
spectrum_models = (getattr(models, spectrum_model), )

output_model = get_return_type(fun)

n = 0
for model in spectrum_models:
if limit is not None and n >= limit:
break
for model, q in generate_queries_for_task(fun, spectrum_model, limit, page=page):
t = p.add_task(description=f"Running {fun.__name__} on {model.__name__}", total=limit)
q = (
model
.select(model, Source)
.join(Source, attr="source")
.switch(model)
.join(Spectrum)
.join(
output_model,
JOIN.LEFT_OUTER,
on=(
(output_model.v_astra_major_minor == current_version)
& (model.spectrum_pk == output_model.spectrum_pk)
)
)
.where(
output_model.spectrum_pk.is_null()
| (model.modified > output_model.modified)
)
)
if limit is not None:
if page is not None:
q = q.paginate(page, limit)
else:
q = q.limit(limit)

total = q.count()
p.update(t, total=total)
if total > 0:
for r in fun(q):
p.advance(t)
for n, r in enumerate(fun(q), start=1):
p.update(t, advance=1, refresh=True)
messages.append(f"Processed {n} {model.__name__} spectra with {fun.__name__}")
p.update(t, completed=True)
n += total
messages.append(f"Processed {n} {model.__name__} spectra with {fun.__name__}")


list(map(typer.echo, messages))


Expand All @@ -114,7 +172,10 @@ def migrate(
style="progress.download",
)

from astra.migrations.boss import migrate_from_spall_file
from astra.migrations.boss import (
migrate_from_spall_file,
migrate_specfull_metadata_from_image_headers
)
from astra.migrations.apogee import (
migrate_apvisit_metadata_from_image_headers,
)
Expand Down Expand Up @@ -204,19 +265,20 @@ def migrate(
process_task(migrate_twomass_photometry, description="Ingesting 2MASS photometry"),
process_task(migrate_unwise_photometry, description="Ingesting unWISE photometry"),
process_task(migrate_glimpse_photometry, description="Ingesting GLIMPSE photometry"),
process_task(migrate_specfull_metadata_from_image_headers, description="Ingesting specFull metadata"),


process_task(migrate_apvisit_metadata_from_image_headers, description="Ingesting apVisit metadata"),
process_task(migrate_healpix, description="Ingesting HEALPix values"),
process_task(migrate_tic_v8_identifier, description="Ingesting TIC v8 identifiers"),
process_task(migrate_apvisit_metadata_from_image_headers, description="Ingesting apVisit metadata"),
process_task(update_galactic_coordinates, description="Computing Galactic coordinates"),
process_task(fix_unsigned_apogee_flags, description="Fix unsigned APOGEE flags"),
#process_task(migrate_carton_assignments_to_bigbitfield, description="Ingesting targeting cartons"),
process_task(migrate_targeting_cartons, description="Ingesting targeting cartons"),
process_task(compute_f_night_time_for_boss_visits, description="Computing f_night for BOSS visits"),
process_task(compute_f_night_time_for_apogee_visits, description="Computing f_night for APOGEE visits"),
process_task(update_visit_spectra_counts, description="Updating visit spectra counts"),
]
# reddening needs unwise, 2mass, glimpse,
task_gaia, task_twomass, task_unwise, task_glimpse, *_ = [t for p, t, q in ptq]
task_gaia, task_twomass, task_unwise, task_glimpse, task_specfull, *_ = [t for p, t, q in ptq]
reddening_requires = {task_twomass, task_unwise, task_glimpse, task_gaia}
started_reddening = False
awaiting = set(t for p, t, q in ptq)
Expand All @@ -242,6 +304,10 @@ def migrate(
]
reddening_requires.update({t for p, t, q in new_tasks[:3]}) # reddening needs Gaia astrometry, Zhang parameters, and Bailer-Jones distances
additional_tasks.extend(new_tasks)
if t == task_specfull:
additional_tasks.append(
process_task(compute_f_night_time_for_boss_visits, description="Computing f_night for BOSS visits")
)
if t == task_unwise:
additional_tasks.append(
process_task(compute_w1mag_and_w2mag, description="Computing W1, W2 mags")
Expand Down Expand Up @@ -298,6 +364,7 @@ def init(
init_model_packages = (
"apogee",
"boss",
"bossnet",
"apogeenet",
"astronn_dist",
"astronn",
Expand Down
64 changes: 59 additions & 5 deletions src/astra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from inspect import isgeneratorfunction
from decorator import decorator
from peewee import IntegrityError
from peewee import IntegrityError, JOIN
from sdsstools.configuration import get_config

from astra.utils import log, Timer
from astra.utils import log, Timer, resolve_task, resolve_model, get_return_type, expects_spectrum_types, version_string_to_integer, get_task_group_by_string

NAME = "astra"
__version__ = "0.7.0"
Expand All @@ -12,6 +12,7 @@
def task(
function,
*args,
group_by=None,
batch_size: int = 1000,
write_frequency: int = 300,
write_to_database: bool = True,
Expand All @@ -24,7 +25,7 @@ def task(
:param function:
The callable to decorate.
:param \*args:
:param *args:
The arguments to the task.
:param batch_size: [optional]
Expand All @@ -39,7 +40,7 @@ def task(
:param re_raise_exceptions: [optional]
If `True` (default), exceptions raised in the task will be raised. Otherwise, they will be logged and ignored.
:param \**kwargs:
:param **kwargs:
Keyword arguments for the task and the task decorator. See below.
"""

Expand Down Expand Up @@ -114,7 +115,7 @@ def bulk_insert_or_replace_pipeline_results(results):
:param results:
A list of records to create (e.g., sub-classes of `astra.models.BaseModel`).
"""
log.info(f"Bulk inserting {len(results)} into the database")
#log.info(f"Bulk inserting {len(results)} into the database")

first = results[0]
database, model = (first._meta.database, first.__class__)
Expand Down Expand Up @@ -150,7 +151,60 @@ def bulk_insert_or_replace_pipeline_results(results):

if len(results_dict) > 0:
raise IntegrityError("Failed to insert all results into the database.")




def generate_queries_for_task(task, input_model=None, limit=None, page=None):
"""
Generate queries for input data that need to be processed by the given task.
:param task:
The task name, or callable.
:param input_model: [optional]
The input spectrum model. If `None` is given then a query will be generated for each
spectrum model expected by the task, based on the task function signature.
:param limit: [optional]
Limit the number of rows for each spectrum model query.
"""
from astra.models.source import Source
from astra.models.spectrum import Spectrum

current_version = version_string_to_integer(__version__) // 1000

fun = resolve_task(task)

input_models = expects_spectrum_types(fun) if input_model is None else (resolve_model(input_model), )
output_model = get_return_type(fun)
group_by_string = get_task_group_by_string(fun)
for input_model in input_models:
where = (
output_model.spectrum_pk.is_null()
| (input_model.modified > output_model.modified)
)
on = (
(output_model.v_astra_major_minor == current_version)
& (input_model.spectrum_pk == output_model.spectrum_pk)
)
q = (
input_model
.select(input_model, Source)
.join(Source, attr="source")
.switch(input_model)
.join(Spectrum)
.join(output_model, JOIN.LEFT_OUTER, on=on)
.where(where)
)
if limit is not None:
if page is not None:
q = q.paginate(page, limit)
else:
q = q.limit(limit)
yield (input_model, q)


try:
config = get_config(NAME)

Expand Down
Loading

0 comments on commit 3dd5149

Please sign in to comment.