diff --git a/bin/new_astra b/bin/new_astra index f3f02f5..7f705d3 100755 --- a/bin/new_astra +++ b/bin/new_astra @@ -5,7 +5,103 @@ from typing_extensions import Annotated app = typer.Typer() -# TODO: make spectrum_model an enumerated type for tab +@app.command() +def srun( + task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")], + model: Annotated[str, typer.Argument( + help=( + "The input model to use (e.g., `ApogeeCombinedSpectrum`, `BossCombinedSpectrum`). " + ) + )] = None, + nodes: Annotated[int, typer.Option(help="The number of nodes to use.")] = 1, + procs: Annotated[int, typer.Option(help="The number of processes to use per node.")] = 1, + limit: Annotated[int, typer.Option(help="Limit the number of inputs.")] = None, + account: Annotated[str, typer.Option(help="Slurm account")] = "sdss-np", + partition: Annotated[str, typer.Option(help="Slurm partition")] = "sdss-np", + time: Annotated[str, typer.Option(help="Wall-time")] = "24:00:00", +): + """Distribute an Astra task over many nodes using Slurm.""" + + import os + import sys + import numpy as np + import concurrent.futures + import subprocess + from tempfile import TemporaryDirectory + from peewee import JOIN + from importlib import import_module + from astra import models, __version__, generate_queries_for_task + from astra.utils import silenced, expand_path + from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn + + ASTRA = "/uufs/chpc.utah.edu/common/home/sdss50/sdsswork/mwm/spectro/astra/astra/astra_dev/bin/new_astra" + + _, q = next(generate_queries_for_task(task, model, limit)) + + total = q.count() + workers = nodes * procs + limit = int(np.ceil(total / workers)) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + transient=True + ) as p: + + executor = concurrent.futures.ProcessPoolExecutor(nodes) + + # Load a whole bunch of sruns in processes + futures = {} + with TemporaryDirectory(dir=expand_path("$PBS")) as td: + + for n in range(nodes): + commands = ["export CLUSTER=1"] + for page in range(n * procs, (n + 1) * procs): + commands.append(f"{ASTRA} run {task} {model} --limit {limit} --page {page} &") + commands.append("wait") + + script_path = f"{td}/node_{n}.sh" + with open(script_path, "w") as fp: + fp.write("\n".join(commands)) + + os.system(f"chmod +x {script_path}") + + executable = [ + "srun", + "--nodes=1", + f"--partition={partition}", + f"--account={account}", + f"--job-name={task}-{n}", + f"--time={time}", + f"--output={td}/{n}.out", + f"--error={td}/{n}.err", + f"bash", + "-c", + f"sh {script_path}" + ] + + t = p.add_task(description=f"Running {task}-{n}", total=None) + job = executor.submit( + subprocess.run, + executable, + capture_output=True + ) + futures[job] = (n, t) + + max_returncode = 0 + for future in concurrent.futures.as_completed(futures.keys()): + n, t = futures[future] + result = future.result() + if result.returncode == 0: + p.update(t, description=f"Completed") + p.remove_task(t) + else: + p.update(t, description=f"Error code {result.returncode} returned from {task}-{n}") + max_returncode = max(max_returncode, result.returncode) + + sys.exit(max_returncode) + @app.command() def run( task: Annotated[str, typer.Argument(help="The task name to run (e.g., `aspcap`, or `astra.pipelines.aspcap.aspcap`).")], @@ -21,13 +117,11 @@ def run( """Run an Astra task on spectra.""" from rich.progress import Progress, SpinnerColumn, TextColumn, TaskProgressColumn, TimeRemainingColumn, BarColumn, MofNCompleteColumn - from peewee import JOIN - from astra import models, __version__ - from astra.models.source import Source - from astra.models.spectrum import Spectrum - from astra.utils import resolve_task, get_return_type, expects_spectrum_types, version_string_to_integer + from astra import models, __version__, generate_queries_for_task + from astra.utils import resolve_task + + fun = resolve_task(task) - current_version = version_string_to_integer(__version__) // 1000 with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}")) as p: t = p.add_task(description="Resolving task", total=None) fun = resolve_task(task) @@ -40,55 +134,19 @@ def run( BarColumn(), MofNCompleteColumn(), TimeRemainingColumn(), - transient=True + transient=False ) as p: - if spectrum_model is None: - spectrum_models = expects_spectrum_types(fun) - else: - spectrum_models = (getattr(models, spectrum_model), ) - - output_model = get_return_type(fun) - - n = 0 - for model in spectrum_models: - if limit is not None and n >= limit: - break + for model, q in generate_queries_for_task(fun, spectrum_model, limit, page=page): t = p.add_task(description=f"Running {fun.__name__} on {model.__name__}", total=limit) - q = ( - model - .select(model, Source) - .join(Source, attr="source") - .switch(model) - .join(Spectrum) - .join( - output_model, - JOIN.LEFT_OUTER, - on=( - (output_model.v_astra_major_minor == current_version) - & (model.spectrum_pk == output_model.spectrum_pk) - ) - ) - .where( - output_model.spectrum_pk.is_null() - | (model.modified > output_model.modified) - ) - ) - if limit is not None: - if page is not None: - q = q.paginate(page, limit) - else: - q = q.limit(limit) - total = q.count() p.update(t, total=total) if total > 0: - for r in fun(q): - p.advance(t) + for n, r in enumerate(fun(q), start=1): + p.update(t, advance=1, refresh=True) + messages.append(f"Processed {n} {model.__name__} spectra with {fun.__name__}") p.update(t, completed=True) - n += total - messages.append(f"Processed {n} {model.__name__} spectra with {fun.__name__}") - + list(map(typer.echo, messages)) @@ -114,7 +172,10 @@ def migrate( style="progress.download", ) - from astra.migrations.boss import migrate_from_spall_file + from astra.migrations.boss import ( + migrate_from_spall_file, + migrate_specfull_metadata_from_image_headers + ) from astra.migrations.apogee import ( migrate_apvisit_metadata_from_image_headers, ) @@ -204,19 +265,20 @@ def migrate( process_task(migrate_twomass_photometry, description="Ingesting 2MASS photometry"), process_task(migrate_unwise_photometry, description="Ingesting unWISE photometry"), process_task(migrate_glimpse_photometry, description="Ingesting GLIMPSE photometry"), + process_task(migrate_specfull_metadata_from_image_headers, description="Ingesting specFull metadata"), + + + process_task(migrate_apvisit_metadata_from_image_headers, description="Ingesting apVisit metadata"), process_task(migrate_healpix, description="Ingesting HEALPix values"), process_task(migrate_tic_v8_identifier, description="Ingesting TIC v8 identifiers"), - process_task(migrate_apvisit_metadata_from_image_headers, description="Ingesting apVisit metadata"), process_task(update_galactic_coordinates, description="Computing Galactic coordinates"), process_task(fix_unsigned_apogee_flags, description="Fix unsigned APOGEE flags"), - #process_task(migrate_carton_assignments_to_bigbitfield, description="Ingesting targeting cartons"), process_task(migrate_targeting_cartons, description="Ingesting targeting cartons"), - process_task(compute_f_night_time_for_boss_visits, description="Computing f_night for BOSS visits"), process_task(compute_f_night_time_for_apogee_visits, description="Computing f_night for APOGEE visits"), process_task(update_visit_spectra_counts, description="Updating visit spectra counts"), ] # reddening needs unwise, 2mass, glimpse, - task_gaia, task_twomass, task_unwise, task_glimpse, *_ = [t for p, t, q in ptq] + task_gaia, task_twomass, task_unwise, task_glimpse, task_specfull, *_ = [t for p, t, q in ptq] reddening_requires = {task_twomass, task_unwise, task_glimpse, task_gaia} started_reddening = False awaiting = set(t for p, t, q in ptq) @@ -242,6 +304,10 @@ def migrate( ] reddening_requires.update({t for p, t, q in new_tasks[:3]}) # reddening needs Gaia astrometry, Zhang parameters, and Bailer-Jones distances additional_tasks.extend(new_tasks) + if t == task_specfull: + additional_tasks.append( + process_task(compute_f_night_time_for_boss_visits, description="Computing f_night for BOSS visits") + ) if t == task_unwise: additional_tasks.append( process_task(compute_w1mag_and_w2mag, description="Computing W1, W2 mags") @@ -298,6 +364,7 @@ def init( init_model_packages = ( "apogee", "boss", + "bossnet", "apogeenet", "astronn_dist", "astronn", diff --git a/src/astra/__init__.py b/src/astra/__init__.py index dfa41d5..b5c14d0 100644 --- a/src/astra/__init__.py +++ b/src/astra/__init__.py @@ -1,9 +1,9 @@ from inspect import isgeneratorfunction from decorator import decorator -from peewee import IntegrityError +from peewee import IntegrityError, JOIN from sdsstools.configuration import get_config -from astra.utils import log, Timer +from astra.utils import log, Timer, resolve_task, resolve_model, get_return_type, expects_spectrum_types, version_string_to_integer, get_task_group_by_string NAME = "astra" __version__ = "0.7.0" @@ -12,6 +12,7 @@ def task( function, *args, + group_by=None, batch_size: int = 1000, write_frequency: int = 300, write_to_database: bool = True, @@ -24,7 +25,7 @@ def task( :param function: The callable to decorate. - :param \*args: + :param *args: The arguments to the task. :param batch_size: [optional] @@ -39,7 +40,7 @@ def task( :param re_raise_exceptions: [optional] If `True` (default), exceptions raised in the task will be raised. Otherwise, they will be logged and ignored. - :param \**kwargs: + :param **kwargs: Keyword arguments for the task and the task decorator. See below. """ @@ -114,7 +115,7 @@ def bulk_insert_or_replace_pipeline_results(results): :param results: A list of records to create (e.g., sub-classes of `astra.models.BaseModel`). """ - log.info(f"Bulk inserting {len(results)} into the database") + #log.info(f"Bulk inserting {len(results)} into the database") first = results[0] database, model = (first._meta.database, first.__class__) @@ -150,7 +151,60 @@ def bulk_insert_or_replace_pipeline_results(results): if len(results_dict) > 0: raise IntegrityError("Failed to insert all results into the database.") + + + + +def generate_queries_for_task(task, input_model=None, limit=None, page=None): + """ + Generate queries for input data that need to be processed by the given task. + + :param task: + The task name, or callable. + :param input_model: [optional] + The input spectrum model. If `None` is given then a query will be generated for each + spectrum model expected by the task, based on the task function signature. + + :param limit: [optional] + Limit the number of rows for each spectrum model query. + """ + from astra.models.source import Source + from astra.models.spectrum import Spectrum + + current_version = version_string_to_integer(__version__) // 1000 + + fun = resolve_task(task) + + input_models = expects_spectrum_types(fun) if input_model is None else (resolve_model(input_model), ) + output_model = get_return_type(fun) + group_by_string = get_task_group_by_string(fun) + for input_model in input_models: + where = ( + output_model.spectrum_pk.is_null() + | (input_model.modified > output_model.modified) + ) + on = ( + (output_model.v_astra_major_minor == current_version) + & (input_model.spectrum_pk == output_model.spectrum_pk) + ) + q = ( + input_model + .select(input_model, Source) + .join(Source, attr="source") + .switch(input_model) + .join(Spectrum) + .join(output_model, JOIN.LEFT_OUTER, on=on) + .where(where) + ) + if limit is not None: + if page is not None: + q = q.paginate(page, limit) + else: + q = q.limit(limit) + yield (input_model, q) + + try: config = get_config(NAME) diff --git a/src/astra/migrations/boss.py b/src/astra/migrations/boss.py index 52fd8d0..fa63152 100644 --- a/src/astra/migrations/boss.py +++ b/src/astra/migrations/boss.py @@ -2,7 +2,6 @@ from astropy.io import fits from astropy.table import Table from astropy.time import Time -from tqdm import tqdm import numpy as np import subprocess import concurrent.futures @@ -11,7 +10,7 @@ from astra.models.base import database from astra.models.boss import BossVisitSpectrum from astra.models.source import Source -from astra.migrations.utils import enumerate_new_spectrum_pks, upsert_many +from astra.migrations.utils import enumerate_new_spectrum_pks, upsert_many, NoQueue from peewee import ( chunked, @@ -22,9 +21,7 @@ ) - - -def migrate_from_spall_file(run2d, queue, gzip=True, limit=None, batch_size=1000): +def migrate_from_spall_file(run2d, queue, gzip=True, limit=1_000_000, batch_size=1000): """ Migrate all new BOSS visit information (`specFull` files) stored in the spAll file, which is generated by the SDSS-V BOSS data reduction pipeline. @@ -103,8 +100,8 @@ class Meta: "telescope": lambda x: f"{x.lower()}25m", "moon_dist_mean": lambda x: np.mean(tuple(map(float, x.split()))), "moon_phase_mean": lambda x: np.mean(tuple(map(float, x.split()))), - "delta_ra": lambda x: np.array(x.split(), dtype=float), - "delta_dec": lambda x: np.array(x.split(), dtype=float), + "delta_ra": lambda x: list(map(float, x.split())), + "delta_dec": lambda x: list(map(float, x.split())) } #with fits.open(path) as hdul: @@ -135,31 +132,7 @@ class Meta: queue.put(dict(description=f"Converting BOSS {run2d} data types", total=None, completed=0)) - spectrum_data = dict_to_iterable(spectrum_data_dicts) - - """ - spectrum_data = [] - for i, row in enumerate(spAll): - - row_data = dict(zip(row.keys(), row.values())) - - sanitised_row_data = { - "release": "sdss5", - "run2d": run2d, - "filetype": "specFull", - } - for from_key, to in translations.items(): - if isinstance(to, str): - sanitised_row_data[to] = row_data[from_key] - else: - to_key, to_callable = to - sanitised_row_data[to_key] = to_callable(row_data[from_key]) - - offset = np.abs(sanitised_row_data["delta_ra"]) + np.abs(sanitised_row_data["delta_dec"]) - sanitised_row_data["fiber_offset"] = np.any(offset > 0) - spectrum_data.append(sanitised_row_data) - queue.put({"advance": 1}) - """ + spectrum_data = list(dict_to_iterable(spectrum_data_dicts)) # We need to get sdss_id and catalog information for each source. source_data = {} @@ -266,8 +239,8 @@ class Meta: None - if n_warnings > 0: - log.warning(f"There were {n_warnings} spectra with no source_pk, probably because of missing or fake catalogids") + #if n_warnings > 0: + # log.warning(f"There were {n_warnings} spectra with no source_pk, probably because of missing or fake catalogids") pks = upsert_many( BossVisitSpectrum, @@ -281,7 +254,7 @@ class Meta: # Assign spectrum_pk values to any spectra missing it. N = len(pks) if pks: - queue.update(dict(description=f"Assigning primary keys to BOSS {run2d} spectra", total=N, completed=0)) + queue.put(dict(description=f"Assigning primary keys to BOSS {run2d} spectra", total=N, completed=0)) N_assigned = 0 for batch in chunked(pks, batch_size): B = ( @@ -294,263 +267,15 @@ class Meta: .where(BossVisitSpectrum.pk.in_(batch)) .execute() ) - queue.update(dict(advance=B)) + queue.put(dict(advance=B)) N_assigned += B - log.info(f"There were {N} spectra inserted and we assigned {N_assigned} spectra with new spectrum_pk values") - else: - log.info(f"No new spectra inserted") + #log.info(f"There were {N} spectra inserted and we assigned {N_assigned} spectra with new spectrum_pk values") queue.put(Ellipsis) return None -def migrate_spectra_from_spall_file( - run2d: Optional[str] = "v6_1_1", - gzip: Optional[bool] = True, - limit: Optional[int] = None, - batch_size: Optional[int] = 1000 -): - """ - Migrate all new BOSS visit information (`specFull` files) stored in the spAll file, which is generated - by the SDSS-V BOSS data reduction pipeline. - """ - - from astra.migrations.sdss5db.catalogdb import ( - Catalog, - CatalogToGaia_DR2, - CatalogToGaia_DR3, - CatalogdbModel - ) - - class SDSS_ID_Flat(CatalogdbModel): - class Meta: - table_name = "sdss_id_flat" - - class SDSS_ID_Stacked(CatalogdbModel): - class Meta: - table_name = "sdss_id_stacked" - - - path = f"$BOSS_SPECTRO_REDUX/{run2d}/spAll-{run2d}.fits" - if gzip: - path += ".gz" - - spAll = Table.read(expand_path(path)) - spAll.sort(["CATALOGID"]) - - if limit is not None: - spAll = spAll[:limit] - - translations = { - "NEXP": "n_exp", - "XCSAO_RV": "xcsao_v_rad", - "XCSAO_ERV": "xcsao_e_v_rad", - "XCSAO_RXC": "xcsao_rxc", - "XCSAO_TEFF": "xcsao_teff", - "XCSAO_ETEFF": "xcsao_e_teff", - "XCSAO_LOGG": "xcsao_logg", - "XCSAO_ELOGG": "xcsao_e_logg", - "XCSAO_FEH": "xcsao_fe_h", - "XCSAO_EFEH": "xcsao_e_fe_h", - "ZWARNING": ("zwarning_flags", lambda x: x or 0), - "EXPTIME": "exptime", - - # Not yet done: gri_gaia_transform, because it is accidentally missing from the IPL3 files - "AIRMASS": "airmass", - "SEEING50": "seeing", - - "OBS": ("telescope", lambda x: f"{x.lower()}25m"), - "MOON_DIST": ("moon_dist_mean", lambda x: np.mean(tuple(map(float, x.split())))), - "MOON_PHASE": ("moon_phase_mean", lambda x: np.mean(tuple(map(float, x.split())))), - - "FIELD": "fieldid", - "MJD": "mjd", - "CATALOGID": "catalogid", - "HEALPIX": "healpix", - "DELTA_RA_LIST": ("delta_ra", lambda x: np.array(x.split(), dtype=float)), - "DELTA_DEC_LIST": ("delta_dec", lambda x: np.array(x.split(), dtype=float)), - "SN_MEDIAN_ALL": "snr", - - # Some additional identifiers that we don't necessarily need, but will take for now - - "CATALOGID_V0": "catalogid_v0", - "CATALOGID_V0P5": "catalogid_v0p5", - "SDSS_ID": "sdss_id", - "GAIA_ID": "gaia_dr2_source_id", - "FIRSTCARTON": "carton_0" - } - source_keys_only = ("catalogid_v0", "catalogid_v0p5", "sdss_id", "gaia_dr2_source_id", "carton_0") - - spectrum_data = [] - for i, row in enumerate(tqdm(spAll)): - - row_data = dict(zip(row.keys(), row.values())) - - sanitised_row_data = { - "release": "sdss5", - "run2d": run2d, - "filetype": "specFull", - } - for from_key, to in translations.items(): - if isinstance(to, str): - sanitised_row_data[to] = row_data[from_key] - else: - to_key, to_callable = to - sanitised_row_data[to_key] = to_callable(row_data[from_key]) - - offset = np.abs(sanitised_row_data["delta_ra"]) + np.abs(sanitised_row_data["delta_dec"]) - sanitised_row_data["fiber_offset"] = np.any(offset > 0) - spectrum_data.append(sanitised_row_data) - - # We need to get sdss_id and catalog information for each source. - source_data = {} - with tqdm(total=len(spectrum_data), desc="Linking to Catalog") as pb: - for chunk in chunked(spectrum_data, batch_size): - - chunk_catalogids = [] - gaia_dr2_source_id_given_catalogid = {} - for row in chunk: - for key in ("catalogid", "catalogid_v0", "catalogid_v0p5"): - try: - if np.all(row[key].mask): - continue - except: - chunk_catalogids.append(row[key]) - gaia_dr2_source_id_given_catalogid[row[key]] = row["gaia_dr2_source_id"] - - q = ( - Catalog - .select( - Catalog.ra, - Catalog.dec, - Catalog.catalogid, - Catalog.version_id.alias("version_id"), - Catalog.lead, - CatalogToGaia_DR3.target.alias("gaia_dr3_source_id"), - SDSS_ID_Flat.sdss_id, - SDSS_ID_Flat.n_associated, - SDSS_ID_Stacked.catalogid21, - SDSS_ID_Stacked.catalogid25, - SDSS_ID_Stacked.catalogid31, - ) - .join(SDSS_ID_Flat, JOIN.LEFT_OUTER, on=(Catalog.catalogid == SDSS_ID_Flat.catalogid)) - .join(SDSS_ID_Stacked, JOIN.LEFT_OUTER, on=(SDSS_ID_Stacked.sdss_id == SDSS_ID_Flat.sdss_id)) - .join(CatalogToGaia_DR3, JOIN.LEFT_OUTER, on=(SDSS_ID_Stacked.catalogid31 == CatalogToGaia_DR3.catalog)) - .where(Catalog.catalogid.in_(chunk_catalogids)) - .dicts() - ) - - reference_key = "catalogid" - for row in q: - if row[reference_key] in source_data: - for key, value in row.items(): - if source_data[row[reference_key]][key] is None and value is not None: - if key == "sdss_id": - source_data[row[reference_key]][key] = min(source_data[row[reference_key]][key], value) - else: - source_data[row[reference_key]][key] = value - continue - - source_data[row[reference_key]] = row - gaia_dr2_source_id = gaia_dr2_source_id_given_catalogid[row[reference_key]] - if gaia_dr2_source_id < 0: - gaia_dr2_source_id = None - source_data[row[reference_key]]["gaia_dr2_source_id"] = gaia_dr2_source_id - - pb.update(batch_size) - - - # Upsert the sources - with database.atomic(): - with tqdm(desc="Upserting sources", total=len(source_data)) as pb: - for chunk in chunked(source_data.values(), batch_size): - ( - Source - .insert_many(chunk) - .on_conflict_ignore() - .execute() - ) - pb.update(min(batch_size, len(chunk))) - pb.refresh() - - log.info(f"Getting data for sources") - - q = ( - Source - .select( - Source.pk, - Source.catalogid, - Source.catalogid21, - Source.catalogid25, - Source.catalogid31 - ) - .tuples() - ) - - source_pk_by_catalogid = {} - for pk, *catalogids in q.iterator(): - for catalogid in catalogids: - source_pk_by_catalogid[catalogid] = pk - - n_warnings = 0 - for each in spectrum_data: - try: - each["source_pk"] = source_pk_by_catalogid[each["catalogid"]] - except: - # log warning? - n_warnings += 1 - finally: - for source_key_only in source_keys_only: - each.pop(source_key_only, None) - - try: - # Missing catalogid! - if np.all(each["catalogid"].mask): - each["catalogid"] = -1 # cannot be null - except: - None - - - if n_warnings > 0: - log.warning(f"There were {n_warnings} spectra with no source_pk, probably because of missing or fake catalogids") - - pks = upsert_many( - BossVisitSpectrum, - BossVisitSpectrum.pk, - spectrum_data, - batch_size, - queue, - "Upserting spectra" - ) - - # Assign spectrum_pk values to any spectra missing it. - N = len(pks) - if pks: - with tqdm(total=N, desc="Assigning primary keys to spectra") as pb: - N_assigned = 0 - for batch in chunked(pks, batch_size): - B = ( - BossVisitSpectrum - .update( - spectrum_pk=Case(None, ( - (BossVisitSpectrum.pk == pk, spectrum_pk) for spectrum_pk, pk in enumerate_new_spectrum_pks(batch) - )) - ) - .where(BossVisitSpectrum.pk.in_(batch)) - .execute() - ) - pb.update(B) - N_assigned += B - - log.info(f"There were {N} spectra inserted and we assigned {N_assigned} spectra with new spectrum_pk values") - else: - log.info(f"No new spectra inserted") - - return N - - - -def _migrate_specfull_metadata(spectra, fields, raise_exceptions=False, full_output=False): +def _migrate_specfull_metadata(spectra, fields, raise_exceptions=True, full_output=False): K = len(fields) keys_str = "|".join([f"({k})" for k in fields.values()]) @@ -628,10 +353,13 @@ def _migrate_specfull_metadata(spectra, fields, raise_exceptions=False, full_out def migrate_specfull_metadata_from_image_headers( where=(BossVisitSpectrum.alt.is_null() & (BossVisitSpectrum.catalogid > 0)), - max_workers: Optional[int] = 8, + max_workers: Optional[int] = 128, limit: Optional[int] = None, batch_size: Optional[int] = 100, + queue = None ): + if queue is None: + queue = NoQueue() q = ( BossVisitSpectrum @@ -640,11 +368,7 @@ def migrate_specfull_metadata_from_image_headers( if where: q = q.where(where) - q = ( - q - .limit(limit) - .iterator() - ) + q = q.limit(limit) fields = { BossVisitSpectrum.plateid: "PLATEID", @@ -681,16 +405,16 @@ def migrate_specfull_metadata_from_image_headers( BossVisitSpectrum.schi2max: "SCHI2MAX", } - executor = concurrent.futures.ProcessPoolExecutor(max_workers) + executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) + queue.put(dict(total=q.count())) specFulls, futures, total = ({}, [], 0) - with tqdm(total=limit or 0, desc="Submitting work", unit="spectra") as pb: - for chunk in chunked(q, batch_size): - futures.append(executor.submit(_migrate_specfull_metadata, chunk, fields)) - for total, spec in enumerate(chunk, start=1 + total): - specFulls[spec.pk] = spec - pb.update() - + for chunk in chunked(q, batch_size): + futures.append(executor.submit(_migrate_specfull_metadata, chunk, fields)) + #for total, spec in enumerate(chunk, start=1 + total): + for spec in chunk: + specFulls[spec.pk] = spec + defaults = { "n_guide": -1, "airtemp": np.nan, @@ -699,35 +423,36 @@ def migrate_specfull_metadata_from_image_headers( } all_missing_counts = {} - with tqdm(total=total, desc="Collecting headers", unit="spectra") as pb: - for future in concurrent.futures.as_completed(futures): - metadata, missing_counts = future.result() - - for name, missing_count in missing_counts.items(): - all_missing_counts.setdefault(name, 0) - all_missing_counts[name] += missing_count - - for pk, meta in metadata.items(): - for key, value in meta.items(): + for future in concurrent.futures.as_completed(futures): + metadata, missing_counts = future.result() + for name, missing_count in missing_counts.items(): + all_missing_counts.setdefault(name, 0) + all_missing_counts[name] += missing_count + + for pk, meta in metadata.items(): + for key, value in meta.items(): + setattr(specFulls[pk], key, value) + for key, value in defaults.items(): + if key not in meta: setattr(specFulls[pk], key, value) - for key, value in defaults.items(): - if key not in meta: - setattr(specFulls[pk], key, value) - pb.update() - - if all_missing_counts: - log.warning(f"There were missing keys:") - for name, count in all_missing_counts.items(): - log.warning(f"\t{name}: {count} missing") - - with tqdm(total=total, desc="Updating", unit="spectra") as pb: - for chunk in chunked(specFulls.values(), batch_size): - pb.update( - BossVisitSpectrum - .bulk_update( - chunk, - fields=list(fields.keys()) - ) - ) + + queue.put(dict(advance=1)) - return pb.n + #if all_missing_counts: + # log.warning(f"There were missing keys:") + # for name, count in all_missing_counts.items(): + # log.warning(f"\t{name}: {count} missing") + + queue.put(dict(total=len(specFulls), completed=0, description="Ingesting specFull metadata")) + for chunk in chunked(specFulls.values(), batch_size): + ( + BossVisitSpectrum + .bulk_update( + chunk, + fields=list(fields.keys()) + ) + ) + queue.put(dict(advance=batch_size)) + + queue.put(Ellipsis) + return None diff --git a/src/astra/migrations/targeting.py b/src/astra/migrations/targeting.py index aa32306..3d564ab 100644 --- a/src/astra/migrations/targeting.py +++ b/src/astra/migrations/targeting.py @@ -80,14 +80,15 @@ class Meta: chunk_dict[sdss_id].sdss5_target_flags.set_bit(bit) update_dict[sdss_id] = chunk_dict[sdss_id] - with database.atomic(): - ( - Source - .bulk_update( - update_dict.values(), - fields=[Source.sdss5_target_flags] + if update_dict: + with database.atomic(): + ( + Source + .bulk_update( + update_dict.values(), + fields=[Source.sdss5_target_flags] + ) ) - ) queue.put(dict(advance=batch_size)) queue.put(Ellipsis) diff --git a/src/astra/models/apogee.py b/src/astra/models/apogee.py index 04f314a..d447694 100644 --- a/src/astra/models/apogee.py +++ b/src/astra/models/apogee.py @@ -267,11 +267,21 @@ class Meta: ), True, ), + # The following index is just to make updating visit counts faster. + ( + ( + "source_pk", + "telescope", + "mjd", + "fiber", + "plate", + "field" + ), + False + ) ) - - class ApogeeVisitSpectrumInApStar(BaseModel, SpectrumMixin): """An APOGEE stacked spectrum, stored in an apStar data product.""" diff --git a/src/astra/models/boss.py b/src/astra/models/boss.py index 899b6cc..52c1c23 100644 --- a/src/astra/models/boss.py +++ b/src/astra/models/boss.py @@ -22,6 +22,7 @@ class BossVisitSpectrum(BaseModel, SpectrumMixin): index=True, unique=True, lazy_load=False, + column_name="spectrum_pk" ) source = ForeignKeyField( Source, @@ -31,6 +32,9 @@ class BossVisitSpectrum(BaseModel, SpectrumMixin): backref="boss_visit_spectra" ) + created = DateTimeField(default=datetime.datetime.now) + modified = DateTimeField(default=datetime.datetime.now) + #> Spectral data wavelength = PixelArray( ext=1, diff --git a/src/astra/models/mwm.py b/src/astra/models/mwm.py index 2903672..816f02e 100644 --- a/src/astra/models/mwm.py +++ b/src/astra/models/mwm.py @@ -10,6 +10,7 @@ DeferredForeignKey, fn, ) +import datetime import numpy as np from astra import __version__ from astra.utils import log @@ -385,6 +386,7 @@ class ApogeeCombinedSpectrum(MWMStarMixin, SpectrumMixin): index=True, unique=True, lazy_load=False, + column_name="spectrum_pk", help_text=Glossary.spectrum_pk, ) # Won't appear in a header group because it is first referenced in `Source`. @@ -407,6 +409,9 @@ class ApogeeCombinedSpectrum(MWMStarMixin, SpectrumMixin): healpix = IntegerField(help_text=Glossary.healpix) # This should be the same as the Source-level field. sdss_id = BigIntegerField(index=True, unique=False, null=True, help_text="SDSS-5 unique identifier") + created = DateTimeField(default=datetime.datetime.now) + modified = DateTimeField(default=datetime.datetime.now) + #> Related Data Product Keywords apred = TextField(help_text=Glossary.apred) obj = TextField(help_text=Glossary.obj) @@ -451,8 +456,8 @@ class ApogeeCombinedSpectrum(MWMStarMixin, SpectrumMixin): autofwhm = FloatField(null=True, help_text=Glossary.autofwhm) n_components = IntegerField(null=True, help_text=Glossary.n_components) - #> Provenance - input_spectrum_pks = ArrayField(IntegerField, null=True, help_text="DRP visit PKs") + ##> Provenance + #input_spectrum_pks = ArrayField(IntegerField, null=True, help_text="DRP visit PKs") #> Spectral Data wavelength = PixelArray( diff --git a/src/astra/pipelines/apogeenet/__init__.py b/src/astra/pipelines/apogeenet/__init__.py index bda69d6..965a61b 100644 --- a/src/astra/pipelines/apogeenet/__init__.py +++ b/src/astra/pipelines/apogeenet/__init__.py @@ -267,35 +267,17 @@ def reverse_inverse_error(inverse_error: np.array, default_error: int) -> np.arr from astra import task, __version__ from astra.utils import log, expand_path -from astra.models import ApogeeCoaddedSpectrumInApStar, ApogeeVisitSpectrumInApStar +from astra.models.apogee import ApogeeCoaddedSpectrumInApStar, ApogeeVisitSpectrumInApStar from astra.models.apogeenet import ApogeeNet -from peewee import JOIN, ModelSelect from typing import Optional, Iterable, Union @task -def apogeenet( - spectra: Optional[Iterable[Union[ApogeeVisitSpectrumInApStar, ApogeeCoaddedSpectrumInApStar]]] = ( - ApogeeCoaddedSpectrumInApStar - .select() - .join( - ApogeeNet, - JOIN.LEFT_OUTER, - on=( - (ApogeeCoaddedSpectrumInApStar.spectrum_pk == ApogeeNet.spectrum_pk) - & (ApogeeNet.v_astra == __version__) - ) - ) - .where(ApogeeNet.spectrum_pk.is_null()) - ), - num_uncertainty_draws: Optional[int] = 20, - limit=None, - **kwargs -) -> Iterable[ApogeeNet]: +def apogeenet(spectra: Iterable[Union[ApogeeVisitSpectrumInApStar, ApogeeCoaddedSpectrumInApStar]], num_uncertainty_draws: Optional[int] = 20, **kwargs) -> Iterable[ApogeeNet]: """ Run the ANet (APOGEENet III) pipeline. - + """ @@ -311,13 +293,7 @@ def apogeenet( if torch.cuda.is_available(): model.cuda() - if isinstance(spectra, ModelSelect): - if limit is not None: - spectra = spectra.limit(limit) - # Note: if you don't use the `.iterator()` you may get out-of-memory issues from the GPU nodes - spectra = spectra.iterator() - - for spectrum in tqdm(spectra, total=0): + for spectrum in spectra: try: flux = np.nan_to_num(spectrum.flux, nan=0.0).astype(np.float32) @@ -329,15 +305,10 @@ def apogeenet( log_G,log_Teff,FeH,log_G_std,log_Teff_std,Feh_std = make_prediction(flux, e_flux, None, num_uncertainty_draws,model,device) except: log.exception(f"Exception when running ApogeeNet on {spectrum}") - yield ApogeeNet( - spectrum_pk=spectrum.spectrum_pk, - source_pk=spectrum.source_pk, - flag_runtime_exception=True - ) + yield ApogeeNet.from_spectrum(spectrum, flag_runtime_exception=True) else: - yield ApogeeNet( - spectrum_pk=spectrum.spectrum_pk, - source_pk=spectrum.source_pk, + yield ApogeeNet.from_spectrum( + spectrum, fe_h=FeH, e_fe_h=Feh_std, logg=log_G, diff --git a/src/astra/pipelines/bossnet/__init__.py b/src/astra/pipelines/bossnet/__init__.py index 945893c..4c6801d 100644 --- a/src/astra/pipelines/bossnet/__init__.py +++ b/src/astra/pipelines/bossnet/__init__.py @@ -349,8 +349,8 @@ def make_prediction(spectra, error, wavlen,num_uncertainty_draws,model,device): from astra import task, __version__ from astra.utils import log, expand_path -from astra.models import BossVisitSpectrum -from astra.models import BossNet +from astra.models.boss import BossVisitSpectrum +from astra.models.bossnet import BossNet from peewee import JOIN from typing import Optional, Iterable @@ -360,22 +360,7 @@ def make_prediction(spectra, error, wavlen,num_uncertainty_draws,model,device): interpolate_flux_err = partial(interpolate_flux_err, linear_grid=linear_grid) @task -def bossnet( - spectra: Optional[Iterable[BossVisitSpectrum]] = ( - BossVisitSpectrum - .select() - .join( - BossNet, - JOIN.LEFT_OUTER, - on=( - (BossVisitSpectrum.spectrum_pk == BossNet.spectrum_pk) - & (BossNet.v_astra == __version__) - ) - ) - .where(BossNet.spectrum_pk.is_null()) - ), - num_uncertainty_draws: Optional[int] = 20 -) -> Iterable[BossNet]: +def bossnet(spectra: Iterable[BossVisitSpectrum], num_uncertainty_draws: Optional[int] = 20) -> Iterable[BossNet]: model = BossNetModel() model_path = expand_path("$MWM_ASTRA/pipelines/BossNet/deconstructed_model") @@ -392,7 +377,7 @@ def bossnet( # Note: if you don't use the `.iterator()` you may get out-of-memory issues from the GPU nodes spectra = spectra.iterator() - for spectrum in tqdm(spectra, total=0): + for spectrum in spectra: try: flux = np.nan_to_num(spectrum.flux, nan=0.0).astype(np.float32) @@ -421,8 +406,8 @@ def bossnet( e_logg=log_G_std, teff=10**log_Teff, e_teff=10**log_Teff * log_Teff_std * np.log(10), - v_rad=rv, - e_v_rad=rv_std + bn_v_r=rv, + e_bn_v_r=rv_std ) diff --git a/src/astra/utils/__init__.py b/src/astra/utils/__init__.py index 09ffb28..ee592c9 100644 --- a/src/astra/utils/__init__.py +++ b/src/astra/utils/__init__.py @@ -1,4 +1,5 @@ import os +import re from sdsstools.logger import get_logger as _get_logger, StreamFormatter import warnings import inspect @@ -43,6 +44,17 @@ def version_string_to_integer(version_string): return sum(p * 10**(6 - i * 3) for i, p in enumerate(parts)) +def get_task_group_by_string(fun, safe=True): + try: + return re.search(r"group_by=(.*)", inspect.getsource(fun)).group(1).strip("()") + except: + if safe: + return None + else: + raise + + + def get_logger(kwargs=None): logger = _get_logger("astra", **(kwargs or {})) # https://stackoverflow.com/questions/6729268/log-messages-appearing-twice-with-python-logging @@ -75,9 +87,16 @@ def get_return_type(fun): raise ValueError(f"Cannot infer output model for task {fun}, is it missing a type annotation?") +def resolve_model(model_str): + *pkg, name = model_str.split(".") + parent = import_module(f"astra.models.{pkg[0]}" if pkg else "astra.models") + return getattr(parent, name) + + + def resolve_task(task_str): - for prefix in ("", "astra.", "astra.pipelines.", f"astra.pipelines.{task_str}."): + for prefix in ( f"astra.pipelines.{task_str}.", "", "astra.", "astra.pipelines.",): try: resolved_task = f"{prefix}{task_str}" f = callable(resolved_task)