Skip to content

Commit

Permalink
fix: tests and ICON task
Browse files Browse the repository at this point in the history
also rename target_date -> target_cycle
  • Loading branch information
leclairm committed Feb 10, 2025
1 parent 04928b7 commit ec24b5d
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 246 deletions.
8 changes: 6 additions & 2 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import f90nml

from sirocco.core.graph_items import Task
from sirocco.parsing.cycling import DateCyclePoint
from sirocco.parsing.yaml_data_models import ConfigIconTaskSpecs


Expand Down Expand Up @@ -48,10 +49,13 @@ def update_core_namelists_from_config(self):
nml_section.update(params)

def update_core_namelists_from_workflow(self):
if not isinstance(self.cycle_point, DateCyclePoint):
msg = "ICON task must have a DateCyclePoint"
raise TypeError(msg)
self.core_namelists["icon_master.namelist"]["master_time_control_nml"].update(
{
"experimentStartDate": self.start_date.isoformat() + "Z",
"experimentStopDate": self.stop_date.isoformat() + "Z",
"experimentStartDate": self.cycle_point.start_date.isoformat() + "Z",
"experimentStopDate": self.cycle_point.stop_date.isoformat() + "Z",
}
)
self.core_namelists["icon_master.namelist"]["master_nml"]["lrestart"] = any(
Expand Down
20 changes: 9 additions & 11 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from itertools import chain, product
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypeVar, cast

from sirocco.parsing.target_date import DateList, LagList, SameDate
from sirocco.parsing.when import WhenSpec
from sirocco.parsing.target_cycle import DateList, LagList, NoTargetCycle
from sirocco.parsing.yaml_data_models import (
ConfigAvailableData,
ConfigBaseDataSpecs,
Expand Down Expand Up @@ -188,10 +187,10 @@ def __getitem__(self, coordinates: dict) -> GRAPH_ITEM_T:

def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, ref_coordinates: dict) -> Iterator[GRAPH_ITEM_T]:
# Check date references
if "date" not in self._dims and isinstance(spec.target_date, DateList | LagList):
if "date" not in self._dims and isinstance(spec.target_cycle, DateList | LagList):
msg = f"Array {self._name} has no date dimension, cannot be referenced by dates"
raise ValueError(msg)
if "date" in self._dims and ref_coordinates.get("date") is None and not isinstance(spec.target_date, DateList):
if "date" in self._dims and ref_coordinates.get("date") is None and not isinstance(spec.target_cycle, DateList):
msg = f"Array {self._name} has a date dimension, must be referenced by dates"
raise ValueError(msg)

Expand All @@ -200,13 +199,14 @@ def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, ref_coordinates: dict

def _resolve_target_dim(self, spec: TargetNodesBaseModel, dim: str, ref_coordinates: Any) -> Iterator[Any]:
if dim == "date":
match spec.target_date:
case SameDate():
match spec.target_cycle:
case NoTargetCycle():
yield ref_coordinates["date"]
case DateList():
yield from spec.target_date.dates
yield from spec.target_cycle.dates
case LagList():
yield from spec.target_date.lags
for lag in spec.target_cycle.lags:
yield ref_coordinates["date"] + lag
elif spec.parameters.get(dim) == "single":
yield ref_coordinates[dim]
else:
Expand Down Expand Up @@ -239,10 +239,8 @@ def __getitem__(self, key: tuple[str, dict]) -> GRAPH_ITEM_T:
return self._dict[name][coordinates]

def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, ref_coordinates: dict) -> Iterator[GRAPH_ITEM_T]:
# Check if we need to skip this querry
if isinstance(spec.when, WhenSpec) and not spec.when.is_active(ref_coordinates.get("date")):
if not spec.when.is_active(ref_coordinates.get("date")):
return
# Yield items
yield from self._dict[spec.name].iter_from_cycle_spec(spec, ref_coordinates)

def __iter__(self) -> Iterator[GRAPH_ITEM_T]:
Expand Down
10 changes: 5 additions & 5 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def __init__(
task_dict: dict[str, ConfigTask] = {task.name: task for task in tasks}

# Function to iterate over date and parameter combinations
def iter_coordinates(param_refs: list[str], cycle_point: CyclePoint) -> Iterator[dict]:
def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator[dict]:
axes = {k: parameters[k] for k in param_refs}
if isinstance(cycle_point, DateCyclePoint):
axes["date"] = [cycle_point.begin_date]
yield from (dict(zip(axes.keys(), x, strict=False)) for x in product(*axes.values()))

# 1 - create availalbe data nodes
for available_data_config in data.available:
for coordinates in iter_coordinates(param_refs=available_data_config.parameters, cycle_point=OneOffPoint()):
for coordinates in iter_coordinates(OneOffPoint(), available_data_config.parameters):
self.data.add(Data.from_config(config=available_data_config, coordinates=coordinates))

# 2 - create output data nodes
Expand All @@ -63,7 +63,7 @@ def iter_coordinates(param_refs: list[str], cycle_point: CyclePoint) -> Iterator
for data_ref in task_ref.outputs:
data_name = data_ref.name
data_config = data_dict[data_name]
for coordinates in iter_coordinates(param_refs=data_config.parameters, cycle_point=cycle_point):
for coordinates in iter_coordinates(cycle_point, data_config.parameters):
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))

# 3 - create cycles and tasks
Expand All @@ -74,7 +74,7 @@ def iter_coordinates(param_refs: list[str], cycle_point: CyclePoint) -> Iterator
for task_graph_spec in cycle_config.tasks:
task_name = task_graph_spec.name
task_config = task_dict[task_name]
for coordinates in iter_coordinates(param_refs=task_config.parameters, cycle_point=cycle_point):
for coordinates in iter_coordinates(cycle_point, task_config.parameters):
task = Task.from_config(
config=task_config,
config_rootdir=self.config_rootdir,
Expand All @@ -89,7 +89,7 @@ def iter_coordinates(param_refs: list[str], cycle_point: CyclePoint) -> Iterator
Cycle(
name=cycle_name,
tasks=cycle_tasks,
coordinates={} if isinstance(cycle_point, OneOffPoint) else {"date": cycle_point.begin_date},
coordinates={"date": cycle_point.begin_date} if isinstance(cycle_point, DateCyclePoint) else {},
)
)

Expand Down
26 changes: 16 additions & 10 deletions tests/cases/large/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ cycles:
inputs: [obs_data]
outputs: [extpar_file]
- icon_bimonthly:
start_date: *root_start_date
stop_date: *root_stop_date
period: 'P2M'
cycling:
start_date: *root_start_date
stop_date: *root_stop_date
period: 'P2M'
tasks:
- preproc:
inputs: [grid_file, extpar_file, ERA5]
Expand All @@ -19,15 +20,17 @@ cycles:
- icon:
when:
after: '2025-03-01T00:00'
lag: '-P4M'
target_cycle:
lag: '-P4M'
- icon:
inputs:
- grid_file
- icon_input
- icon_restart:
when:
after: *root_start_date
lag: '-P2M'
target_cycle:
lag: '-P2M'
port: restart
outputs: [stream_1, stream_2, icon_restart]
- postproc_1:
Expand All @@ -37,20 +40,23 @@ cycles:
inputs: [postout_1, stream_1, icon_input]
outputs: [stored_data_1]
- yearly:
start_date: *root_start_date
stop_date: *root_stop_date
period: 'P1Y'
cycling:
start_date: *root_start_date
stop_date: *root_stop_date
period: 'P1Y'
tasks:
- postproc_2:
inputs:
- stream_2:
lag: ['P0M', 'P2M', 'P4M', 'P6M', 'P8M', 'P10M']
target_cycle:
lag: ['P0M', 'P2M', 'P4M', 'P6M', 'P8M', 'P10M']
outputs: [postout_2]
- store_and_clean_2:
inputs:
- postout_2
- stream_2:
lag: ['P0M', 'P2M', 'P4M', 'P6M', 'P8M', 'P10M']
target_cycle:
lag: ['P0M', 'P2M', 'P4M', 'P6M', 'P8M', 'P10M']
outputs:
- stored_data_2
# Each task and piece of data (input and output of tasks) used to
Expand Down
Loading

0 comments on commit ec24b5d

Please sign in to comment.