diff --git a/src/astra/models/aspcap.py b/src/astra/models/aspcap.py index 9f33e25..eef3fd9 100644 --- a/src/astra/models/aspcap.py +++ b/src/astra/models/aspcap.py @@ -10,36 +10,74 @@ DateTimeField, BooleanField, ) + from astra.models.pipeline import PipelineOutputModel from astra.models.ferre import FerreCoarse, FerreStellarParameters, FerreChemicalAbundances from astra.models.source import Source from astra.models.spectrum import Spectrum from astra.glossary import Glossary from playhouse.hybrid import hybrid_property +from functools import cached_property +from astra.pipelines.ferre.utils import (get_apogee_pixel_mask, parse_ferre_spectrum_name) + +APOGEE_FERRE_MASK = get_apogee_pixel_mask() + +""" + @cached_property + def ferre_flux(self): + return self._get_pixel_array("params/flux.input") + + @cached_property + def ferre_e_flux(self): + return self._get_pixel_array("params/e_flux.input") + + #@cached_property + #def model_flux(self): + # return self._get_pixel_array("params/model_flux.output") + + @cached_property + def rectified_model_flux(self): + return self._get_pixel_array("params/rectified_model_flux.output") + + @cached_property + def rectified_flux(self): + return self._get_pixel_array("params/rectified_flux.output") + @cached_property + def e_rectified_flux(self): + continuum = self.ferre_flux / self.rectified_flux + return self.ferre_e_flux / continuum +""" -class StellarParameterPixelAccessor(BasePixelArrayAccessor): +class ASPCAPPixelArrayAccessor(BasePixelArrayAccessor): def __get__(self, instance, instance_type=None): if instance is not None: - try: - return instance.__pixel_data__[self.name] - except (AttributeError, KeyError): - # Load them all. + if not hasattr(instance, "__pixel_data__"): instance.__pixel_data__ = {} - upstream = FerreStellarParameters.get(instance.stellar_parameters_task_pk) - continuum = upstream.unmask( - (upstream.rectified_model_flux/upstream.model_flux) - / (upstream.rectified_flux/upstream.ferre_flux) - ) - - instance.__pixel_data__.setdefault("continuum", continuum) - instance.__pixel_data__.setdefault("model_flux", upstream.unmask(upstream.model_flux)) - - return instance.__pixel_data__[self.name] - + # Stellar parameter case first, since we have to load a bunch of stuff. + if self.name not in instance.__pixel_data__: + if self.name in ("model_flux", "continuum"): + rectified_model_flux = instance._get_output_pixel_array("params", "rectified_model_flux.output") + model_flux = instance._get_output_pixel_array("params", "model_flux.output") + rectified_flux = instance._get_output_pixel_array("params", "rectified_flux.output") + ferre_flux = instance._get_input_pixel_array("params", "flux.input") + + continuum = instance._unmask_pixel_array( + (rectified_model_flux/model_flux) / (rectified_flux/ferre_flux) + ) + instance.__pixel_data__.setdefault("continuum", continuum) + instance.__pixel_data__.setdefault("model_flux", instance._unmask_pixel_array(model_flux)) + + else: + # Chemical abundance pixel array. + x_h = self.name[len("model_flux_"):] + #isntance._get_output_pixel_array("abundances", "") + raise NotImplementedError + + return instance.__pixel_data__[self.name] return self.field @@ -84,6 +122,7 @@ class ASPCAP(PipelineOutputModel): """ APOGEE Stellar Parameter and Chemical Abundances Pipeline (ASPCAP) """ + #> Spectral Data wavelength = PixelArray( accessor_class=LogLambdaArrayAccessor, @@ -94,11 +133,11 @@ class ASPCAP(PipelineOutputModel): ), ) model_flux = PixelArray( - accessor_class=StellarParameterPixelAccessor, + accessor_class=ASPCAPPixelArrayAccessor, help_text="Model flux at optimized stellar parameters" ) continuum = PixelArray( - accessor_class=StellarParameterPixelAccessor, + accessor_class=ASPCAPPixelArrayAccessor, help_text="Continuum" ) @@ -624,6 +663,9 @@ def flag_bad(self): coarse_rchi2 = FloatField(null=True, help_text=Glossary.coarse_rchi2) coarse_penalized_rchi2 = FloatField(null=True, help_text="Penalized reduced chi-squared for coarse grid") + pwd = TextField(null=True, help_text="Working directory") + ferre_index = IntegerField(null=True, help_text="Index of the FERRE run") + """ #> Task Primary Keys stellar_parameters_task_pk = ForeignKeyField(FerreStellarParameters, unique=True, null=True, lazy_load=False, help_text="Task primary key for stellar parameters") @@ -730,6 +772,34 @@ def flag_bad(self): raw_e_v_h = FloatField(null=True, help_text=Glossary.raw_e_v_h) + def _unmask_pixel_array(self, array, fill_value=np.nan): + unmasked_array = fill_value * np.ones(APOGEE_FERRE_MASK.shape) + unmasked_array[APOGEE_FERRE_MASK] = array + return unmasked_array + + + def _get_pixel_array_kwds(self, stage, name, **kwargs): + kwds = dict( + fname=f"{self.pwd}/{stage}/{self.short_grid_name}/{name}", + skiprows=int(self.ferre_index), + max_rows=1, + ) + return kwds + + def _get_input_pixel_array(self, stage, name): + return np.loadtxt(**self._get_pixel_array_kwds(stage, name)) + + def _get_output_pixel_array(self, stage, name, P=7514): + kwds = self._get_pixel_array_kwds(stage, name) + name, = np.atleast_1d(np.loadtxt(usecols=(0, ), dtype=str, **kwds)) + array = np.loadtxt(usecols=range(1, 1+P), **kwds) + meta = parse_ferre_spectrum_name(name) + assert int(meta["source_pk"]) == self.source_pk + assert int(meta["spectrum_pk"]) == self.spectrum_pk + assert int(meta["index"]) == self.ferre_index + return array + + def apply_noise_model(): diff --git a/src/astra/models/ferre.py b/src/astra/models/ferre.py index fbaa1f6..5f5bf38 100644 --- a/src/astra/models/ferre.py +++ b/src/astra/models/ferre.py @@ -60,18 +60,18 @@ def unmask(self, array, fill_value=np.nan): def _get_input_pixel_array(self, basename): return np.loadtxt( fname=f"{self.pwd}/{basename}", - skiprows=int(self.ferre_input_index), + skiprows=int(self.ferre_index), max_rows=1, ) def _get_output_pixel_array(self, basename, P=7514): - #assert self.ferre_input_index >= 0 + #assert self.ferre_index >= 0 kwds = dict( fname=f"{self.pwd}/{basename}", - skiprows=int(self.ferre_output_index), + skiprows=int(self.ferre_index), max_rows=1, ) ''' @@ -83,17 +83,17 @@ def _get_output_pixel_array(self, basename, P=7514): if ( (int(meta["source_pk"]) != self.source_pk) or (int(meta["spectrum_pk"]) != self.spectrum_pk) - or (int(meta["index"]) != self.ferre_input_index) + or (int(meta["index"]) != self.ferre_index) ): raise a except: del kwds["skiprows"] del kwds["max_rows"] - name = get_ferre_spectrum_name(self.ferre_input_index, self.source_pk, self.spectrum_pk, self.initial_flags, self.upstream_id) + name = get_ferre_spectrum_name(self.ferre_index, self.source_pk, self.spectrum_pk, self.initial_flags, self.upstream_id) index = list(np.loadtxt(usecols=(0, ), dtype=str, **kwds)).index(name) - self.ferre_output_index = index + self.ferre_index = index self.save() print("saved!") kwds["skiprows"] = index @@ -108,7 +108,7 @@ def _get_output_pixel_array(self, basename, P=7514): meta = parse_ferre_spectrum_name(name) assert int(meta["source_pk"]) == self.source_pk assert int(meta["spectrum_pk"]) == self.spectrum_pk - assert int(meta["index"]) == self.ferre_input_index + assert int(meta["index"]) == self.ferre_index return array @@ -180,8 +180,7 @@ class FerreCoarse(PipelineOutputModel, FerreOutputMixin): #> FERRE Access Fields ferre_name = TextField(default="") - ferre_input_index = IntegerField(default=-1) - ferre_output_index = IntegerField(default=-1) + ferre_index = IntegerField(default=-1) ferre_n_obj = IntegerField(default=-1) #> Summary Statistics @@ -311,8 +310,8 @@ class FerreStellarParameters(PipelineOutputModel, FerreOutputMixin): # TODO: flag definitions for each dimension (DRY) #> FERRE Access Fields ferre_name = TextField(default="") - ferre_input_index = IntegerField(default=-1) - ferre_output_index = IntegerField(default=-1) + ferre_index = IntegerField(default=-1) + ferre_index = IntegerField(default=-1) ferre_n_obj = IntegerField(default=-1) #> Summary Statistics @@ -450,8 +449,8 @@ def ferre_e_flux(self): # TODO: flag definitions for each dimension (DRY) #> FERRE Access Fields ferre_name = TextField(default="") - ferre_input_index = IntegerField(default=-1) - ferre_output_index = IntegerField(default=-1) + ferre_index = IntegerField(default=-1) + ferre_index = IntegerField(default=-1) ferre_n_obj = IntegerField(default=-1) #> Summary Statistics diff --git a/src/astra/pipelines/aspcap/abundances.py b/src/astra/pipelines/aspcap/abundances.py index 19ced18..90a5c7b 100644 --- a/src/astra/pipelines/aspcap/abundances.py +++ b/src/astra/pipelines/aspcap/abundances.py @@ -84,7 +84,7 @@ def pre_abundances( The parent directory where these FERRE executions will be planned. """ - ferre_kwds, spectra_with_no_stellar_parameters = plan_abundances( + ferre_kwds, spectra_with_no_stellar_parameters = plan_abundances_stage( spectra, parent_dir, element_weight_paths, @@ -145,9 +145,9 @@ def post_abundances(parent_dir, relative_mode=True, skip_pixel_arrays=True, **kw yield FerreChemicalAbundances(**kwds) -def plan_abundances( +def plan_abundances_stage( spectra: Iterable[Spectrum], - parent_dir: str, + stellar_parameter_results, element_weight_paths: str, continuum_order: Optional[int] = -1, continuum_flag: Optional[int] = 0, @@ -155,108 +155,15 @@ def plan_abundances( **kwargs, ): """ - Plan abundance executions with FERRE for some given spectra, which are assumed to already have `FerreStellarParameter` results. - + Plan abundance executions with FERRE for some given spectra. + In the abundances stage we keep the continuum fixed to what was found from the stellar parameter stage. That's why the defaults are set for `continuum_order`, `continuum_flag`, and `continuum_observations_flag`. - - :param spectra: - The spectra to be processed. - - :param parent_dir: - The parent directory where these FERRE executions will be planned. - - :param element_weight_paths: - A path containing the masks to supply per element. """ with open(expand_path(element_weight_paths), "r") as fp: weight_paths = list(map(str.strip, fp.readlines())) - - if spectra is None: - log.info(f"Retrieving spectra") - - # Get spectrum ids from params stage in parent dir. - spectrum_pks = list(get_input_spectrum_primary_keys(f"{parent_dir}/params")) - if len(spectrum_pks) == 0: - log.warning(f"No spectrum identifiers found in {parent_dir}/params") - return ([], []) - - # TODO: assuming all spectra are the same model type.. - model_class = Spectrum.get(spectrum_pks[0]).resolve().__class__ - spectra = ( - model_class - .select() - .where(model_class.spectrum_pk << spectrum_pks) - ) - else: - spectrum_pks = [s.spectrum_pk for s in spectra] - - parent_dir = sanitise_parent_dir(parent_dir) - - ''' - Alias = FerreStellarParameters.alias() - sq = ( - Alias - .select( - Alias.spectrum_pk.alias("spectrum_pk"), - fn.MIN(Alias.penalized_rchi2).alias("min_penalized_rchi2"), - ) - .where(Alias.spectrum_pk << spectrum_pks) - .where(Alias.pwd.startswith(expand_path(parent_dir))) - .group_by(Alias.spectrum_pk) - .alias("sq") - ) - - q = ( - FerreStellarParameters - .select() - # Only get one result per spectrum. - .where( - FerreStellarParameters.penalized_rchi2.is_null(False) - # Don't calculate abundances for things that failed SPECTACULARLY - & (~FerreStellarParameters.flag_ferre_fail) - & (~FerreStellarParameters.flag_spectrum_io_error) - & (~FerreStellarParameters.flag_no_suitable_initial_guess) - & (~FerreStellarParameters.flag_missing_model_flux) - & (FerreStellarParameters.pwd.startswith(expand_path(parent_dir))) - ) - .join( - sq, - on=( - (FerreStellarParameters.spectrum_pk == sq.c.spectrum_pk) & - (FerreStellarParameters.penalized_rchi2 == sq.c.min_penalized_rchi2) - ) - ) - # We will only get one result per spectrum, but we'll do it by recency. - .order_by(FerreStellarParameters.task_pk.desc()) - ) - ''' - - q = ( - FerreStellarParameters - .select() - .where( - FerreStellarParameters.penalized_rchi2.is_null(False) - # Don't calculate abundances for things that failed SPECTACULARLY - & (~FerreStellarParameters.flag_ferre_fail) - & (~FerreStellarParameters.flag_spectrum_io_error) - & (~FerreStellarParameters.flag_no_suitable_initial_guess) - & (~FerreStellarParameters.flag_missing_model_flux) - & (FerreStellarParameters.pwd.startswith(expand_path(parent_dir))) - ) - ) - - # You can do this by a SQL join, but it gets heavy and we kind of want to know about duplicate results. - all_results = {} - for result in q: - all_results.setdefault(result.spectrum_pk, []) - all_results[result.spectrum_pk].append(result) - - best_results = {} - for spectrum_pk, results in all_results.items(): - best_results[spectrum_pk] = sorted(results, key=lambda r: r.penalized_rchi2)[0] - + # Load abundance keywords on demand. ferre_headers, abundance_keywords = ({}, {}) lookup_spectrum_by_primary_key = { s.spectrum_pk: s for s in spectra } @@ -264,97 +171,77 @@ def plan_abundances( mask = get_apogee_pixel_mask() continuum_cache, continuum_cache_names = ({}, {}) - shown_BA_lsfcombo5_warning = False - done, group_task_kwds, pre_computed_continuum = ([], {}, {}) - for result in tqdm(best_results.values(), total=len(best_results), desc="Planning for abundances"): - if result.spectrum_pk in done: - # We have a more recent FerreStellarParameters result which we will use instead of this one. - log.warning(f"Ignoring stellar parameter result {result} because we have a more recent result for this spectrum_pk={result.spectrum_pk}") - continue - - # TODO: make a better check for this - if result.header_path.find("BA_lsfcombo5") > 0: - if not shown_BA_lsfcombo5_warning: - log.warning(f"Not doing abundances on BA_lsfcombo5 grid") - shown_BA_lsfcombo5_warning = True + t_check = 0 + + group_task_kwds, pre_computed_continuum = ({}, {}) + for result in stellar_parameter_results: + + if result["short_grid_name"].find("combo5_BA") > 0: + # Not doing abundances for BA_lsfcombo5 grids continue - done.append(result.spectrum_pk) - group_task_kwds.setdefault(result.header_path, []) - if result.header_path not in abundance_keywords: - abundance_keywords[result.header_path] = {} + group_task_kwds.setdefault(result["header_path"], []) + if result["header_path"] not in abundance_keywords: + abundance_keywords[result["header_path"]] = {} try: - headers, *segment_headers = ferre_headers[result.header_path] + headers, *segment_headers = ferre_headers[result["header_path"]] except KeyError: - headers, *segment_headers = ferre_headers[result.header_path] = read_ferre_headers(result.header_path) + headers, *segment_headers = ferre_headers[result["header_path"]] = read_ferre_headers(result["header_path"]) for weight_path in weight_paths: species = get_species(weight_path) frozen_parameters, ferre_kwds = get_abundance_keywords(species, headers["LABEL"]) - abundance_keywords[result.header_path][species] = (weight_path, frozen_parameters, ferre_kwds) - - try: - spectrum = lookup_spectrum_by_primary_key[result.spectrum_pk] - except KeyError: - log.warning(f"Could not find spectrum {result.spectrum_pk} in the input list. Were previous analyses run in this same folder? Skipping..") - continue + abundance_keywords[result["header_path"]][species] = (weight_path, frozen_parameters, ferre_kwds) + + spectrum = lookup_spectrum_by_primary_key[result["spectrum_pk"]] - """ - # Apply continuum normalization, where we are just going to fix the observed - # spectrum to the best-fitting model spectrum from the upstream task. - try: - pre_computed_continuum[result.spectrum_pk] - except KeyError: - pre_computed_continuum[result.spectrum_pk] = result.unmask( - (result.rectified_model_flux/result.model_flux) - / (result.rectified_flux/result.ferre_flux) - ) - """ + prefix = f"{result['pwd']}/params/{result['short_grid_name']}" - # This does the same as the intent above, but it's faster. - # TODO: Consider rewriting the FerreOutputMixin to cache flux arrays and then only return index - # as it is called. That would be the same as the intent above, but it would be faster. try: - continuum_cache[result.pwd] + continuum_cache[prefix] except: P = 7514 - rectified_model_flux = np.atleast_2d(np.loadtxt(f"{result.pwd}/rectified_model_flux.output", usecols=range(1, 1+P))) - model_flux = np.atleast_2d(np.loadtxt(f"{result.pwd}/model_flux.output", usecols=range(1, 1+P))) - rectified_flux = np.atleast_2d(np.loadtxt(f"{result.pwd}/rectified_flux.output", usecols=range(1, 1+P))) - ferre_flux = np.atleast_2d(np.loadtxt(f"{result.pwd}/flux.input", usecols=range(P))) + rectified_model_flux = np.atleast_2d(np.loadtxt(f"{prefix}/rectified_model_flux.output", usecols=range(1, 1+P))) + model_flux = np.atleast_2d(np.loadtxt(f"{prefix}/model_flux.output", usecols=range(1, 1+P))) + rectified_flux = np.atleast_2d(np.loadtxt(f"{prefix}/rectified_flux.output", usecols=range(1, 1+P))) + ferre_flux = np.atleast_2d(np.loadtxt(f"{prefix}/flux.input", usecols=range(P))) continuum = (rectified_model_flux/model_flux) / (rectified_flux/ferre_flux) - continuum_cache[result.pwd] = np.nan * np.ones((continuum.shape[0], 8575)) - continuum_cache[result.pwd][:, mask] = continuum + continuum_cache[prefix] = np.nan * np.ones((continuum.shape[0], 8575)) + continuum_cache[prefix][:, mask] = continuum # Check names - continuum_cache_names[result.pwd] = [ - np.atleast_1d(np.loadtxt(f"{result.pwd}/model_flux.output", usecols=(0, ), dtype=str)), - np.atleast_1d(np.loadtxt(f"{result.pwd}/rectified_flux.output", usecols=(0, ), dtype=str)), - np.atleast_1d(np.loadtxt(f"{result.pwd}/rectified_model_flux.output", usecols=(0, ), dtype=str)), + # TODO: This is a sanity check. if it is expensive, we can remove it later. + t = -time() + continuum_cache_names[prefix] = [ + np.atleast_1d(np.loadtxt(f"{prefix}/model_flux.output", usecols=(0, ), dtype=str)), + np.atleast_1d(np.loadtxt(f"{prefix}/rectified_flux.output", usecols=(0, ), dtype=str)), + np.atleast_1d(np.loadtxt(f"{prefix}/rectified_model_flux.output", usecols=(0, ), dtype=str)), ] + t_check += (time() + t) finally: - pre_computed_continuum[result.spectrum_pk] = continuum_cache[result.pwd][int(result.ferre_output_index)] - for each in continuum_cache_names[result.pwd]: - meta = parse_ferre_spectrum_name(each[int(result.ferre_output_index)]) - assert int(meta["source_pk"]) == result.source_pk - assert int(meta["spectrum_pk"]) == result.spectrum_pk - assert int(meta["index"]) == result.ferre_input_index + t = -time() + pre_computed_continuum[result["spectrum_pk"]] = continuum_cache[prefix][int(result["ferre_index"])] + for each in continuum_cache_names[prefix]: + meta = parse_ferre_spectrum_name(each[int(result["ferre_index"])]) + assert int(meta["source_pk"]) == int(result["source_pk"]) + assert int(meta["spectrum_pk"]) == int(result["spectrum_pk"]) + assert int(meta["index"]) == int(result["ferre_index"]) + t_check += (time() + t) - group_task_kwds[result.header_path].append( + group_task_kwds[result["header_path"]].append( dict( spectra=spectrum, - pre_computed_continuum=pre_computed_continuum[result.spectrum_pk], - initial_teff=result.teff, - initial_logg=result.logg, - initial_m_h=result.m_h, - initial_log10_v_sini=result.log10_v_sini, - initial_log10_v_micro=result.log10_v_micro, - initial_alpha_m=result.alpha_m, - initial_c_m=result.c_m, - initial_n_m=result.n_m, - upstream_pk=result.task_pk, + pre_computed_continuum=pre_computed_continuum[result["spectrum_pk"]], + initial_teff=result["teff"], + initial_logg=result["logg"], + initial_m_h=result["m_h"], + initial_log10_v_sini=result.get("log10_v_sini", np.nan), + initial_log10_v_micro=result.get("log10_v_micro", np.nan), + initial_alpha_m=result.get("alpha_m", np.nan), + initial_c_m=result.get("c_m", np.nan), + initial_n_m=result.get("n_m", np.nan), ) ) @@ -382,8 +269,8 @@ def plan_abundances( ) extra_kwds.update(kwargs) - kwds_list = [] - spectra_with_no_stellar_parameters = set(spectra) + plans = [] + #spectra_with_no_stellar_parameters = set(spectra) for header_path in group_task_kwds.keys(): grid_kwds = list_to_dict(group_task_kwds[header_path]) @@ -391,10 +278,9 @@ def plan_abundances( for i, (species, details) in enumerate(abundance_keywords[header_path].items()): weight_path, frozen_parameters, ferre_kwds = details - pwd = os.path.join(parent_dir, STAGE, short_grid_name, species) kwds = grid_kwds.copy() kwds.update( - pwd=pwd, + relative_dir=os.path.join(STAGE, short_grid_name, species), header_path=header_path, weight_path=weight_path, frozen_parameters=frozen_parameters, @@ -409,13 +295,14 @@ def plan_abundances( if all(frozen_parameters.get(ln, False) for ln in ferre_headers[kwds["header_path"]][0]["LABEL"]): log.warning(f"Ignoring {species} species on grid {short_grid_name} because all parameters are frozen") continue - kwds_list.append(kwds) + plans.append(kwds) - spectra_with_no_stellar_parameters -= set(grid_kwds["spectra"]) + #spectra_with_no_stellar_parameters -= set(grid_kwds["spectra"]) - spectra_with_no_stellar_parameters = tuple(spectra_with_no_stellar_parameters) + #spectra_with_no_stellar_parameters = tuple(spectra_with_no_stellar_parameters) + #return (plans, spectra_with_no_stellar_parameters) + return plans - return (kwds_list, spectra_with_no_stellar_parameters) def get_species(weight_path): diff --git a/src/astra/pipelines/aspcap/coarse.py b/src/astra/pipelines/aspcap/coarse.py index 7177e36..60745f6 100644 --- a/src/astra/pipelines/aspcap/coarse.py +++ b/src/astra/pipelines/aspcap/coarse.py @@ -54,7 +54,7 @@ def penalize_coarse_stellar_parameter_result(result: FerreCoarse, warn_multiplie return penalized_rchi2 -def plan_coarse_stellar_parameters( +def plan_coarse_stellar_parameters_stage( spectra: Iterable[Spectrum], header_paths: Optional[Union[List[str], Tuple[str], str]] = "$MWM_ASTRA/pipelines/aspcap/synspec_dr17_marcs_header_paths.list", initial_guess_callable: Optional[Callable] = None, diff --git a/src/astra/pipelines/ferre/post_process.py b/src/astra/pipelines/ferre/post_process.py index e31918c..849a9fa 100644 --- a/src/astra/pipelines/ferre/post_process.py +++ b/src/astra/pipelines/ferre/post_process.py @@ -230,10 +230,10 @@ def _post_process_ferre(dir, pwd=None, skip_pixel_arrays=False, **kwargs) -> Ite source_pk=name_meta["source_pk"], spectrum_pk=name_meta["spectrum_pk"], initial_flags=name_meta["initial_flags"] or 0, - upstream_pk=name_meta["upstream_pk"], + #upstream_pk=name_meta["upstream_pk"], ferre_name=name, - ferre_input_index=name_meta["index"], - ferre_output_index=i, + ferre_index=name_meta["index"], + #ferre_output_index=i, rchi2=10**ferre_log_chi_sq[i], penalized_rchi2=10**ferre_log_chi_sq[i], ferre_log_snr_sq=ferre_log_snr_sq[i], @@ -241,6 +241,7 @@ def _post_process_ferre(dir, pwd=None, skip_pixel_arrays=False, **kwargs) -> Ite flag_potential_ferre_timeout=flag_potential_ferre_timeout[i], flag_missing_model_flux=flag_missing_model_flux[i], ) + assert i == name_meta["index"] # Add correlation coefficients. #meta["cov"] diff --git a/src/astra/pipelines/ferre/pre_process.py b/src/astra/pipelines/ferre/pre_process.py index d27cce8..9d09ef2 100644 --- a/src/astra/pipelines/ferre/pre_process.py +++ b/src/astra/pipelines/ferre/pre_process.py @@ -61,6 +61,8 @@ def pre_process_ferre( if kwargs: log.warning(f"astra.pipelines.ferre.pre_process.pre_process ignoring kwargs: {kwargs}") + n_threads = min(n_threads, len(spectra)) + # Validate the control file keywords. ( control_kwds, @@ -85,7 +87,7 @@ def pre_process_ferre( full_covariance=full_covariance, pca_project=pca_project, pca_chi=pca_chi, - n_threads=min(n_threads, len(spectra)), # Limit threads to the number of objects + n_threads=n_threads, # Limit threads to the number of objects f_access=f_access, f_format=f_format, ) @@ -238,7 +240,7 @@ def pre_process_ferre( np.savetxt(e_flux_path, batch_e_flux, **savetxt_kwds) n_obj = len(batch_names) - return (pwd, n_obj, n_threads, skipped) + return (pwd, n_obj, min(n_threads, n_obj), skipped) ''' bad_pixel_flux_value: float = 1e-4,