Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve yaml models: ConfigTasks #122

Merged
merged 11 commits into from
Feb 13, 2025
22 changes: 20 additions & 2 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Self

import f90nml

from sirocco.core.graph_items import Task
from sirocco.parsing import yaml_data_models as models
from sirocco.parsing.cycling import DateCyclePoint
from sirocco.parsing.yaml_data_models import ConfigIconTaskSpecs


@dataclass(kw_only=True)
class IconTask(ConfigIconTaskSpecs, Task):
class IconTask(models.ConfigIconTaskSpecs, Task):
core_namelists: dict[str, f90nml.Namelist] = field(default_factory=dict)

def init_core_namelists(self):
Expand Down Expand Up @@ -97,3 +98,20 @@ def section_index(section_name) -> tuple[str, int | None]:
if m := multi_section_pattern.match(section_name):
return m.group(1), int(m.group(2)) - 1
return section_name, None

@classmethod
def build_from_config(cls: type[Self], config: models.ConfigTask, **kwargs: Any) -> Self:
config_kwargs = dict(config)
del config_kwargs["parameters"]
# The following check is here for type checkers.
# We don't want to narrow the type in the signature, as that would break liskov substitution.
# We guarantee elsewhere this is called with the correct type at runtime
if not isinstance(config, models.ConfigIconTask):
raise TypeError
Comment on lines +109 to +110
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is here to satisfy mypy, because we can not put this as a type hint in the signature (would break Liskof substitution principle -> mypy complains). Not necessary for ShellTask because we really check in Task that the subclasses get the correct types.

config_kwargs["namelists"] = {
nml.path.name: models.NamelistSpec(**nml.model_dump()) for nml in config.namelists
}
return cls(
**kwargs,
**config_kwargs,
)
3 changes: 1 addition & 2 deletions src/sirocco/core/_tasks/shell_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@


@dataclass(kw_only=True)
class ShellTask(ConfigShellTaskSpecs, Task):
pass
class ShellTask(ConfigShellTaskSpecs, Task): ...
19 changes: 11 additions & 8 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:
class Task(ConfigBaseTaskSpecs, GraphItem):
"""Internal representation of a task node"""

plugin_classes: ClassVar[dict[str, type]] = field(default={}, repr=False)
plugin_classes: ClassVar[dict[str, type[Self]]] = field(default={}, repr=False)
color: ClassVar[Color] = field(default="light_red", repr=False)

inputs: list[BoundData] = field(default_factory=list)
Expand All @@ -87,7 +87,7 @@ def __init_subclass__(cls, **kwargs):

@classmethod
def from_config(
cls,
cls: type[Self],
config: ConfigTask,
config_rootdir: Path,
cycle_point: CyclePoint,
Expand All @@ -101,21 +101,18 @@ def from_config(
for data_node in datastore.iter_from_cycle_spec(input_spec, coordinates)
]
outputs = [datastore[output_spec.name, coordinates] for output_spec in graph_spec.outputs]
# use the fact that pydantic models can be turned into dicts easily
cls_config = dict(config)
del cls_config["parameters"]
if (plugin_cls := Task.plugin_classes.get(type(config).plugin, None)) is None:
msg = f"Plugin {type(config).plugin!r} is not supported."
raise ValueError(msg)

new = plugin_cls(
new = plugin_cls.build_from_config(
config,
config_rootdir=config_rootdir,
coordinates=coordinates,
cycle_point=cycle_point,
inputs=inputs,
outputs=outputs,
**cls_config,
) # this works because dataclass has generated this init for us
)

# Store for actual linking in link_wait_on_tasks() once all tasks are created
new._wait_on_specs = graph_spec.wait_on # noqa: SLF001 we don't have access to self in a dataclass
Expand All @@ -124,6 +121,12 @@ def from_config(

return new

@classmethod
def build_from_config(cls: type[Self], config: ConfigTask, **kwargs: Any) -> Self:
config_kwargs = dict(config)
del config_kwargs["parameters"]
return cls(**kwargs, **config_kwargs)

def link_wait_on_tasks(self, taskstore: Store[Task]) -> None:
self.wait_on = list(
chain(
Expand Down
172 changes: 117 additions & 55 deletions src/sirocco/parsing/yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]:
return value


def extract_merge_key_as_value(data: Any, new_key: str = "name") -> Any:
if not isinstance(data, dict):
return data
if len(data) == 1:
key, value = next(iter(data.items()))
match key:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Haven't really used such match/case constructs before, but makes the structure clearer than using isinstance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Now, if there was a match expression, the result of which can be assigned to a variable (like in Rust), I would love it even more...

case str():
match value:
case str() if key == new_key:
pass
case dict() if new_key not in value:
data = value | {new_key: key}
case None:
data = {new_key: key}
case _:
msg = f"Expected a mapping, not a value (got {data})."
raise TypeError(msg)
case _:
msg = f"{new_key} must be a string (got {key})."
raise TypeError(msg)
return data


class _NamedBaseModel(BaseModel):
"""
Base model for reading names from yaml keys *or* keyword args to the constructor.
Expand Down Expand Up @@ -75,30 +98,7 @@ class _NamedBaseModel(BaseModel):
@model_validator(mode="before")
@classmethod
def reformat_named_object(cls, data: Any) -> Any:
return cls.extract_merge_name(data)

@classmethod
def extract_merge_name(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if len(data) == 1:
key, value = next(iter(data.items()))
match key:
case str():
match value:
case str() if key == "name":
pass
case dict() if "name" not in value:
data = value | {"name": key}
case None:
data = {"name": key}
case _:
msg = f"{cls.__name__} may only be used for named objects, not values (got {data})."
raise TypeError(msg)
case _:
msg = f"{cls.__name__} requires name to be a str (got {key})."
raise TypeError(msg)
return data
return extract_merge_key_as_value(data)


def select_when(spec: Any) -> When:
Expand Down Expand Up @@ -241,6 +241,12 @@ class ConfigCycle(_NamedBaseModel):

@dataclass(kw_only=True)
class ConfigBaseTaskSpecs:
"""
Common information for tasks.

Any of these keys can be None, in which case they are inherited from the root task.
"""

computer: str | None = None
host: str | None = None
account: str | None = None
Expand All @@ -251,7 +257,7 @@ class ConfigBaseTaskSpecs:

class ConfigBaseTask(_NamedBaseModel, ConfigBaseTaskSpecs):
"""
config for genric task, no plugin specifics
Config for generic task, no plugin specifics.
"""

parameters: list[str] = Field(default_factory=list)
Expand Down Expand Up @@ -322,6 +328,38 @@ class ConfigShellTaskSpecs:


class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
"""
Represent a shell script to be run as part of the workflow.

Examples:

>>> import textwrap
>>> my_task = validate_yaml_content(
... ConfigShellTask,
... textwrap.dedent(
... '''
... my_task:
... plugin: shell
... command: my_script.sh
... src: post_run_scripts
... cli_arguments: "-n 1024 {current_sim_output}"
... env_source_files: "env.sh"
... walltime: 00:01:00
... '''
... ),
... )
>>> my_task.cli_arguments[0]
ShellCliArgument(name='-n', references_data_item=False, cli_option_of_data_item=None)
>>> my_task.cli_arguments[1]
ShellCliArgument(name='1024', references_data_item=False, cli_option_of_data_item=None)
>>> my_task.cli_arguments[2]
ShellCliArgument(name='current_sim_output', references_data_item=True, cli_option_of_data_item=None)
>>> my_task.env_source_files
['env.sh']
>>> my_task.walltime.tm_min
1
"""

command: str = ""
cli_arguments: list[ShellCliArgument] = Field(default_factory=list)
env_source_files: list[str] = Field(default_factory=list)
Expand Down Expand Up @@ -373,7 +411,7 @@ def parse_cli_arguments(cli_arguments: str) -> list[ShellCliArgument]:


@dataclass(kw_only=True)
class ConfigNamelist:
class NamelistSpec:
"""Class for namelist specifications

- path is the path to the namelist file considered as template
Expand All @@ -387,20 +425,60 @@ class ConfigNamelist:
... "first_nml_block": {"first_param": "a string value", "second_param": 0},
... "second_nml_block": {"third_param": False},
... }
>>> config_nml = ConfigNamelist(path=path, specs=specs)
>>> nml_info = NamelistSpec(path=path, specs=specs)
"""

path: Path
specs: dict | None = None
specs: dict[str, Any] = field(default_factory=dict)


class ConfigNamelist(BaseModel, NamelistSpec):
"""
Validated namelist specifications.

Example:

>>> import textwrap
>>> from_init = ConfigNamelist(
... path="/path/to/some.nml", specs={"block": {"key": "value"}}
... )
>>> from_yml = validate_yaml_content(
... ConfigNamelist,
... textwrap.dedent(
... '''
... /path/to/some.nml:
... block:
... key: value
... '''
... ),
... )
>>> from_init == from_yml
True
>>> no_spec = ConfigNamelist(path="/path/to/some.nml")
>>> no_spec_yml = validate_yaml_content(ConfigNamelist, "/path/to/some.nml")
"""

specs: dict[str, Any] = {}

@model_validator(mode="before")
@classmethod
def merge_path_key(cls, data: Any) -> dict[str, Any]:
if isinstance(data, str):
return {"path": data}
merged = extract_merge_key_as_value(data, new_key="path")
if "specs" in merged:
return merged
path = merged.pop("path")
return {"path": path, "specs": merged or {}}


@dataclass(kw_only=True)
class ConfigIconTaskSpecs:
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: dict[str, ConfigNamelist]
namelists: dict[str, NamelistSpec]


class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
class ConfigIconTask(ConfigBaseTask):
"""Class representing an ICON task configuration from a workflow file

Examples:
Expand All @@ -422,34 +500,18 @@ class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
>>> icon_task_cfg = validate_yaml_content(ConfigIconTask, snippet)
"""

@field_validator("namelists", mode="before")
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: list[ConfigNamelist]

@field_validator("namelists", mode="after")
@classmethod
def check_nmls(cls, nmls: dict[str, ConfigNamelist] | list[Any]) -> dict[str, ConfigNamelist]:
def check_nmls(cls, nmls: list[ConfigNamelist]) -> list[ConfigNamelist]:
# Make validator idempotent even if not used yet
if isinstance(nmls, dict):
return nmls
if not isinstance(nmls, list):
msg = f"expected a list got type {type(nmls).__name__}"
raise TypeError(msg)
namelists = {}
master_found = False
for nml in nmls:
msg = f"was expecting a dict of length 1 or a string, got {nml}"
if not isinstance(nml, str | dict):
raise TypeError(msg)
if isinstance(nml, dict) and len(nml) > 1:
raise TypeError(msg)
if isinstance(nml, str):
path, specs = Path(nml), None
else:
path, specs = next(iter(nml.items()))
path = Path(path)
namelists[path.name] = ConfigNamelist(path=path, specs=specs)
master_found = master_found or (path.name == "icon_master.namelist")
if not master_found:
names = [nml.path.name for nml in nmls]
if "icon_master.namelist" not in names:
msg = "icon_master.namelist not found"
raise ValueError(msg)
return namelists
return nmls


class DataType(enum.StrEnum):
Expand Down Expand Up @@ -548,7 +610,7 @@ def get_plugin_from_named_base_model(
) -> str:
if isinstance(data, ConfigRootTask | ConfigShellTask | ConfigIconTask):
return data.plugin
name_and_specs = ConfigBaseTask.extract_merge_name(data)
name_and_specs = extract_merge_key_as_value(data)
if name_and_specs.get("name", None) == "ROOT":
return ConfigRootTask.plugin
plugin = name_and_specs.get("plugin", None)
Expand Down
Loading