Skip to content

Commit

Permalink
Allow specifying task config via environment and CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Apr 4, 2023
1 parent 1887e71 commit 608433b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
53 changes: 49 additions & 4 deletions qcengine/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,40 @@ def parse_args():
parser = argparse.ArgumentParser(description="A CLI for the QCEngine.")
parser.add_argument("--version", action="version", version=f"{__version__}")

parent_parser = argparse.ArgumentParser(add_help=False)
task_group = parent_parser.add_argument_group(
"Task Configuration", "Extra configuration related to running the computation"
)
task_group.add_argument("--ncores", type=int, help="The number of cores to use for the task")
task_group.add_argument("--nnodes", type=int, help="The number of nodes to use")
task_group.add_argument("--memory", type=float, help="The amount of memory (in GiB) to use")
task_group.add_argument("--scratch-directory", type=str, help="Where to store temporary files")
task_group.add_argument("--retries", type=int, help="Number of retries for random failures")
task_group.add_argument("--mpiexec-command", type=str, help="Command used to launch MPI tasks")
task_group.add_argument(
"--use-mpiexec",
action="store_true",
default=None,
help="Whether it is necessary to use MPI to run an executable",
)
task_group.add_argument("--cores-per-rank", type=int, help="Number of cores per MPI rank")
task_group.add_argument(
"--scratch-messy",
action="store_true",
default=None,
help="Leave the scratch directory and contents on disk after completion",
)

subparsers = parser.add_subparsers(dest="command")

info = subparsers.add_parser("info", help="Print information about QCEngine setup, version, and environment.")
info.add_argument(
"category", nargs="*", default="all", choices=info_choices, help="The information categories to show."
)

run = subparsers.add_parser("run", help="Run a program on a given task. Output is printed as a JSON blob.")
run = subparsers.add_parser(
"run", parents=[parent_parser], help="Run a program on a given task. Output is printed as a JSON blob."
)
run.add_argument("program", type=str, help="The program to run.")
run.add_argument(
"data",
Expand All @@ -49,7 +75,9 @@ def parse_args():
)

run_procedure = subparsers.add_parser(
"run-procedure", help="Run a procedure on a given task. " "Output is printed as a JSON blob."
"run-procedure",
parents=[parent_parser],
help="Run a procedure on a given task. " "Output is printed as a JSON blob.",
)
run_procedure.add_argument("procedure", type=str, help="The procedure to run.")
run_procedure.add_argument(
Expand Down Expand Up @@ -164,12 +192,29 @@ def main(args=None):
# Grab CLI args if not present
if args is None:
args = parse_args()

# Break out a task config
task_config = {
"ncores": args.pop("ncores", None),
"memory": args.pop("memory", None),
"nnodes": args.pop("nnodes", None),
"scratch_directory": args.pop("scratch_directory", None),
"retries": args.pop("retries", None),
"mpiexec_command": args.pop("mpiexec_command", None),
"use_mpiexec": args.pop("use_mpiexec", None),
"cores_per_rank": args.pop("cores_per_rank", None),
"scratch_messy": args.pop("scratch_messy", None),
}

# Prune None values and let other config functions handle defaults
task_config = {k: v for k, v in task_config.items() if v is not None}

command = args.pop("command")
if command == "info":
info_cli(args)
elif command == "run":
ret = compute(data_arg_helper(args["data"]), args["program"])
ret = compute(data_arg_helper(args["data"]), args["program"], task_config=task_config)
print(ret.json())
elif command == "run-procedure":
ret = compute_procedure(data_arg_helper(args["data"]), args["procedure"])
ret = compute_procedure(data_arg_helper(args["data"]), args["procedure"], task_config=task_config)
print(ret.json())
28 changes: 22 additions & 6 deletions qcengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class Config:
extra = "forbid"


class TaskConfig(pydantic.BaseModel):
class TaskConfig(pydantic.BaseSettings):
"""Description of the configuration used to launch a task."""

# Specifications
Expand All @@ -159,8 +159,9 @@ class TaskConfig(pydantic.BaseModel):
False, description="Leave scratch directory and contents on disk after completion."
)

class Config:
class Config(pydantic.BaseSettings.Config):
extra = "forbid"
env_prefix = "QCENGINE_"


def _load_defaults() -> None:
Expand Down Expand Up @@ -267,6 +268,19 @@ def parse_environment(data: Dict[str, Any]) -> Dict[str, Any]:
return ret


def read_qcengine_task_environment() -> Dict[str, Any]:
"""
Reads the qcengine task-related environment variables and returns a dictionary of the values.
"""

ret = {}
for k, v in os.environ.items():
if k.startswith("QCENGINE_"):
ret[k[9:].lower()] = v

return ret


def get_config(*, hostname: Optional[str] = None, task_config: Dict[str, Any] = None) -> TaskConfig:
"""
Returns the configuration key for qcengine.
Expand All @@ -275,7 +289,9 @@ def get_config(*, hostname: Optional[str] = None, task_config: Dict[str, Any] =
if task_config is None:
task_config = {}

task_config = parse_environment(task_config)
task_config_env = read_qcengine_task_environment()
task_config = {**task_config_env, **task_config}

config = {}

# Node data
Expand All @@ -285,7 +301,7 @@ def get_config(*, hostname: Optional[str] = None, task_config: Dict[str, Any] =
config["retries"] = task_config.pop("retries", node.retries)

# Jobs per node
jobs_per_node = task_config.pop("jobs_per_node", None) or node.jobs_per_node
jobs_per_node = int(task_config.pop("jobs_per_node", None) or node.jobs_per_node)

# Handle memory
memory = task_config.pop("memory", None)
Expand All @@ -297,12 +313,12 @@ def get_config(*, hostname: Optional[str] = None, task_config: Dict[str, Any] =
config["memory"] = memory

# Get the number of cores available to each task
ncores = task_config.pop("ncores", int(ncores / jobs_per_node))
ncores = int(task_config.pop("ncores", int(ncores / jobs_per_node)))
if ncores < 1:
raise KeyError("Number of jobs per node exceeds the number of available cores.")

config["ncores"] = ncores
config["nnodes"] = task_config.pop("nnodes", 1)
config["nnodes"] = int(task_config.pop("nnodes", 1))

# Add in the MPI launch command template
config["mpiexec_command"] = node.mpiexec_command
Expand Down

0 comments on commit 608433b

Please sign in to comment.