From 5fec087ceb6c71b76ee01610bac6f861ac64d521 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 27 Nov 2024 16:43:18 -0800 Subject: [PATCH 1/6] code cleanup from pylint and mypy --- pyproject.toml | 1 + src/rail/_pipelines/__init__.py | 1 - src/rail/cli/pipe_commands.py | 30 ++--- src/rail/cli/pipe_options.py | 2 +- src/rail/cli/pipe_scripts.py | 54 +++++---- src/rail/cli/reduce_roman_rubin_data.py | 13 +- src/rail/pipelines/estimation/estimate_all.py | 2 +- src/rail/pipelines/estimation/inform_all.py | 2 +- src/rail/pipelines/estimation/pz_all.py | 4 +- src/rail/pipelines/estimation/tomography.py | 22 ++-- src/rail/pipelines/evaluation/evaluate_all.py | 2 +- src/rail/utils/name_utils.py | 46 +++---- src/rail/utils/project.py | 113 +++++++++--------- 13 files changed, 155 insertions(+), 137 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3417eba..d4cbb15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ disable = [ "abstract-method", "invalid-name", "too-many-statements", + "too-many-instance-attributes", "missing-module-docstring", "missing-class-docstring", "missing-function-docstring", diff --git a/src/rail/_pipelines/__init__.py b/src/rail/_pipelines/__init__.py index 26bed30..8dee4bf 100644 --- a/src/rail/_pipelines/__init__.py +++ b/src/rail/_pipelines/__init__.py @@ -1,2 +1 @@ from ._version import __version__ - diff --git a/src/rail/cli/pipe_commands.py b/src/rail/cli/pipe_commands.py index c96c594..5d70c82 100644 --- a/src/rail/cli/pipe_commands.py +++ b/src/rail/cli/pipe_commands.py @@ -1,3 +1,5 @@ +from typing import Any + import click from rail.core import __version__ @@ -15,7 +17,7 @@ def pipe_cli() -> None: @pipe_cli.command(name="inspect") @pipe_options.config_file() -def inspect(config_file): +def inspect(config_file: str) -> int: """Inspect a rail pipeline project config""" return pipe_scripts.inspect(config_file) @@ -23,7 +25,7 @@ def inspect(config_file): @pipe_cli.command() @pipe_options.config_file() @pipe_options.flavor() -def build_pipelines(config_file, **kwargs): +def build_pipelines(config_file: str, **kwargs: Any) -> int: """Build the ceci pipeline configuraiton files""" project = RailProject.load_config(config_file) flavors = project.get_flavor_args(kwargs.pop('flavor')) @@ -38,7 +40,7 @@ def build_pipelines(config_file, **kwargs): @pipe_options.config_file() @pipe_options.selection() @pipe_options.run_mode() -def reduce_roman_rubin(config_file, **kwargs): +def reduce_roman_rubin(config_file: str, **kwargs: Any) -> int: """Reduce the roman rubin simulations for PZ analysis""" project = RailProject.load_config(config_file) selections = project.get_selection_args(kwargs.pop('selection')) @@ -54,7 +56,7 @@ def reduce_roman_rubin(config_file, **kwargs): @pipe_options.selection() @pipe_options.flavor() @pipe_options.run_mode() -def truth_to_observed_pipeline(config_file, **kwargs): +def truth_to_observed_pipeline(config_file: str, **kwargs: Any) -> int: """Run the truth-to-observed analysis pipeline""" project = RailProject.load_config(config_file) flavors = project.get_flavor_args(kwargs.pop('flavor')) @@ -62,7 +64,7 @@ def truth_to_observed_pipeline(config_file, **kwargs): iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) ok = 0 pipeline_name = "truth_to_observed" - + pipeline_catalog_config = pipe_scripts.TruthToObservedPipelineCatalogConfiguration( project, source_catalog_tag='reduced', sink_catalog_tag='degraded', ) @@ -80,7 +82,7 @@ def truth_to_observed_pipeline(config_file, **kwargs): @pipe_options.selection() @pipe_options.flavor() @pipe_options.run_mode() -def spectroscopic_selection_pipeline(config_file, **kwargs): +def spectroscopic_selection_pipeline(config_file: str, **kwargs: Any) -> int: """Run the spectroscopic selection analysis pipeline""" project = RailProject.load_config(config_file) flavors = project.get_flavor_args(kwargs.pop('flavor')) @@ -112,7 +114,7 @@ def spectroscopic_selection_pipeline(config_file, **kwargs): @pipe_options.selection() @pipe_options.flavor() @pipe_options.run_mode() -def blending_pipeline(config_file, **kwargs): +def blending_pipeline(config_file: str, **kwargs: Any) -> int: """Run the blending analysis pipeline""" project = RailProject.load_config(config_file) flavors = project.get_flavor_args(kwargs.pop('flavor')) @@ -143,7 +145,7 @@ def blending_pipeline(config_file, **kwargs): @pipe_options.selection() @pipe_options.label() @pipe_options.run_mode() -def subsample_data(config_file, **kwargs): +def subsample_data(config_file: str, **kwargs: Any) -> int: """Make a training data set by randomly selecting objects""" project = RailProject.load_config(config_file) flavors = project.get_flavor_args(kwargs.pop('flavor')) @@ -160,7 +162,7 @@ def subsample_data(config_file, **kwargs): @pipe_options.flavor() @pipe_options.selection() @pipe_options.run_mode() -def inform(config_file, **kwargs): +def inform(config_file: str, **kwargs: Any) -> int: """Run the inform pipeline""" pipeline_name = "inform" project = RailProject.load_config(config_file) @@ -182,7 +184,7 @@ def inform(config_file, **kwargs): @pipe_options.flavor() @pipe_options.selection() @pipe_options.run_mode() -def estimate_single(config_file, **kwargs): +def estimate_single(config_file: str, **kwargs: Any) -> int: """Run the estimation pipeline""" pipeline_name = "estimate" project = RailProject.load_config(config_file) @@ -204,7 +206,7 @@ def estimate_single(config_file, **kwargs): @pipe_options.flavor() @pipe_options.selection() @pipe_options.run_mode() -def evaluate_single(config_file, **kwargs): +def evaluate_single(config_file: str, **kwargs: Any) -> int: """Run the evaluation pipeline""" pipeline_name = "evaluate" project = RailProject.load_config(config_file) @@ -226,7 +228,7 @@ def evaluate_single(config_file, **kwargs): @pipe_options.flavor() @pipe_options.selection() @pipe_options.run_mode() -def pz_single(config_file, **kwargs): +def pz_single(config_file: str, **kwargs: Any) -> int: """Run the pz pipeline""" pipeline_name = "pz" project = RailProject.load_config(config_file) @@ -248,7 +250,7 @@ def pz_single(config_file, **kwargs): @pipe_options.flavor() @pipe_options.selection() @pipe_options.run_mode() -def tomography_single(config_file, **kwargs): +def tomography_single(config_file : str, **kwargs: Any) -> int: """Run the tomography pipeline""" pipeline_name = "tomography" project = RailProject.load_config(config_file) @@ -270,7 +272,7 @@ def tomography_single(config_file, **kwargs): @pipe_options.flavor() @pipe_options.selection() @pipe_options.run_mode() -def sompz_single(config_file, **kwargs): +def sompz_single(config_file: str, **kwargs: Any) -> int: """Run the sompz pipeline""" pipeline_name = "sompz" project = RailProject.load_config(config_file) diff --git a/src/rail/cli/pipe_options.py b/src/rail/cli/pipe_options.py index 72359f5..8091228 100644 --- a/src/rail/cli/pipe_options.py +++ b/src/rail/cli/pipe_options.py @@ -9,7 +9,7 @@ ) -__all__ = [ +__all__: list[str] = [ "RunMode", "config_path", "flavor", diff --git a/src/rail/cli/pipe_scripts.py b/src/rail/cli/pipe_scripts.py index a6504ab..20b2d75 100644 --- a/src/rail/cli/pipe_scripts.py +++ b/src/rail/cli/pipe_scripts.py @@ -41,10 +41,10 @@ def handle_command( if run_mode == RunMode.dry_run: # print(command_line) command_line.insert(0, "echo") - finished = subprocess.run(command_line) + finished = subprocess.run(command_line, check=False) elif run_mode == RunMode.bash: # return os.system(command_line) - finished = subprocess.run(command_line) + finished = subprocess.run(command_line, check=False) elif run_mode == RunMode.slurm: raise RuntimeError("handle_command should not be called with run_mode == RunMode.slurm") @@ -94,9 +94,9 @@ def handle_commands( try: os.makedirs(os.path.dirname(script_path)) - except: + except FileExistsError: pass - with open(script_path, 'w') as fout: + with open(script_path, 'w', encoding='utf-8') as fout: fout.write("#!/usr/bin/bash\n\n") for command_ in command_lines: com_line = ' '.join(command_) @@ -110,7 +110,7 @@ def handle_commands( ) as sbatch: assert sbatch.stdout line = sbatch.stdout.read().decode().strip() - return line.split("|")[0] + return int(line.split("|")[0]) except TypeError as msg: raise TypeError(f"Bad slurm submit: {msg}") from msg return 0 @@ -228,7 +228,7 @@ def run_pipeline_on_catalog( project: RailProject, pipeline_name: str, pipeline_catalog_configuration: PipelineCatalogConfiguration, - run_mode: RunMode==RunMode.bash, + run_mode: RunMode=RunMode.bash, **kwargs: Any, ) -> int: """ Run a pipeline on an entire catalog @@ -261,14 +261,14 @@ def run_pipeline_on_catalog( pipeline_path = project.get_path('pipeline_path', pipeline=pipeline_name, **kwargs) input_catalog_name = pipeline_info['InputCatalogTag'] - input_catalog = project.get_catalogs().get(input_catalog_name) + input_catalog = project.get_catalogs().get(input_catalog_name, {}) # Loop through all possible combinations of the iteration variables that are # relevant to this pipeline - if (iteration_vars := input_catalog.get("IterationVars")) is not None: + if (iteration_vars := input_catalog.get("IterationVars", {})) is not None: iterations = itertools.product( *[ - project.config.get("IterationVars").get(iteration_var) + project.config.get("IterationVars", {}).get(iteration_var, "") for iteration_var in iteration_vars ] ) @@ -281,7 +281,11 @@ def run_pipeline_on_catalog( source_catalog = pipeline_catalog_configuration.get_source_catalog(**kwargs, **iteration_kwargs) sink_catalog = pipeline_catalog_configuration.get_sink_catalog(**kwargs, **iteration_kwargs) sink_dir = os.path.dirname(sink_catalog) - script_path = pipeline_catalog_configuration.get_script_path(pipeline_name, sink_dir, **kwargs, **iteration_kwargs) + script_path = pipeline_catalog_configuration.get_script_path( + pipeline_name, + sink_dir, + **kwargs, **iteration_kwargs, + ) convert_commands = pipeline_catalog_configuration.get_convert_commands(sink_dir) ceci_command = project.generate_ceci_command( @@ -371,7 +375,7 @@ def run_pipeline_on_single_input( def inform_input_callback( project: RailProject, pipeline_name: str, - sink_dir: str | None, + sink_dir: str, # pylint: disable=unused-argument **kwargs: Any, ) -> dict[str, str]: """Make dict of input tags and paths for the inform pipeline @@ -384,7 +388,7 @@ def inform_input_callback( pipeline_name: str Name of the pipeline to run - sink_dir: str | None + sink_dir: str Path to output directory kwargs: Any @@ -408,7 +412,7 @@ def inform_input_callback( def estimate_input_callback( project: RailProject, pipeline_name: str, - sink_dir: str | None, + sink_dir: str, **kwargs: Any, ) -> dict[str, str]: """Make dict of input tags and paths for the estimate pipeline @@ -421,7 +425,7 @@ def estimate_input_callback( pipeline_name: str Name of the pipeline to run - sink_dir: str | None + sink_dir: str Path to output directory kwargs: Any @@ -449,7 +453,7 @@ def estimate_input_callback( def evaluate_input_callback( project: RailProject, pipeline_name: str, - sink_dir: str | None, + sink_dir: str, **kwargs: Any, ) -> dict[str, str]: """Make dict of input tags and paths for the evalute pipeline @@ -462,7 +466,7 @@ def evaluate_input_callback( pipeline_name: str Name of the pipeline to run - sink_dir: str | None + sink_dir: str Path to output directory kwargs: Any @@ -491,7 +495,7 @@ def evaluate_input_callback( def pz_input_callback( project: RailProject, pipeline_name: str, - sink_dir: str | None, + sink_dir: str, # pylint: disable=unused-argument **kwargs: Any, ) -> dict[str, str]: """Make dict of input tags and paths for the pz pipeline @@ -504,7 +508,7 @@ def pz_input_callback( pipeline_name: str Name of the pipeline to run - sink_dir: str | None + sink_dir: str Path to output directory kwargs: Any @@ -528,7 +532,7 @@ def pz_input_callback( def tomography_input_callback( project: RailProject, pipeline_name: str, - sink_dir: str | None, + sink_dir: str, **kwargs: Any, ) -> dict[str, str]: """Make dict of input tags and paths for the tomography pipeline @@ -541,7 +545,7 @@ def tomography_input_callback( pipeline_name: str Name of the pipeline to run - sink_dir: str | None + sink_dir: str Path to output directory kwargs: Any @@ -572,7 +576,7 @@ def tomography_input_callback( def sompz_input_callback( project: RailProject, pipeline_name: str, - sink_dir: str | None, + sink_dir: str, # pylint: disable=unused-argument **kwargs: Any, ) -> dict[str, str]: """Make dict of input tags and paths for the sompz pipeline @@ -585,7 +589,7 @@ def sompz_input_callback( pipeline_name: str Name of the pipeline to run - sink_dir: str | None + sink_dir: str Path to output directory kwargs: Any @@ -664,7 +668,7 @@ def subsample_data( iterations = itertools.product( *[ - project.config.get("IterationVars").get(iteration_var) + project.config.get("IterationVars", {}).get(iteration_var, "") for iteration_var in iteration_vars ] ) @@ -749,7 +753,7 @@ def build_pipelines( try: os.makedirs(pipe_out_dir) - except: + except FileExistsError: pass overrides = pipeline_overrides.get('default', {}) @@ -780,7 +784,7 @@ def build_pipelines( } pipeline_kwargs.update(**pipe_ctor_kwargs) stages_config = os.path.join(pipe_out_dir, f"{pipeline_name}_{flavor}_overrides.yml") - with open(stages_config, 'w') as fout: + with open(stages_config, 'w', encoding='utf-8') as fout: yaml.dump(overrides, fout) else: stages_config = None diff --git a/src/rail/cli/reduce_roman_rubin_data.py b/src/rail/cli/reduce_roman_rubin_data.py index 458068e..843afb7 100644 --- a/src/rail/cli/reduce_roman_rubin_data.py +++ b/src/rail/cli/reduce_roman_rubin_data.py @@ -8,6 +8,7 @@ from pyarrow import acero from rail.cli.pipe_options import RunMode +from rail.utils.project import RailProject COLUMNS = [ @@ -87,10 +88,10 @@ def reduce_roman_rubin_data( - project, - selection, - run_mode=RunMode.bash, -): + project: RailProject, + selection: str | None, + run_mode: RunMode=RunMode.bash, +) -> int: source_catalogs = [] sink_catalogs = [] @@ -104,11 +105,11 @@ def reduce_roman_rubin_data( # FIXME - iteration_vars = list(project.config.get("IterationVars").keys()) + iteration_vars = list(project.config.get("IterationVars", {}).keys()) if iteration_vars is not None: iterations = itertools.product( *[ - project.config.get("IterationVars").get(iteration_var) + project.config.get("IterationVars", {}).get(iteration_var, "") for iteration_var in iteration_vars ] ) diff --git a/src/rail/pipelines/estimation/estimate_all.py b/src/rail/pipelines/estimation/estimate_all.py index 849f6d7..7be3ff6 100644 --- a/src/rail/pipelines/estimation/estimate_all.py +++ b/src/rail/pipelines/estimation/estimate_all.py @@ -8,7 +8,7 @@ # Various rail modules from rail.core.stage import RailStage, RailPipeline -from rail.utils.project import PZ_ALGORITHMS +from rail.utils.algo_library import PZ_ALGORITHMS input_file = 'rubin_dm_dc2_example.pq' diff --git a/src/rail/pipelines/estimation/inform_all.py b/src/rail/pipelines/estimation/inform_all.py index a5a9ec8..e34a591 100644 --- a/src/rail/pipelines/estimation/inform_all.py +++ b/src/rail/pipelines/estimation/inform_all.py @@ -4,7 +4,7 @@ import ceci from rail.core.stage import RailStage, RailPipeline -from rail.utils.project import PZ_ALGORITHMS +from rail.utils.algo_library import PZ_ALGORITHMS input_file = 'rubin_dm_dc2_example.pq' diff --git a/src/rail/pipelines/estimation/pz_all.py b/src/rail/pipelines/estimation/pz_all.py index 41c5852..ccc590d 100644 --- a/src/rail/pipelines/estimation/pz_all.py +++ b/src/rail/pipelines/estimation/pz_all.py @@ -6,7 +6,7 @@ # Various rail modules from rail.core.stage import RailStage, RailPipeline from rail.evaluation.single_evaluator import SingleEvaluator -from rail.utils.project import PZ_ALGORITHMS +from rail.utils.algo_library import PZ_ALGORITHMS input_file = 'rubin_dm_dc2_example.pq' @@ -29,7 +29,7 @@ class PzPipeline(RailPipeline): 'input_test':'dummy.in', } - def __init__(self, algorithms=None): + def __init__(self, algorithms: dict|None=None): RailPipeline.__init__(self) DS = RailStage.data_store diff --git a/src/rail/pipelines/estimation/tomography.py b/src/rail/pipelines/estimation/tomography.py index 82cc692..927adf8 100644 --- a/src/rail/pipelines/estimation/tomography.py +++ b/src/rail/pipelines/estimation/tomography.py @@ -6,7 +6,6 @@ # Various rail modules from rail.core.stage import RailStage, RailPipeline from rail.estimation.algos.true_nz import TrueNZHistogrammer -from rail.evaluation.single_evaluator import SingleEvaluator from rail.utils.algo_library import PZ_ALGORITHMS, CLASSIFIERS, SUMMARIZERS @@ -17,7 +16,13 @@ class TomographyPipeline(RailPipeline): truth='dummy.in', ) - def __init__(self, algorithms=None, classifiers=None, summarizers=None, n_tomo_bins=5): + def __init__( + self, + algorithms: dict | None=None, + classifiers: dict | None=None, + summarizers: dict | None=None, + n_tomo_bins: int=5, + ): RailPipeline.__init__(self) DS = RailStage.data_store @@ -25,7 +30,7 @@ def __init__(self, algorithms=None, classifiers=None, summarizers=None, n_tomo_b if algorithms is None: algorithms = PZ_ALGORITHMS - + if classifiers is None: classifiers = CLASSIFIERS @@ -33,10 +38,13 @@ def __init__(self, algorithms=None, classifiers=None, summarizers=None, n_tomo_b summarizers = SUMMARIZERS for pz_algo_name_ in algorithms: - + for classifier_name_, classifier_info_ in classifiers.items(): - classifier_class = ceci.PipelineStage.get_stage(classifier_info_['Classify'], classifier_info_['Module']) + classifier_class = ceci.PipelineStage.get_stage( + classifier_info_['Classify'], + classifier_info_['Module'], + ) the_classifier = classifier_class.make_and_connect( aliases=dict(input=f"input_{pz_algo_name_}"), name=f'classify_{pz_algo_name_}_{classifier_name_}', @@ -56,7 +64,7 @@ def __init__(self, algorithms=None, classifiers=None, summarizers=None, n_tomo_b aliases=dict(input='truth'), ) self.add_stage(true_nz) - + for summarizer_name_, summarize_info_ in summarizers.items(): summarizer_class = ceci.PipelineStage.get_stage( summarize_info_['Summarize'], @@ -66,7 +74,7 @@ def __init__(self, algorithms=None, classifiers=None, summarizers=None, n_tomo_b name=f'summarize_{pz_algo_name_}_{classifier_name_}_bin{ibin}_{summarizer_name_}', aliases=dict(input=f"input_{pz_algo_name_}"), connections=dict( - tomography_bins=the_classifier.io.output, + tomography_bins=the_classifier.io.output, ), selected_bin=ibin, nsamples=20, diff --git a/src/rail/pipelines/evaluation/evaluate_all.py b/src/rail/pipelines/evaluation/evaluate_all.py index 20ef4de..2def946 100644 --- a/src/rail/pipelines/evaluation/evaluate_all.py +++ b/src/rail/pipelines/evaluation/evaluate_all.py @@ -6,7 +6,7 @@ from rail.core.stage import RailStage, RailPipeline from rail.evaluation.single_evaluator import SingleEvaluator -from rail.utils.project import PZ_ALGORITHMS +from rail.utils.algo_library import PZ_ALGORITHMS shared_stage_opts = dict( diff --git a/src/rail/utils/name_utils.py b/src/rail/utils/name_utils.py index 1ef3f4c..4a1b9d9 100644 --- a/src/rail/utils/name_utils.py +++ b/src/rail/utils/name_utils.py @@ -6,8 +6,7 @@ # import enum import re from functools import partial - -import yaml +from typing import Any CommonPaths = dict( @@ -27,7 +26,7 @@ ) -def _get_required_interpolants(template): +def _get_required_interpolants(template: str) -> list[str]: """ Get the list of interpolants required to format a template string Notes @@ -38,7 +37,7 @@ def _get_required_interpolants(template): return re.findall('{.*?}', template) -def _format_template(template, **kwargs): +def _format_template(template: str, **kwargs: Any) -> str: """ Resolve a specific template This is fault-tolerant and will not raise KeyError if some @@ -54,7 +53,7 @@ def _format_template(template, **kwargs): return template.format(**interpolants) -def _resolve_dict(source, interpolants): +def _resolve_dict(source: dict, interpolants: dict) -> dict: """ Recursively resolve a dictionary using interpolants Parameters @@ -73,10 +72,11 @@ def _resolve_dict(source, interpolants): if source is not None: sink = copy.deepcopy(source) for k, v in source.items(): + v_interpolated: list | dict | str = "" match v: case dict(): v_interpolated = _resolve_dict(source[k], interpolants) - case list(): + case list(): v_interpolated = [_resolve_dict(_v, interpolants) for _v in v] case str(): v_interpolated = v.format(**interpolants) @@ -90,7 +90,7 @@ def _resolve_dict(source, interpolants): return sink -def _resolve(templates, source, interpolants): +def _resolve(templates: dict, source: dict, interpolants: dict) -> dict: """ Resolve a set of templates using interpolants and allow for overrides Parameters @@ -134,7 +134,12 @@ class NameFactory: PathTemplates = PathTemplates, ) - def __init__(self, config=None, templates=None, interpolants=None): + def __init__( + self, + config: dict | None=None, + templates: dict | None=None, + interpolants: dict | None=None, + ): """ C'tor """ @@ -151,7 +156,7 @@ def __init__(self, config=None, templates=None, interpolants=None): self._config[key].update(**config[key]) self._templates = copy.deepcopy(self._config['PathTemplates']) self._templates.update(**templates) - self._interpolants = {} + self._interpolants: dict = {} self.templates = {} for k, v in templates.items(): @@ -160,31 +165,30 @@ def __init__(self, config=None, templates=None, interpolants=None): self.interpolants = self._config['CommonPaths'] self.interpolants = interpolants - def get_path_templates(self): + def get_path_templates(self) -> dict: return self._config['PathTemplates'] - def get_common_paths(self): + def get_common_paths(self) -> dict: return self._config['CommonPaths'] @property - def interpolants(self): + def interpolants(self) -> dict: """ Return the dict of interpolants that are used to resolve templates """ return self._interpolants @interpolants.setter - def interpolants(self, config): + def interpolants(self, config: dict) -> None: """ Update the dict of interpolants that are used to resolve templates """ for key, value in config.items(): new_value = value.format(**self.interpolants) self.interpolants[key] = new_value @interpolants.deleter - def interpolants(self): + def interpolants(self) -> None: """ Reset the dict of interpolants that are used to resolve templates""" self._interpolants = {} - - def resolve_from_config(self, config): + def resolve_from_config(self, config: dict) -> dict: """ Resolve all the templates in a dict Parameters @@ -206,7 +210,7 @@ def resolve_from_config(self, config): return resolved - def resolve_path(self, config, path_key, **kwargs): + def resolve_path(self, config: dict, path_key: str, **kwargs: Any) -> str: """ Resolve a particular template in a config dict Parameters @@ -229,7 +233,7 @@ def resolve_path(self, config, path_key, **kwargs): return formatted - def get_template(self, section_key, path_key): + def get_template(self, section_key: str, path_key: str) -> str: """ Return the template for a particular file type Parameters @@ -261,7 +265,7 @@ def get_template(self, section_key, path_key): f"available paths are {list(section.keys())}", ) from msg - def resolve_template(self, section_key, path_key, **kwargs): + def resolve_template(self, section_key: str, path_key: str, **kwargs: Any) -> str: """ Return the template for a particular file type Parameters @@ -281,7 +285,7 @@ def resolve_template(self, section_key, path_key, **kwargs): template = self.get_template(section_key, path_key) return _format_template(template, **self.interpolants, **kwargs) - def resolve_path_template(self, path_key, **kwargs): + def resolve_path_template(self, path_key: str, **kwargs: Any) -> str: """ Return a particular path templated Parameters @@ -300,7 +304,7 @@ def resolve_path_template(self, path_key, **kwargs): return _format_template(template, **interp_dict) - def resolve_common_path(self, path_key, **kwargs): + def resolve_common_path(self, path_key: str, **kwargs: Any) -> str: """ Return a particular common path template Parameters diff --git a/src/rail/utils/project.py b/src/rail/utils/project.py index ec4bf2c..61cb9c0 100644 --- a/src/rail/utils/project.py +++ b/src/rail/utils/project.py @@ -1,16 +1,15 @@ import copy from pathlib import Path import itertools +from typing import Any import yaml from rail.utils import name_utils -from .algo_library import PZ_ALGORITHMS, CLASSIFIERS, SUMMARIZERS, SPEC_SELECTIONS - class RailProject: - config_template = { + config_template: dict[str, dict] = { "IterationVars": {}, "CommonPaths": {}, "PathTemplates": {}, @@ -27,7 +26,7 @@ class RailProject: "Summarizers": {}, } - def __init__(self, name, config_dict): + def __init__(self, name: str, config_dict: dict): self.name = name self._config_dict = config_dict self.config = copy.deepcopy(self.config_template) @@ -41,43 +40,43 @@ def __init__(self, name, config_dict): interpolants=self.config.get("CommonPaths", {}), ) self.name_factory.resolve_from_config( - self.config.get("CommonPaths") + self.config.get("CommonPaths", {}) ) def __repr__(self): return f"{self.name}" @staticmethod - def load_config(config_file): + def load_config(config_file: str) -> RailProject: """ Create and return a RailProject from a yaml config file""" project_name = Path(config_file).stem - with open(config_file, "r") as fp: + with open(config_file, "r", encoding='utf-8') as fp: config_dict = yaml.safe_load(fp) project = RailProject(project_name, config_dict) # project.resolve_common() return project - def get_path_templates(self): + def get_path_templates(self) -> dict: """ Return the dictionary of templates used to construct paths """ return self.name_factory.get_path_templates() - def get_path(self, path_key, **kwargs): + def get_path(self, path_key: str, **kwargs: Any) -> str: """ Resolve and return a path using the kwargs as interopolants """ return self.name_factory.resolve_path_template(path_key, **kwargs) - def get_common_paths(self): + def get_common_paths(self) -> dict: """ Return the dictionary of common paths """ return self.name_factory.get_common_paths() - def get_common_path(self, path_key, **kwargs): + def get_common_path(self, path_key: str, **kwargs: Any) -> str: """ Resolve and return a common path using the kwargs as interopolants """ return self.name_factory.resolve_common_path(path_key, **kwargs) - def get_files(self): + def get_files(self) -> dict: """ Return the dictionary of specific files """ - return self.config.get("Files") + return self.config.get("Files", {}) - def get_file(self, name, **kwargs): + def get_file(self, name: str, **kwargs: Any) -> str: """ Resolve and return a file using the kwargs as interpolants """ files = self.get_files() file_dict = files.get(name, None) @@ -86,9 +85,9 @@ def get_file(self, name, **kwargs): path = self.name_factory.resolve_path(file_dict, "PathTemplate", **kwargs) return path - def get_flavors(self): + def get_flavors(self) -> dict: """ Return the dictionary of analysis flavor variants """ - flavors = self.config.get("Flavors") + flavors = self.config.get("Flavors", {}) baseline = flavors.get("baseline", {}) for k, v in flavors.items(): if k != "baseline": @@ -96,7 +95,7 @@ def get_flavors(self): return flavors - def get_flavor(self, name): + def get_flavor(self, name: str) -> dict: """ Resolve the configuration for a particular analysis flavor variant """ flavors = self.get_flavors() flavor = flavors.get(name, None) @@ -104,7 +103,7 @@ def get_flavor(self, name): raise KeyError(f"flavor '{name}' not found in {self}") return flavor - def get_file_for_flavor(self, flavor, label, **kwargs): + def get_file_for_flavor(self, flavor: str, label: str, **kwargs: Any) -> str: """ Resolve the file associated to a particular flavor and label E.g., flavor=baseline and label=train would give the baseline training file @@ -116,7 +115,7 @@ def get_file_for_flavor(self, flavor, label, **kwargs): raise KeyError(f"Label '{label}' not found in flavor '{flavor}'") from msg return self.get_file(file_alias, flavor=flavor, label=label, **kwargs) - def get_file_metadata_for_flavor(self, flavor, label): + def get_file_metadata_for_flavor(self, flavor: str, label: str) -> dict: """ Resolve the metadata associated to a particular flavor and label E.g., flavor=baseline and label=train would give the baseline training metadata @@ -128,11 +127,11 @@ def get_file_metadata_for_flavor(self, flavor, label): raise KeyError(f"Label '{label}' not found in flavor '{flavor}'") from msg return self.get_files()[file_alias] - def get_selections(self): + def get_selections(self) -> dict: """ Get the dictionary describing all the selections""" - return self.config.get("Selections") + return self.config.get("Selections", {}) - def get_selection(self, name): + def get_selection(self, name: str) -> dict: """ Get a particular selection by name""" selections = self.get_selections() selection = selections.get(name, None) @@ -140,11 +139,11 @@ def get_selection(self, name): raise KeyError(f"selection '{name}' not found in {self}") return selection - def get_error_models(self): + def get_error_models(self) -> dict: """ Get the dictionary describing all the photometric error model algorithms""" - return self.config.get("ErrorModels") + return self.config.get("ErrorModels", {}) - def get_error_model(self, name): + def get_error_model(self, name: str) -> dict: """ Get the information about a particular photometric error model algorithms""" error_models = self.get_error_models() error_model = error_models.get(name, None) @@ -152,11 +151,11 @@ def get_error_model(self, name): raise KeyError(f"error_models '{name}' not found in {self}") return error_model - def get_pzalgorithms(self): + def get_pzalgorithms(self) -> dict: """ Get the dictionary describing all the PZ estimation algorithms""" - return self.config.get("PZAlgorithms") + return self.config.get("PZAlgorithms", {}) - def get_pzalgorithm(self, name): + def get_pzalgorithm(self, name: str) -> dict: """ Get the information about a particular PZ estimation algorithm""" pzalgorithms = self.get_pzalgorithms() pzalgorithm = pzalgorithms.get(name, None) @@ -164,11 +163,11 @@ def get_pzalgorithm(self, name): raise KeyError(f"pz algorithm '{name}' not found in {self}") return pzalgorithm - def get_nzalgorithms(self): + def get_nzalgorithms(self) -> dict: """ Get the dictionary describing all the PZ estimation algorithms""" - return self.config.get("NZAlgorithms") + return self.config.get("NZAlgorithms", {}) - def get_nzalgorithm(self, name): + def get_nzalgorithm(self, name: str) -> dict: """ Get the information about a particular NZ estimation algorithm""" nzalgorithms = self.get_nzalgorithms() nzalgorithm = nzalgorithms.get(name, None) @@ -176,11 +175,11 @@ def get_nzalgorithm(self, name): raise KeyError(f"nz algorithm '{name}' not found in {self}") return nzalgorithm - def get_spec_selections(self): + def get_spec_selections(self) -> dict: """ Get the dictionary describing all the spectroscopic selection algorithms""" - return self.config.get("SpecSelections") + return self.config.get("SpecSelections", {}) - def get_spec_selection(self, name): + def get_spec_selection(self, name: str) -> dict: """ Get the information about a particular spectroscopic selection algorithm""" spec_selections = self.get_spec_selections() spec_selection = spec_selections.get(name, None) @@ -188,11 +187,11 @@ def get_spec_selection(self, name): raise KeyError(f"spectroscopic selection '{name}' not found in {self}") return spec_selection - def get_classifiers(self): + def get_classifiers(self) -> dict: """ Get the dictionary describing all the tomographic bin classification""" - return self.config.get("Classifiers") + return self.config.get("Classifiers", {}) - def get_classifier(self, name): + def get_classifier(self, name: str) -> dict: """ Get the information about a particular tomographic bin classification""" classifiers = self.get_classifiers() classifier = classifiers.get(name, None) @@ -200,11 +199,11 @@ def get_classifier(self, name): raise KeyError(f"tomographic bin classifier '{name}' not found in {self}") return classifier - def get_summarizers(self): + def get_summarizers(self) -> dict: """ Get the dictionary describing all the NZ summarization algorithms""" - return self.config.get("Summarizers") + return self.config.get("Summarizers", {}) - def get_summarizer(self, name): + def get_summarizer(self, name: str) -> dict: """ Get the information about a particular NZ summarization algorithms""" summarizers = self.get_summarizers() summarizer = summarizers.get(name, None) @@ -212,11 +211,11 @@ def get_summarizer(self, name): raise KeyError(f"NZ summarizer '{name}' not found in {self}") return summarizer - def get_catalogs(self): + def get_catalogs(self) -> dict: """ Get the dictionary describing all the types of data catalogs""" - return self.config['Catalogs'] + return self.config.get('Catalogs', {}) - def get_catalog(self, catalog, **kwargs): + def get_catalog(self, catalog: str, **kwargs: Any) -> str: """ Resolve the path for a particular catalog file""" catalog_dict = self.config['Catalogs'].get(catalog, {}) try: @@ -225,11 +224,11 @@ def get_catalog(self, catalog, **kwargs): except KeyError as msg: raise KeyError(f"PathTemplate not found in {catalog}") from msg - def get_pipelines(self): + def get_pipelines(self) -> dict: """ Get the dictionary describing all the types of ceci pipelines""" - return self.config.get("Pipelines") + return self.config.get("Pipelines", {}) - def get_pipeline(self, name): + def get_pipeline(self, name: str) -> dict: """ Get the information about a particular ceci pipeline""" pipelines = self.get_pipelines() pipeline = pipelines.get(name, None) @@ -237,7 +236,7 @@ def get_pipeline(self, name): raise KeyError(f"pipeline '{name}' not found in {self}") return pipeline - def get_flavor_args(self, flavors): + def get_flavor_args(self, flavors: list[str]) -> list[str]: """ Get the 'flavors' to iterate a particular command over Notes @@ -250,7 +249,7 @@ def get_flavor_args(self, flavors): return list(flavor_dict.keys()) return flavors - def get_selection_args(self, selections): + def get_selection_args(self, selections: list[str]) -> list[str]: """ Get the 'selections' to iterate a particular command over Notes @@ -263,11 +262,11 @@ def get_selection_args(self, selections): return list(selection_dict.keys()) return selections - def generate_kwargs_iterable(self, **iteration_dict): + def generate_kwargs_iterable(self, **iteration_dict: Any) -> list[dict]: iteration_vars = list(iteration_dict.keys()) iterations = itertools.product( *[ - iteration_dict.get(key) for key in iteration_vars + iteration_dict.get(key, []) for key in iteration_vars ] ) iteration_kwarg_list = [] @@ -281,13 +280,13 @@ def generate_kwargs_iterable(self, **iteration_dict): def generate_ceci_command( self, - pipeline_path, - config=None, - inputs=None, - output_dir='.', - log_dir='.', - **kwargs, - ): + pipeline_path: str, + config: str|None, + inputs: dict, + output_dir: str='.', + log_dir: str='.', + **kwargs: Any, + ) -> list[str]: if config is None: config = pipeline_path.replace('.yaml', '_config.yml') From aa72b3bd760668b52ca7e76146271243d081a559 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 27 Nov 2024 17:05:10 -0800 Subject: [PATCH 2/6] Added __future__ import annotations --- src/rail/utils/project.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rail/utils/project.py b/src/rail/utils/project.py index 61cb9c0..7b45dbd 100644 --- a/src/rail/utils/project.py +++ b/src/rail/utils/project.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy from pathlib import Path import itertools From 7b943f63227f43692ad59a36b30f06e0a6720086 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 27 Nov 2024 17:05:22 -0800 Subject: [PATCH 3/6] docstrings --- src/rail/cli/pipe_scripts.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/rail/cli/pipe_scripts.py b/src/rail/cli/pipe_scripts.py index 20b2d75..748307e 100644 --- a/src/rail/cli/pipe_scripts.py +++ b/src/rail/cli/pipe_scripts.py @@ -139,6 +139,10 @@ def inspect(config_file: str) -> int: class PipelineCatalogConfiguration: + """Small plugin class to handle configuring a pipeline to run on a catalog + + Sub-classes will have to implment "get_convert_commands" function + """ def __init__( self, @@ -155,16 +159,19 @@ def __init__( self._sink_catalog_basename = sink_catalog_basename def get_source_catalog(self, **kwargs: Any) -> str: + """Get the name of the source (i.e. input) catalog file""" return self._project.get_catalog( self._source_catalog_tag, basename=self._source_catalog_basename, **kwargs, ) def get_sink_catalog(self, **kwargs: Any) -> str: + """Get the name of the sink (i.e., output) catalog file""" return self._project.get_catalog( self._sink_catalog_tag, basename=self._sink_catalog_basename, **kwargs, ) def get_script_path(self, pipeline_name: str, sink_dir: str, **kwargs: Any) -> str: + """Get path to use for the slurm batch submit script""" selection = kwargs['selection'] flavor = kwargs['flavor'] return os.path.join( @@ -173,6 +180,9 @@ def get_script_path(self, pipeline_name: str, sink_dir: str, **kwargs: Any) -> s ) def get_convert_commands(self, sink_dir: str) -> list[list[str]]: + """Get the set of commands to run after the pipeline to + convert output files + """ raise NotImplementedError() From 58d84ac8706bdd0f15f3d5d077ba29d5ee2c4b38 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 27 Nov 2024 17:05:42 -0800 Subject: [PATCH 4/6] pylint generated-members for pyarrow.compute --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d4cbb15..eda0cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,9 @@ disable = [ "use-dict-literal", "broad-exception-caught", ] +generated-members = ["add", "multiply", "subtract", "divide", "sqrt", "floor", "atan2"] max-line-length = 110 max-locals = 50 max-branches = 25 max-public-methods = 50 +max-args = 7 \ No newline at end of file From c72d065cf90d496babc17d8a4df942afc2e0aa0c Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Mon, 2 Dec 2024 12:41:40 -0800 Subject: [PATCH 5/6] mypy fixes --- pyproject.toml | 16 +++++++++++++++- src/rail/cli/pipe_scripts.py | 4 ++-- src/rail/pipelines/estimation/estimate_all.py | 2 +- src/rail/pipelines/estimation/inform_all.py | 2 +- src/rail/pipelines/evaluation/evaluate_all.py | 2 +- .../examples/goldenspike/goldenspike.py | 2 +- .../survey_nonuniformity/survey_nonuniformity.py | 2 +- src/rail/utils/name_utils.py | 4 ++-- src/rail/utils/project.py | 2 +- 9 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eda0cbe..6726137 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,4 +79,18 @@ max-line-length = 110 max-locals = 50 max-branches = 25 max-public-methods = 50 -max-args = 7 \ No newline at end of file +max-args = 7 + + +[tool.mypy] +disallow_untyped_defs = true +disallow_incomplete_defs = true +ignore_missing_imports = true +local_partial_types = true +no_implicit_reexport = true +show_error_codes = true +strict_equality = true +warn_redundant_casts = true +warn_unreachable = true +warn_unused_ignores = true + diff --git a/src/rail/cli/pipe_scripts.py b/src/rail/cli/pipe_scripts.py index 748307e..e31d875 100644 --- a/src/rail/cli/pipe_scripts.py +++ b/src/rail/cli/pipe_scripts.py @@ -110,10 +110,10 @@ def handle_commands( ) as sbatch: assert sbatch.stdout line = sbatch.stdout.read().decode().strip() - return int(line.split("|")[0]) + ret_val = int(line.split("|")[0]) except TypeError as msg: raise TypeError(f"Bad slurm submit: {msg}") from msg - return 0 + return ret_val def inspect(config_file: str) -> int: diff --git a/src/rail/pipelines/estimation/estimate_all.py b/src/rail/pipelines/estimation/estimate_all.py index 7be3ff6..b6d0a87 100644 --- a/src/rail/pipelines/estimation/estimate_all.py +++ b/src/rail/pipelines/estimation/estimate_all.py @@ -18,7 +18,7 @@ class EstimatePipeline(RailPipeline): default_input_dict={'input':'dummy.in'} - def __init__(self, algorithms=None, models_dir='.'): + def __init__(self, algorithms: dict|None=None, models_dir: str='.'): RailPipeline.__init__(self) diff --git a/src/rail/pipelines/estimation/inform_all.py b/src/rail/pipelines/estimation/inform_all.py index e34a591..8c8a533 100644 --- a/src/rail/pipelines/estimation/inform_all.py +++ b/src/rail/pipelines/estimation/inform_all.py @@ -14,7 +14,7 @@ class InformPipeline(RailPipeline): default_input_dict={'input':'dummy.in'} - def __init__(self, algorithms=None): + def __init__(self, algorithms: dict | None=None): RailPipeline.__init__(self) DS = RailStage.data_store diff --git a/src/rail/pipelines/evaluation/evaluate_all.py b/src/rail/pipelines/evaluation/evaluate_all.py index 2def946..68f916c 100644 --- a/src/rail/pipelines/evaluation/evaluate_all.py +++ b/src/rail/pipelines/evaluation/evaluate_all.py @@ -23,7 +23,7 @@ class EvaluationPipeline(RailPipeline): default_input_dict=dict(truth='dummy.in') - def __init__(self, algorithms=None, pdfs_dir='.'): + def __init__(self, algorithms:dict | None = None, pdfs_dir: str='.') -> None: RailPipeline.__init__(self) DS = RailStage.data_store diff --git a/src/rail/pipelines/examples/goldenspike/goldenspike.py b/src/rail/pipelines/examples/goldenspike/goldenspike.py index 96363f0..4f60157 100644 --- a/src/rail/pipelines/examples/goldenspike/goldenspike.py +++ b/src/rail/pipelines/examples/goldenspike/goldenspike.py @@ -30,7 +30,7 @@ class GoldenspikePipeline(RailPipeline): model=flow_file, ) - def __init__(self): + def __init__(self) -> None: RailPipeline.__init__(self) DS = RailStage.data_store diff --git a/src/rail/pipelines/examples/survey_nonuniformity/survey_nonuniformity.py b/src/rail/pipelines/examples/survey_nonuniformity/survey_nonuniformity.py index b0bf355..5a2dc3a 100644 --- a/src/rail/pipelines/examples/survey_nonuniformity/survey_nonuniformity.py +++ b/src/rail/pipelines/examples/survey_nonuniformity/survey_nonuniformity.py @@ -20,7 +20,7 @@ class SurveyNonuniformDegraderPipeline(RailPipeline): default_input_dict = dict(model=flow_file) - def __init__(self): + def __init__(self) -> None: RailPipeline.__init__(self) DS = RailStage.data_store diff --git a/src/rail/utils/name_utils.py b/src/rail/utils/name_utils.py index 4a1b9d9..ca250cd 100644 --- a/src/rail/utils/name_utils.py +++ b/src/rail/utils/name_utils.py @@ -69,7 +69,7 @@ def _resolve_dict(source: dict, interpolants: dict) -> dict: sink : dict Dictionary of resolved templates """ - if source is not None: + if source: sink = copy.deepcopy(source) for k, v in source.items(): v_interpolated: list | dict | str = "" @@ -85,7 +85,7 @@ def _resolve_dict(source: dict, interpolants: dict) -> dict: sink[k] = v_interpolated else: - sink = None + sink = {} return sink diff --git a/src/rail/utils/project.py b/src/rail/utils/project.py index 7b45dbd..775ec7b 100644 --- a/src/rail/utils/project.py +++ b/src/rail/utils/project.py @@ -45,7 +45,7 @@ def __init__(self, name: str, config_dict: dict): self.config.get("CommonPaths", {}) ) - def __repr__(self): + def __repr__(self) -> str: return f"{self.name}" @staticmethod From 873dbf9ec6edc59286335a8eebf1d70d0db0ec4b Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Mon, 2 Dec 2024 12:54:37 -0800 Subject: [PATCH 6/6] fix unit test --- tests/test_name_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_name_utils.py b/tests/test_name_utils.py index 53fc3fe..8c11e86 100644 --- a/tests/test_name_utils.py +++ b/tests/test_name_utils.py @@ -17,7 +17,7 @@ def test_name_utils(): name_utils._resolve_dict(test_dict, dict(alice='x', bob='y')) - assert name_utils._resolve_dict(None, {}) is None + assert not name_utils._resolve_dict(None, {}) with pytest.raises(ValueError): name_utils._resolve_dict(dict(a=('s','d',)), dict(alice='x', bob='y'))