Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve yaml models: ConfigTasks #122

Merged
merged 11 commits into from
Feb 13, 2025
22 changes: 20 additions & 2 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self

import f90nml

from sirocco.core.graph_items import Task
from sirocco.parsing._yaml_data_models import ConfigIconTaskSpecs
from sirocco.parsing import _yaml_data_models as models

if TYPE_CHECKING:
from sirocco.parsing._yaml_data_models import ConfigTask
DropD marked this conversation as resolved.
Show resolved Hide resolved


@dataclass(kw_only=True)
class IconTask(ConfigIconTaskSpecs, Task):
class IconTask(models.ConfigIconTaskSpecs, Task):
core_namelists: dict[str, f90nml.Namelist] = field(default_factory=dict)

def init_core_namelists(self):
Expand Down Expand Up @@ -93,3 +97,17 @@ def section_index(section_name) -> tuple[str, int | None]:
if m := multi_section_pattern.match(section_name):
return m.group(1), int(m.group(2)) - 1
return section_name, None

@classmethod
def build_from_config(cls: type[Self], config: ConfigTask, **kwargs: Any) -> Self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, just for my understanding: We're supporting Python >= 3.10, but Self type was introduced in 3.11. Is that an issue, or can we use newer Python features in type hinting? According to ChatGPT:

"""
You're correct—Self was introduced in Python 3.11, not 3.12. Since your package claims to support Python 3.10+, using Self directly would be a problem for users on 3.10.

Can you use newer Python concepts in type annotations?
Yes! Starting with PEP 563 (Python 3.7) and PEP 649 (Python 3.10), Python supports postponed evaluation of annotations, meaning type hints are stored as strings instead of being evaluated immediately. This allows you to use newer type hints while keeping compatibility with older versions, as long as they aren't used at runtime.

Solution:
If you add this import at the top of your module:
from __future__ import annotations
Then the Self annotation will be treated as a string ("Self") and won't cause syntax errors in Python 3.10. However, tools like mypy will still understand it.
"""

Copy link
Collaborator Author

@DropD DropD Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPT is correct in that it was introduced later and in that in our case it works (at runtime) because it is treated as a string. Here's how that applies to Sirocco:

For now, we have set the hatch project to run type analysis on 3.12 and we use the future annotations. Therefore it is type checked as if it was supporting 3.12+ (ok at analysis time), while the interpreter treats all the type hints as just strings (ok at run time).

If we were to lower that to 3.10, we would have to import Self from typing_extensions (making that a dependency), if we choose to keep using it. It would also affect other things (not touched in this PR) like class Store[T], which would have to turn into class Store(typing.Generic[T]). The backports in typing_extensions can catch some of this but not syntax changes like class Store[T] (which I am honestly suprised that it works at runtime). I am also wondering why from typing import Self does not cause an import error in 3.10, but it seems not to.

config_kwargs = dict(config)
del config_kwargs["parameters"]
if not isinstance(config, models.ConfigIconTask):
raise TypeError
Comment on lines +109 to +110
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is here to satisfy mypy, because we can not put this as a type hint in the signature (would break Liskof substitution principle -> mypy complains). Not necessary for ShellTask because we really check in Task that the subclasses get the correct types.

config_kwargs["namelists"] = {
nml.path.name: models.NamelistInfo(**nml.model_dump()) for nml in config.namelists
}
return cls(
**kwargs,
**config_kwargs,
)
3 changes: 1 addition & 2 deletions src/sirocco/core/_tasks/shell_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@


@dataclass(kw_only=True)
class ShellTask(ConfigShellTaskSpecs, Task):
pass
class ShellTask(ConfigShellTaskSpecs, Task): ...
19 changes: 11 additions & 8 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:
class Task(ConfigBaseTaskSpecs, GraphItem):
"""Internal representation of a task node"""

plugin_classes: ClassVar[dict[str, type]] = field(default={}, repr=False)
plugin_classes: ClassVar[dict[str, type[Self]]] = field(default={}, repr=False)
color: ClassVar[Color] = field(default="light_red", repr=False)

inputs: list[BoundData] = field(default_factory=list)
Expand All @@ -87,7 +87,7 @@ def __init_subclass__(cls, **kwargs):

@classmethod
def from_config(
cls,
cls: type[Self],
config: ConfigTask,
config_rootdir: Path,
start_date: datetime | None,
Expand All @@ -102,22 +102,19 @@ def from_config(
for data_node in datastore.iter_from_cycle_spec(input_spec, coordinates)
]
outputs = [datastore[output_spec.name, coordinates] for output_spec in graph_spec.outputs]
# use the fact that pydantic models can be turned into dicts easily
cls_config = dict(config)
del cls_config["parameters"]
if (plugin_cls := Task.plugin_classes.get(type(config).plugin, None)) is None:
msg = f"Plugin {type(config).plugin!r} is not supported."
raise ValueError(msg)

new = plugin_cls(
new = plugin_cls.build_from_config(
config,
config_rootdir=config_rootdir,
coordinates=coordinates,
start_date=start_date,
end_date=end_date,
inputs=inputs,
outputs=outputs,
**cls_config,
) # this works because dataclass has generated this init for us
)

# Store for actual linking in link_wait_on_tasks() once all tasks are created
new._wait_on_specs = graph_spec.wait_on # noqa: SLF001 we don't have access to self in a dataclass
Expand All @@ -126,6 +123,12 @@ def from_config(

return new

@classmethod
def build_from_config(cls: type[Self], config: ConfigTask, **kwargs: Any) -> Self:
config_kwargs = dict(config)
del config_kwargs["parameters"]
return cls(**kwargs, **config_kwargs)

def link_wait_on_tasks(self, taskstore: Store[Task]) -> None:
self.wait_on = list(
chain(
Expand Down
172 changes: 117 additions & 55 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]:
return value


def extract_merge_key_as_value(data: Any, new_key: str = "name") -> Any:
if not isinstance(data, dict):
return data
if len(data) == 1:
key, value = next(iter(data.items()))
match key:
case str():
match value:
case str() if key == new_key:
pass
case dict() if new_key not in value:
data = value | {new_key: key}
case None:
data = {new_key: key}
case _:
msg = f"Expected a mapping, not a value (got {data})."
raise TypeError(msg)
case _:
msg = f"{new_key} must be a string (got {key})."
raise TypeError(msg)
return data


class _NamedBaseModel(BaseModel):
"""
Base model for reading names from yaml keys *or* keyword args to the constructor.
Expand Down Expand Up @@ -76,30 +99,7 @@ class _NamedBaseModel(BaseModel):
@model_validator(mode="before")
@classmethod
def reformat_named_object(cls, data: Any) -> Any:
return cls.extract_merge_name(data)

@classmethod
def extract_merge_name(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if len(data) == 1:
key, value = next(iter(data.items()))
match key:
case str():
match value:
case str() if key == "name":
pass
case dict() if "name" not in value:
data = value | {"name": key}
case None:
data = {"name": key}
case _:
msg = f"{cls.__name__} may only be used for named objects, not values (got {data})."
raise TypeError(msg)
case _:
msg = f"{cls.__name__} requires name to be a str (got {key})."
raise TypeError(msg)
return data
return extract_merge_key_as_value(data)


class _WhenBaseModel(BaseModel):
Expand Down Expand Up @@ -288,6 +288,12 @@ def check_period_is_not_negative_or_zero(self) -> ConfigCycle:

@dataclass(kw_only=True)
class ConfigBaseTaskSpecs:
"""
Common information for tasks.

Any of these keys can be None, in which case they are inherited from the root task.
"""

computer: str | None = None
host: str | None = None
account: str | None = None
Expand All @@ -298,7 +304,7 @@ class ConfigBaseTaskSpecs:

class ConfigBaseTask(_NamedBaseModel, ConfigBaseTaskSpecs):
"""
config for genric task, no plugin specifics
Config for genric task, no plugin specifics.
"""

parameters: list[str] = Field(default_factory=list)
Expand Down Expand Up @@ -369,6 +375,38 @@ class ConfigShellTaskSpecs:


class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
"""
Represent a shell script to be run as part of the workflow.

Examples:

>>> import textwrap
>>> my_task = validate_yaml_content(
... ConfigShellTask,
... textwrap.dedent(
... '''
... my_task:
... plugin: shell
... command: my_script.sh
... src: post_run_scripts
... cli_arguments: "-n 1024 {current_sim_output}"
... env_source_files: "env.sh"
... walltime: 00:01:00
... '''
... ),
... )
>>> my_task.cli_arguments[0]
ShellCliArgument(name='-n', references_data_item=False, cli_option_of_data_item=None)
>>> my_task.cli_arguments[1]
ShellCliArgument(name='1024', references_data_item=False, cli_option_of_data_item=None)
>>> my_task.cli_arguments[2]
ShellCliArgument(name='current_sim_output', references_data_item=True, cli_option_of_data_item=None)
>>> my_task.env_source_files
['env.sh']
>>> my_task.walltime.tm_min
1
"""

command: str = ""
cli_arguments: list[ShellCliArgument] = Field(default_factory=list)
env_source_files: list[str] = Field(default_factory=list)
Expand Down Expand Up @@ -420,7 +458,7 @@ def parse_cli_arguments(cli_arguments: str) -> list[ShellCliArgument]:


@dataclass(kw_only=True)
class ConfigNamelist:
class NamelistInfo:
"""Class for namelist specifications

- path is the path to the namelist file considered as template
Expand All @@ -434,20 +472,60 @@ class ConfigNamelist:
... "first_nml_block": {"first_param": "a string value", "second_param": 0},
... "second_nml_block": {"third_param": False},
... }
>>> config_nml = ConfigNamelist(path=path, specs=specs)
>>> nml_info = NamelistInfo(path=path, specs=specs)
"""

path: Path
specs: dict | None = None
specs: dict[str, Any] = field(default_factory=dict)


class ConfigNamelist(BaseModel, NamelistInfo):
"""
Validated namelist specifications.

Example:

>>> import textwrap
>>> from_init = ConfigNamelist(
... path="/path/to/some.nml", specs={"block": {"key": "value"}}
... )
>>> from_yml = validate_yaml_content(
... ConfigNamelist,
... textwrap.dedent(
... '''
... /path/to/some.nml:
... block:
... key: value
... '''
... ),
... )
>>> from_init == from_yml
True
>>> no_spec = ConfigNamelist(path="/path/to/some.nml")
>>> no_spec_yml = validate_yaml_content(ConfigNamelist, "/path/to/some.nml")
"""

specs: dict[str, Any] = {}

@model_validator(mode="before")
@classmethod
def merge_path_key(cls, data: Any) -> dict[str, Any]:
if isinstance(data, str):
return {"path": data}
merged = extract_merge_key_as_value(data, new_key="path")
if "specs" in merged:
return merged
path = merged.pop("path")
return {"path": path, "specs": merged or {}}


@dataclass(kw_only=True)
class ConfigIconTaskSpecs:
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: dict[str, ConfigNamelist]
namelists: dict[str, NamelistInfo]


class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
class ConfigIconTask(ConfigBaseTask):
"""Class representing an ICON task configuration from a workflow file

Examples:
Expand All @@ -469,34 +547,18 @@ class ConfigIconTask(ConfigBaseTask, ConfigIconTaskSpecs):
>>> icon_task_cfg = validate_yaml_content(ConfigIconTask, snippet)
"""

@field_validator("namelists", mode="before")
plugin: ClassVar[Literal["icon"]] = "icon"
namelists: list[ConfigNamelist]

@field_validator("namelists", mode="after")
@classmethod
def check_nmls(cls, nmls: dict[str, ConfigNamelist] | list[Any]) -> dict[str, ConfigNamelist]:
def check_nmls(cls, nmls: list[ConfigNamelist]) -> list[ConfigNamelist]:
# Make validator idempotent even if not used yet
if isinstance(nmls, dict):
return nmls
if not isinstance(nmls, list):
msg = f"expected a list got type {type(nmls).__name__}"
raise TypeError(msg)
namelists = {}
master_found = False
for nml in nmls:
msg = f"was expecting a dict of length 1 or a string, got {nml}"
if not isinstance(nml, str | dict):
raise TypeError(msg)
if isinstance(nml, dict) and len(nml) > 1:
raise TypeError(msg)
if isinstance(nml, str):
path, specs = Path(nml), None
else:
path, specs = next(iter(nml.items()))
path = Path(path)
namelists[path.name] = ConfigNamelist(path=path, specs=specs)
master_found = master_found or (path.name == "icon_master.namelist")
if not master_found:
names = [nml.path.name for nml in nmls]
if "icon_master.namelist" not in names:
msg = "icon_master.namelist not found"
raise ValueError(msg)
return namelists
return nmls


class DataType(enum.StrEnum):
Expand Down Expand Up @@ -595,7 +657,7 @@ def get_plugin_from_named_base_model(
) -> str:
if isinstance(data, ConfigRootTask | ConfigShellTask | ConfigIconTask):
return data.plugin
name_and_specs = ConfigBaseTask.extract_merge_name(data)
name_and_specs = extract_merge_key_as_value(data)
if name_and_specs.get("name", None) == "ROOT":
return ConfigRootTask.plugin
plugin = name_and_specs.get("plugin", None)
Expand Down
Loading