Skip to content

Commit

Permalink
Add site option and clean up mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Dec 11, 2024
1 parent 19021e1 commit 785375e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 7 deletions.
14 changes: 13 additions & 1 deletion src/rail/cli/rail_pipe/pipe_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from rail.core import __version__

from ...utils.project import RailProject
from rail.utils.project import RailProject
from . import pipe_options, pipe_scripts
from .reduce_roman_rubin_data import reduce_roman_rubin_data

Expand Down Expand Up @@ -43,6 +43,7 @@ def build_command(config_file: str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.label()
@pipe_options.run_mode()
@pipe_options.site()
def subsample_command(config_file: str, **kwargs: Any) -> int:
"""Make a training or test data set by randomly selecting objects"""
"""Make a training data set by randomly selecting objects"""
Expand All @@ -67,6 +68,7 @@ def reduce_group() -> None:
@pipe_options.input_selection()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def reduce_roman_rubin(config_file: str, **kwargs: Any) -> int:
"""Reduce the roman rubin simulations for PZ analysis"""
project = RailProject.load_config(config_file)
Expand All @@ -90,6 +92,7 @@ def run_group() -> None:
@pipe_options.selection()
@pipe_options.flavor()
@pipe_options.run_mode()
@pipe_options.site()
def photmetric_errors_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the photometric errors analysis pipeline"""
project = RailProject.load_config(config_file)
Expand Down Expand Up @@ -118,6 +121,7 @@ def photmetric_errors_pipeline(config_file: str, **kwargs: Any) -> int:
@pipe_options.selection()
@pipe_options.flavor()
@pipe_options.run_mode()
@pipe_options.site()
def truth_to_observed_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the truth-to-observed data pipeline"""
project = RailProject.load_config(config_file)
Expand Down Expand Up @@ -149,6 +153,7 @@ def truth_to_observed_pipeline(config_file: str, **kwargs: Any) -> int:
@pipe_options.selection()
@pipe_options.flavor()
@pipe_options.run_mode()
@pipe_options.site()
def blending_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the blending analysis pipeline"""
project = RailProject.load_config(config_file)
Expand Down Expand Up @@ -179,6 +184,7 @@ def blending_pipeline(config_file: str, **kwargs: Any) -> int:
@pipe_options.selection()
@pipe_options.flavor()
@pipe_options.run_mode()
@pipe_options.site()
def spectroscopic_selection_pipeline(config_file: str, **kwargs: Any) -> int:
"""Run the spectroscopic selection data pipeline"""
project = RailProject.load_config(config_file)
Expand Down Expand Up @@ -210,6 +216,7 @@ def spectroscopic_selection_pipeline(config_file: str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def inform_single(config_file: str, **kwargs: Any) -> int:
"""Run the inform pipeline"""
pipeline_name = "inform"
Expand All @@ -232,6 +239,7 @@ def inform_single(config_file: str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def estimate_single(config_file: str, **kwargs: Any) -> int:
"""Run the estimation pipeline"""
pipeline_name = "estimate"
Expand All @@ -254,6 +262,7 @@ def estimate_single(config_file: str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def evaluate_single(config_file: str, **kwargs: Any) -> int:
"""Run the evaluation pipeline"""
pipeline_name = "evaluate"
Expand All @@ -276,6 +285,7 @@ def evaluate_single(config_file: str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def pz_single(config_file: str, **kwargs: Any) -> int:
"""Run the pz pipeline"""
pipeline_name = "pz"
Expand All @@ -298,6 +308,7 @@ def pz_single(config_file: str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def tomography_single(config_file : str, **kwargs: Any) -> int:
"""Run the tomography pipeline"""
pipeline_name = "tomography"
Expand All @@ -320,6 +331,7 @@ def tomography_single(config_file : str, **kwargs: Any) -> int:
@pipe_options.flavor()
@pipe_options.selection()
@pipe_options.run_mode()
@pipe_options.site()
def sompz_single(config_file: str, **kwargs: Any) -> int:
"""Run the sompz pipeline"""
pipeline_name = "sompz"
Expand Down
8 changes: 8 additions & 0 deletions src/rail/cli/rail_pipe/pipe_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"pdf_path",
"run_mode",
"selection",
"site",
"output_dir",
"output_file",
"truth_path",
Expand Down Expand Up @@ -84,6 +85,13 @@ class RunMode(enum.Enum):
)


site = PartialOption(
"--site",
help="site for slurm submission",
default="s3df",
)


input_dir = PartialOption(
"--input_dir",
help="Input Directory",
Expand Down
36 changes: 34 additions & 2 deletions src/rail/cli/rail_pipe/pipe_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@
from rail.cli.rail_pipe.pipe_options import RunMode


S3DF_SLURM_OPTIONS: list[str] = [
"-p",
"milano",
"--account",
"rubin:commissioning@milano",
"--mem",
"16448",
"--parsable",
]
NERSC_SLURM_OPTIONS: list[str] = [
]

SLURM_OPTIONS = {
"s3df":S3DF_SLURM_OPTIONS,
"nersc":NERSC_SLURM_OPTIONS,
}


def handle_command(
run_mode: RunMode,
command_line: list[str],
Expand Down Expand Up @@ -60,6 +78,7 @@ def handle_commands(
run_mode: RunMode,
command_lines: list[list[str]],
script_path:str | None=None,
site:str="s3df",
) -> int:
""" Run a multiple commands in the mode requested
Expand All @@ -74,6 +93,9 @@ def handle_commands(
script_path: str | None
Path to write the slurm submit script to
site: str
Site to use for running slurm commands
Returns
-------
returncode: int
Expand Down Expand Up @@ -103,9 +125,14 @@ def handle_commands(
fout.write(f"{com_line}\n")

script_log = script_path.replace('.sh', '.log')

command_line = ["sbatch", "-o", script_log]
command_line += SLURM_OPTIONS[site]
command_line += [script_path]

try:
with subprocess.Popen(
["sbatch", "-o", script_log, "--mem", "16448", "-p", "milano", "--parsable", script_path],
command_line,
stdout=subprocess.PIPE,
) as sbatch:
assert sbatch.stdout
Expand Down Expand Up @@ -288,6 +315,8 @@ def run_pipeline_on_catalog(
input_catalog_name = pipeline_info['InputCatalogTag']
input_catalog = project.get_catalogs().get(input_catalog_name, {})

site = kwargs.get('site', 's3df')

# Loop through all possible combinations of the iteration variables that are
# relevant to this pipeline
if (iteration_vars := input_catalog.get("IterationVars", {})) is not None:
Expand Down Expand Up @@ -332,6 +361,7 @@ def run_pipeline_on_catalog(
*convert_commands,
],
script_path,
site=site,
)
except Exception as msg:
print(msg)
Expand Down Expand Up @@ -389,8 +419,10 @@ def run_pipeline_on_single_input(
log_dir=f"{sink_dir}/logs",
)

site = kwargs.get('site', 's3df')

try:
statuscode = handle_commands(run_mode, [command_line], script_path)
statuscode = handle_commands(run_mode, [command_line], script_path, site=site)
except Exception as msg:
print(msg)
statuscode = 1
Expand Down
2 changes: 1 addition & 1 deletion src/rail/cli/rail_pipe/reduce_roman_rubin_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyarrow import acero

from .pipe_options import RunMode
from ...utils.project import RailProject
from rail.utils.project import RailProject


COLUMNS = [
Expand Down
6 changes: 3 additions & 3 deletions src/rail/utils/name_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@


def update_include_dict(
orig_dict: dict[str, Any],
include_dict: dict[str, Any],
orig_dict: dict[Any, Any],
include_dict: dict[Any, Any],
) -> None:
"""Update a dict by updating (instead of replacing) sub-dicts
Expand All @@ -39,7 +39,7 @@ def update_include_dict(
Dict used to update the original
"""
for key, val in include_dict.items():
if isinstance(val, Mapping) and key in orig_dict:
if isinstance(val, dict) and key in orig_dict:
update_include_dict(orig_dict[key], val)
else:
orig_dict[key] = val
Expand Down

0 comments on commit 785375e

Please sign in to comment.