-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve yaml models: ConfigTasks #122
Changes from all commits
56d911a
5fcd72c
a690f43
38a8cb9
daf8430
696ca6d
534efc3
f773837
4604dde
46db910
cc62bc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,29 @@ def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]: | |
return value | ||
|
||
|
||
def extract_merge_key_as_value(data: Any, new_key: str = "name") -> Any: | ||
if not isinstance(data, dict): | ||
return data | ||
if len(data) == 1: | ||
key, value = next(iter(data.items())) | ||
match key: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. Haven't really used such There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. Now, if there was a |
||
case str(): | ||
match value: | ||
case str() if key == new_key: | ||
pass | ||
case dict() if new_key not in value: | ||
data = value | {new_key: key} | ||
case None: | ||
data = {new_key: key} | ||
case _: | ||
msg = f"Expected a mapping, not a value (got {data})." | ||
raise TypeError(msg) | ||
case _: | ||
msg = f"{new_key} must be a string (got {key})." | ||
raise TypeError(msg) | ||
return data | ||
|
||
|
||
class _NamedBaseModel(BaseModel): | ||
""" | ||
Base model for reading names from yaml keys *or* keyword args to the constructor. | ||
|
@@ -75,30 +98,7 @@ class _NamedBaseModel(BaseModel): | |
@model_validator(mode="before") | ||
@classmethod | ||
def reformat_named_object(cls, data: Any) -> Any: | ||
return cls.extract_merge_name(data) | ||
|
||
@classmethod | ||
def extract_merge_name(cls, data: Any) -> Any: | ||
if not isinstance(data, dict): | ||
return data | ||
if len(data) == 1: | ||
key, value = next(iter(data.items())) | ||
match key: | ||
case str(): | ||
match value: | ||
case str() if key == "name": | ||
pass | ||
case dict() if "name" not in value: | ||
data = value | {"name": key} | ||
case None: | ||
data = {"name": key} | ||
case _: | ||
msg = f"{cls.__name__} may only be used for named objects, not values (got {data})." | ||
raise TypeError(msg) | ||
case _: | ||
msg = f"{cls.__name__} requires name to be a str (got {key})." | ||
raise TypeError(msg) | ||
return data | ||
return extract_merge_key_as_value(data) | ||
|
||
|
||
def select_when(spec: Any) -> When: | ||
|
@@ -241,6 +241,12 @@ class ConfigCycle(_NamedBaseModel): | |
|
||
@dataclass(kw_only=True) | ||
class ConfigBaseTaskSpecs: | ||
""" | ||
Common information for tasks. | ||
|
||
Any of these keys can be None, in which case they are inherited from the root task. | ||
""" | ||
|
||
computer: str | None = None | ||
host: str | None = None | ||
account: str | None = None | ||
|
@@ -251,7 +257,7 @@ class ConfigBaseTaskSpecs: | |
|
||
class ConfigBaseTask(_NamedBaseModel, ConfigBaseTaskSpecs): | ||
""" | ||
config for genric task, no plugin specifics | ||
Config for generic task, no plugin specifics. | ||
""" | ||
|
||
parameters: list[str] = Field(default_factory=list) | ||
|
@@ -322,6 +328,38 @@ class ConfigShellTaskSpecs: | |
|
||
|
||
class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs): | ||
""" | ||
Represent a shell script to be run as part of the workflow. | ||
|
||
Examples: | ||
|
||
>>> import textwrap | ||
>>> my_task = validate_yaml_content( | ||
... ConfigShellTask, | ||
... textwrap.dedent( | ||
... ''' | ||
... my_task: | ||
... plugin: shell | ||
... command: my_script.sh | ||
... src: post_run_scripts | ||
... cli_arguments: "-n 1024 {current_sim_output}" | ||
... env_source_files: "env.sh" | ||
... walltime: 00:01:00 | ||
... ''' | ||
... ), | ||
... ) | ||
>>> my_task.cli_arguments[0] | ||
ShellCliArgument(name='-n', references_data_item=False, cli_option_of_data_item=None) | ||
>>> my_task.cli_arguments[1] | ||
ShellCliArgument(name='1024', references_data_item=False, cli_option_of_data_item=None) | ||
>>> my_task.cli_arguments[2] | ||
ShellCliArgument(name='current_sim_output', references_data_item=True, cli_option_of_data_item=None) | ||
>>> my_task.env_source_files | ||
['env.sh'] | ||
>>> my_task.walltime.tm_min | ||
1 | ||
""" | ||
|
||
command: str = "" | ||
cli_arguments: list[ShellCliArgument] = Field(default_factory=list) | ||
env_source_files: list[str] = Field(default_factory=list) | ||
|
@@ -373,7 +411,7 @@ def parse_cli_arguments(cli_arguments: str) -> list[ShellCliArgument]: | |
|
||
|
||
@dataclass(kw_only=True) | ||
class ConfigNamelist: | ||
class NamelistSpec: | ||
"""Class for namelist specifications | ||
|
||
- path is the path to the namelist file considered as template | ||
|
@@ -387,20 +425,60 @@ class ConfigNamelist: | |
... "first_nml_block": {"first_param": "a string value", "second_param": 0}, | ||
... "second_nml_block": {"third_param": False}, | ||
... } | ||
>>> config_nml = ConfigNamelist(path=path, specs=specs) | ||
>>> nml_info = NamelistSpec(path=path, specs=specs) | ||
""" | ||
|
||
path: Path | ||
specs: dict | None = None | ||
specs: dict[str, Any] = field(default_factory=dict) | ||
|
||
|
||
class ConfigNamelist(BaseModel, NamelistSpec): | ||
""" | ||
Validated namelist specifications. | ||
|
||
Example: | ||
|
||
>>> import textwrap | ||
>>> from_init = ConfigNamelist( | ||
... path="/path/to/some.nml", specs={"block": {"key": "value"}} | ||
... ) | ||
>>> from_yml = validate_yaml_content( | ||
... ConfigNamelist, | ||
... textwrap.dedent( | ||
... ''' | ||
... /path/to/some.nml: | ||
... block: | ||
... key: value | ||
... ''' | ||
... ), | ||
... ) | ||
>>> from_init == from_yml | ||
True | ||
>>> no_spec = ConfigNamelist(path="/path/to/some.nml") | ||
>>> no_spec_yml = validate_yaml_content(ConfigNamelist, "/path/to/some.nml") | ||
""" | ||
|
||
specs: dict[str, Any] = {} | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def merge_path_key(cls, data: Any) -> dict[str, Any]: | ||
if isinstance(data, str): | ||
return {"path": data} | ||
merged = extract_merge_key_as_value(data, new_key="path") | ||
if "specs" in merged: | ||
return merged | ||
path = merged.pop("path") | ||
return {"path": path, "specs": merged or {}} | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class ConfigIconTaskSpecs: | ||
plugin: ClassVar[Literal["icon"]] = "icon" | ||
namelists: dict[str, ConfigNamelist] | ||
namelists: dict[str, NamelistSpec] | ||
|
||
|
||
class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs): | ||
class ConfigIconTask(ConfigBaseTask): | ||
"""Class representing an ICON task configuration from a workflow file | ||
|
||
Examples: | ||
|
@@ -422,34 +500,18 @@ class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs): | |
>>> icon_task_cfg = validate_yaml_content(ConfigIconTask, snippet) | ||
""" | ||
|
||
@field_validator("namelists", mode="before") | ||
plugin: ClassVar[Literal["icon"]] = "icon" | ||
namelists: list[ConfigNamelist] | ||
|
||
@field_validator("namelists", mode="after") | ||
@classmethod | ||
def check_nmls(cls, nmls: dict[str, ConfigNamelist] | list[Any]) -> dict[str, ConfigNamelist]: | ||
def check_nmls(cls, nmls: list[ConfigNamelist]) -> list[ConfigNamelist]: | ||
# Make validator idempotent even if not used yet | ||
if isinstance(nmls, dict): | ||
return nmls | ||
if not isinstance(nmls, list): | ||
msg = f"expected a list got type {type(nmls).__name__}" | ||
raise TypeError(msg) | ||
namelists = {} | ||
master_found = False | ||
for nml in nmls: | ||
msg = f"was expecting a dict of length 1 or a string, got {nml}" | ||
if not isinstance(nml, str | dict): | ||
raise TypeError(msg) | ||
if isinstance(nml, dict) and len(nml) > 1: | ||
raise TypeError(msg) | ||
if isinstance(nml, str): | ||
path, specs = Path(nml), None | ||
else: | ||
path, specs = next(iter(nml.items())) | ||
path = Path(path) | ||
namelists[path.name] = ConfigNamelist(path=path, specs=specs) | ||
master_found = master_found or (path.name == "icon_master.namelist") | ||
if not master_found: | ||
names = [nml.path.name for nml in nmls] | ||
if "icon_master.namelist" not in names: | ||
msg = "icon_master.namelist not found" | ||
raise ValueError(msg) | ||
return namelists | ||
return nmls | ||
|
||
|
||
class DataType(enum.StrEnum): | ||
|
@@ -548,7 +610,7 @@ def get_plugin_from_named_base_model( | |
) -> str: | ||
if isinstance(data, ConfigRootTask | ConfigShellTask | ConfigIconTask): | ||
return data.plugin | ||
name_and_specs = ConfigBaseTask.extract_merge_name(data) | ||
name_and_specs = extract_merge_key_as_value(data) | ||
if name_and_specs.get("name", None) == "ROOT": | ||
return ConfigRootTask.plugin | ||
plugin = name_and_specs.get("plugin", None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is here to satisfy mypy, because we can not put this as a type hint in the signature (would break Liskof substitution principle -> mypy complains). Not necessary for
ShellTask
because we really check inTask
that the subclasses get the correct types.