Skip to content

Commit

Permalink
fix: Better node overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
vijayvammi committed Jan 25, 2025
1 parent 186cf36 commit 2d86d86
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions extensions/pipeline_executor/argo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ class Resources(BaseModel):

# Lets construct this from UserDefaults
class ArgoTemplateDefaults(BaseModelWIthConfig):
image: str
active_deadline_seconds: Optional[int] = Field(default=86400) # 1 day
fail_fast: bool = Field(default=True)
node_selector: dict[str, str] = Field(default_factory=dict)
Expand All @@ -264,10 +263,12 @@ class CommonDefaults(BaseModelWIthConfig):
env: list[EnvVar | SecretEnvVar] = Field(default_factory=list, exclude=True)


# The user provided defaults at the top level
class UserDefaults(CommonDefaults):
image: str


# Overrides need not have image
class Overrides(CommonDefaults):
image: Optional[str] = Field(default=None)

Expand Down Expand Up @@ -392,11 +393,9 @@ class ArgoExecutor(GenericPipelineExecutor):
default="INFO"
)

defaults: UserDefaults # A similar structure to template defaults
defaults: UserDefaults
argo_workflow: ArgoWorkflow

# Lets use a generic one

overrides: dict[str, Overrides] = Field(default_factory=dict)

# This should be used when we refer to run_id or log_level in the containers
Expand All @@ -414,7 +413,7 @@ class ArgoExecutor(GenericPipelineExecutor):

def model_post_init(self, __context: Any) -> None:
self.argo_workflow.spec.template_defaults = ArgoTemplateDefaults(
image=self.defaults.image, **self.defaults.model_dump()
**self.defaults.model_dump()
)

def sanitize_name(self, name: str) -> str:
Expand Down Expand Up @@ -516,7 +515,10 @@ def _create_container_template(
node_override = None
if hasattr(node, "overrides"):
override_key = node.overrides.get(self.service_name, "")
node_override = self.overrides.get(override_key, None)
try:
node_override = self.overrides.get(override_key)
except: # noqa
raise Exception("Override not found for: ", override_key)

effective_settings = self.defaults.model_dump()
if node_override:
Expand Down Expand Up @@ -563,7 +565,7 @@ def _create_container_template(
]
),
volumes=[volume_pair.volume for volume_pair in self.volume_pairs],
**effective_settings,
**node_override.model_dump() if node_override else {},
)

return container_template
Expand Down Expand Up @@ -841,11 +843,6 @@ def execute_graph(

argo_workflow_dump = self.argo_workflow.model_dump(
by_alias=True,
# exclude={
# "spec": {
# "template_defaults": {"image_pull_policy", "image", "resources"}
# }
# },
exclude_none=True,
round_trip=False,
)
Expand Down

0 comments on commit 2d86d86

Please sign in to comment.