Skip to content

Commit

Permalink
aspcap restart in list mode
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Jan 16, 2025
1 parent 9e0a4e0 commit 4d66853
Show file tree
Hide file tree
Showing 7 changed files with 468 additions and 115 deletions.
5 changes: 1 addition & 4 deletions src/astra/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,5 @@ def from_spectrum(cls, spectrum, **kwargs):
if given_spectrum_pk is not None and given_spectrum_pk != spectrum.spectrum_pk:
raise ValueError(f"`spectrum_pk` mismatch between `spectrum` and `spectrum_pk` argument ({spectrum.spectrum_pk} != {given_spectrum_pk})")

kwds.update({
"spectrum_pk": spectrum.spectrum_pk,
"source_pk": spectrum.source_pk,
})
kwds.update(spectrum_pk=spectrum.spectrum_pk, source_pk=spectrum.source_pk)
return cls(**kwds)
364 changes: 285 additions & 79 deletions src/astra/pipelines/aspcap/__init__.py

Large diffs are not rendered by default.

27 changes: 16 additions & 11 deletions src/astra/pipelines/aspcap/abundances.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def plan_abundances_stage(
parent_dir: str,
stellar_parameter_results,
element_weight_paths: str,
use_ferre_list_mode: Optional[bool] = False,
continuum_order: Optional[int] = -1,
continuum_flag: Optional[int] = 0,
continuum_observations_flag: Optional[int] = 0,
Expand Down Expand Up @@ -304,20 +305,24 @@ def plan_abundances_stage(


# Group together as necessary
grouped = {}
for plan in plans:
short_grid_name = parse_header_path(plan["header_path"])["short_grid_name"]
if use_ferre_list_mode:
grouped = {}
for plan in plans:
short_grid_name = parse_header_path(plan["header_path"])["short_grid_name"]

grouped.setdefault(short_grid_name, [])
grouped[short_grid_name].append(plan)
grouped.setdefault(short_grid_name, [])
grouped[short_grid_name].append(plan)

#spectra_with_no_stellar_parameters -= set(grid_kwds["spectra"])

grouped_plans = list(grouped.values())
#spectra_with_no_stellar_parameters = tuple(spectra_with_no_stellar_parameters)
#return (plans, spectra_with_no_stellar_parameters)
return grouped_plans
#spectra_with_no_stellar_parameters -= set(grid_kwds["spectra"])

grouped_plans = list(grouped.values())
#spectra_with_no_stellar_parameters = tuple(spectra_with_no_stellar_parameters)
#return (plans, spectra_with_no_stellar_parameters)
return grouped_plans

else:
# In each directory, create symbolic links to the flux/e_flux arrays
return [[plan] for plan in plans]


def get_species(weight_path):
Expand Down
6 changes: 6 additions & 0 deletions src/astra/pipelines/ferre/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@ def post_execution_interpolation(pwd, n_threads=128, f_access=1, epsilon=0.001):
control_kwds = parse_control_kwds(input_path)

# parse synthfile headers to get the edges
# TODO: so hacky. just give post_execution_interpolation an input nml path and a reference dir.

synthfile = control_kwds["SYNTHFILE(1)"]
for check in (synthfile, f"{pwd}/{synthfile}", f"{pwd}/../{synthfile}"):
if os.path.exists(check):
synthfile = check
break
headers = read_ferre_headers(synthfile)

output_parameter_path = os.path.join(f"{pwd}/{os.path.basename(control_kwds['OPFILE'])}")
Expand Down
51 changes: 35 additions & 16 deletions src/astra/pipelines/ferre/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,36 @@ def post_process_ferre(input_nml_path, **kwargs) -> list[dict]:
for kwds in post_process_ferre(dir, ref_dir, skip_pixel_arrays=skip_pixel_arrays, **kwargs):
yield FerreChemicalAbundances(**kwds)
"""
is_abundance_mode = input_nml_path.endswith("input_list.nml")
if is_abundance_mode:
#is_abundance_mode = input_nml_path.endswith("input_list.nml")
#if is_abundance_mode:
if "input_list.nml" in input_nml_path:
abundance_dir = os.path.dirname(input_nml_path)
with open(input_nml_path, "r") as fp:
dirs = [os.path.join(abundance_dir, line.split("/")[0]) for line in fp.read().strip().split("\n")]

# TODO: this might be slow we can probably use -l mode
for d in dirs:
post_execution_interpolation(d)

dirs = [os.path.join(abundance_dir, line.split("/")[0]) for line in fp.read().strip().split("\n")]
ref_dir = os.path.dirname(input_nml_path)
v = []
for d in dirs:
v.extend(list(_post_process_ferre(d, ref_dir, skip_pixel_arrays=True, **kwargs)))
return v

else:
directory = os.path.dirname(input_nml_path)
post_execution_interpolation(directory)
v = list(_post_process_ferre(directory, **kwargs))
dirs = [os.path.dirname(input_nml_path)]
if "/abundances/" in input_nml_path:
ref_dir = os.path.dirname(os.path.dirname(input_nml_path))
else:
ref_dir = os.path.dirname(input_nml_path)

# TODO: this might be slow we can probably use -l mode
for d in dirs:
post_execution_interpolation(d)

v = []
for d in dirs:
v.extend(list(_post_process_ferre(d, ref_dir, skip_pixel_arrays=True, **kwargs)))
return v
#else:
# directory = os.path.dirname(input_nml_path)
# post_execution_interpolation(directory)
# v = list(_post_process_ferre(directory, **kwargs))
return v


def _post_process_ferre(dir, pwd=None, skip_pixel_arrays=False, **kwargs) -> Iterable[dict]:
"""
Post-process results from a FERRE execution.
Expand Down Expand Up @@ -139,9 +148,17 @@ def _post_process_ferre(dir, pwd=None, skip_pixel_arrays=False, **kwargs) -> Ite
else:
is_missing_rectified_model_flux = ~np.all(np.isfinite(rectified_model_flux), axis=1)
'''

#np.atleast_1d(np.loadtxt(f"{prefix}/model_flux.output", usecols=(0, ), dtype=str)),
#np.atleast_1d(np.loadtxt(f"{prefix}/rectified_flux.output", usecols=(0, ), dtype=str)),
#np.atleast_1d(np.loadtxt(f"{prefix}/rectified_model_flux.output", usecols=(0, ), dtype=str)),

#print(f"DIR -> {dir}\n{offile_path}")
parameter_input_path = os.path.join(dir, "parameter.input")
os.system(f"vaffoff {parameter_input_path} {offile_path}")
for basename in ("model_flux.output", "rectified_flux.output", "rectified_model_flux.output"):
if os.path.exists(os.path.join(dir, basename)):
os.system(f"vaffoff {parameter_input_path} {os.path.join(dir, basename)}")

is_missing_rectified_model_flux = ~np.isfinite(np.atleast_1d(np.loadtxt(offile_path, usecols=(1, ), dtype=float)))
names_with_missing_rectified_model_flux = input_names[is_missing_rectified_model_flux]

Expand Down Expand Up @@ -208,6 +225,8 @@ def _post_process_ferre(dir, pwd=None, skip_pixel_arrays=False, **kwargs) -> Ite

# Create some boolean flags.
header_path = control_kwds["SYNTHFILE(1)"]
if not os.path.exists(header_path):
header_path = os.path.join(ref_dir, header_path)
headers, *segment_headers = read_ferre_headers(expand_path(header_path))
bad_lower = headers["LLIMITS"] + headers["STEPS"] / 8
bad_upper = headers["ULIMITS"] - headers["STEPS"] / 8
Expand Down
41 changes: 36 additions & 5 deletions src/astra/pipelines/ferre/pre_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def pre_process_ferres(plans):
input_nml_paths.append(input_nml_path_[len(abundance_dir) + 1:]) # ppaths too long
total += n_obj_
for spectrum, kwds in skipped_:
skipped.append((spectrum, kwds))

skipped.append((spectrum, kwds))

# Create a FERRE list file.
input_nml_path = os.path.join(abundance_dir, "input_list.nml")
Expand Down Expand Up @@ -221,8 +220,13 @@ def pre_process_ferre(
continue

if pre_computed_continuum is not None:
flux /= pre_computed_continuum[index]
e_flux /= pre_computed_continuum[index]
continuum = pre_computed_continuum[index]
flux /= continuum
e_flux /= continuum

#bad = ((flux < 0) | (e_flux <= 0))
#flux[bad] = 0.01
e_flux = np.clip(e_flux, 0.005, np.inf)

batch_flux.append(flux[mask])
batch_e_flux.append(e_flux[mask])
Expand All @@ -236,12 +240,22 @@ def pre_process_ferre(
batch_initial_parameters.append(initial_parameters)
index += 1

assert np.all(np.array(batch_e_flux) > 0)

#if len(skipped) > 0:
# log.warning(f"Skipping {len(skipped)} spectra ({100 * len(skipped) / len(spectra):.0f}%; of {len(spectra)})")

if not batch_initial_parameters:
return (pwd, 0, 0, skipped)



synthfile_full_path = control_kwds["synthfile(1)"]
if reference_pixel_arrays_for_abundance_run:
control_kwds["synthfile(1)"] = os.path.basename(synthfile_full_path)
else:
control_kwds["synthfile(1)"] = os.path.basename(synthfile_full_path)

control_kwds_formatted = utils.format_ferre_control_keywords(control_kwds, n_obj=1 + index)

# Convert list of dicts of initial parameters to array.
Expand All @@ -254,6 +268,20 @@ def pre_process_ferre(
)
# Create directory and write the control file
os.makedirs(absolute_pwd, exist_ok=True)
if reference_pixel_arrays_for_abundance_run:
if write_input_pixel_arrays:
target_path_prefix = f"{absolute_pwd}/../{os.path.basename(synthfile_full_path)}"[:-4]
if not os.path.exists(f"{target_path_prefix}.hdf"):
os.system(f"ln -s {synthfile_full_path} {target_path_prefix}.hdr")
if not os.path.exists(f"{target_path_prefix}.unf"):
os.system(f"ln -s {synthfile_full_path[:-4]}.unf {target_path_prefix}.unf")
else:
target_path_prefix = f"{absolute_pwd}/{os.path.basename(synthfile_full_path)}"[:-4]
if not os.path.exists(f"{target_path_prefix}.hdr"):
os.system(f"ln -s {synthfile_full_path} {target_path_prefix}.hdr")
if not os.path.exists(f"{target_path_prefix}.unf"):
os.system(f"ln -s {synthfile_full_path[:-4]}.unf {target_path_prefix}.unf")

with open(os.path.join(absolute_pwd, "input.nml"), "w") as fp:
fp.write(control_kwds_formatted)

Expand Down Expand Up @@ -287,7 +315,10 @@ def pre_process_ferre(
savetxt_kwds = dict(fmt="%.4e")#footer="\n")
np.savetxt(flux_path, batch_flux, **savetxt_kwds)
np.savetxt(e_flux_path, batch_e_flux, **savetxt_kwds)


#if reference_pixel_arrays_for_abundance_run:
# for basename in (control_kwds["ffile"], control_kwds["erfile"]):
# os.system(f"ln -s {absolute_pwd}/../{basename} {absolute_pwd}/{basename}")
n_obj = len(batch_names)
return (f"{pwd}/input.nml", n_obj, min(n_threads, n_obj), skipped)

Expand Down
89 changes: 89 additions & 0 deletions src/astra/pipelines/ferre/re_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import os
from collections import Counter

def get_suffix(input_nml_path):
try:
_, suffix = input_nml_path.split(".nml.")
except ValueError:
return 0
else:
return int(suffix)

def get_new_path(existing_path, new_suffix):
if new_suffix == 1:
return f"{existing_path}.{new_suffix}"
else:
return ".".join(existing_path.split(".")[:-1]) + f".{new_suffix}"


def re_process_partial_ferre(existing_input_nml_path, pwd=None, exclude_indices=None):

if pwd is None:
pwd = os.path.dirname(existing_input_nml_path)

existing_suffix = get_suffix(existing_input_nml_path)
new_suffix = existing_suffix + 1

new_input_nml_path = get_new_path(existing_input_nml_path, new_suffix)

with open(existing_input_nml_path, "r") as f:
lines = f.readlines()

keys = ("PFILE", "OFFILE", "ERFILE", "OPFILE", "FFILE", "SFFILE")
paths = {}
for i, line in enumerate(lines):
key = line.split("=")[0].strip()
if key in keys:
existing_relative_path = line.split("=")[1].strip("' \n")
new_relative_path = get_new_path(existing_relative_path, new_suffix)
lines[i] = line[:line.index("=")] + f"= '{new_relative_path}'"
paths[key] = (existing_relative_path, new_relative_path)

with open(new_input_nml_path, "w") as fp:
fp.write("\n".join(lines))

# Find the things that are already written in all three output files.
output_path_keys = ["OFFILE", "OPFILE"]
if os.path.exists(os.path.join(pwd, paths["SFFILE"][0])):
output_path_keys.append("SFFILE")

counts = []
for key in output_path_keys:
names = set(np.loadtxt(os.path.join(pwd, paths[key][0]), usecols=(0, ), dtype=str))
counts.extend(names)

completed_names = [k for k, v in Counter(counts).items() if v == len(output_path_keys)]
input_names = np.loadtxt(os.path.join(pwd, paths["PFILE"][0]), usecols=(0, ), dtype=str)

ignore_names = [] + completed_names
if exclude_indices is not None:
ignore_names.extend([input_names[int(idx)] for idx in exclude_indices])

mask = [(name not in ignore_names) for name in names]
if not any(mask):
return (None, None)

# Create new input files that ignore specific names.
for key in ("PFILE", "ERFILE", "FFILE"):
existing_path, new_path = paths[key]
with open(os.path.join(pwd, existing_path), "r") as f:
lines = f.readlines()

with open(os.path.join(pwd, new_path), "w") as f:
for line, m in zip(lines, mask):
if m:
f.write(line)

# Clean up the output files to only include things that are written in all three files.
for key in output_path_keys:
existing_path, new_path = paths[key]
with open(os.path.join(pwd, existing_path), "r") as f:
lines = f.readlines()

lines = [line for line in lines if line.split()[0].strip() in completed_names]
with open(os.path.join(pwd, existing_path) + ".cleaned", "w") as fp:
fp.write("\n".join(lines))

ignore_names = list(ignore_names)
return (new_input_nml_path, ignore_names)

0 comments on commit 4d66853

Please sign in to comment.