From 9408d36ff31e0a96f948fd3c1bcce11eb02e5aef Mon Sep 17 00:00:00 2001 From: Giulio Frasca Date: Mon, 13 May 2024 18:30:58 -0400 Subject: [PATCH] feat(backend): Add intermediate Template for iterator Tasks in ArgoCompiler Signed-off-by: Giulio Frasca --- backend/src/v2/compiler/argocompiler/dag.go | 92 ++++++++++++++++++--- 1 file changed, 81 insertions(+), 11 deletions(-) diff --git a/backend/src/v2/compiler/argocompiler/dag.go b/backend/src/v2/compiler/argocompiler/dag.go index 77c1fa0272a5..25885b7b6e79 100644 --- a/backend/src/v2/compiler/argocompiler/dag.go +++ b/backend/src/v2/compiler/argocompiler/dag.go @@ -72,14 +72,6 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen if err != nil { return err } - // Add Parallelism Limit if present - parallel := int64(kfpTask.GetIteratorPolicy().GetParallelismLimit()) - if parallel > 0 { - currentParallelism := dag.Parallelism - if currentParallelism == nil || parallel > *currentParallelism { - dag.Parallelism = ¶llel - } - } dag.DAG.Tasks = append(dag.DAG.Tasks, tasks...) } @@ -278,6 +270,82 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline if err != nil { return nil, err } + + // Set up Loop Control Template + loopDriverArgoName := name + "-loop-driver" + loopDriverInputs := dagDriverInputs{ + component: componentSpecPlaceholder, + parentDagID: parentDagID, + task: taskJson, // TODO(Bobgy): avoid duplicating task JSON twice in the template. + } + loopDriver, loopDriverOutputs, err := c.dagDriverTask(loopDriverArgoName, loopDriverInputs) + if err != nil { + return nil, err + } + loopDriver.Depends = depends(task.GetDependentTasks()) + + iteratorTasks, err := c.iterationItemTask("iteration", task, taskJson, parentDagID) + if err != nil { + return nil, err + } + + loopTmpl := &wfapi.Template{ + Inputs: wfapi.Inputs{ + Parameters: []wfapi.Parameter{ + {Name: paramParentDagID}, + }, + }, + DAG: &wfapi.DAGTemplate{ + Tasks: iteratorTasks, + }, + } + parallellism_limit := int64(task.GetIteratorPolicy().GetParallelismLimit()) + if parallellism_limit > 0 { + loopTmpl.Parallelism = ¶llellism_limit + } + + loopTmplName, err := c.addTemplate(loopTmpl, componentName+"-loop-"+name) + if err != nil { + return nil, err + } + when := "" + if task.GetTriggerPolicy().GetCondition() != "" { + when = loopDriverOutputs.condition + " != false" + } + + tasks = []wfapi.DAGTask{ + *loopDriver, + { + Name: name + "-loop", + Template: loopTmplName, + Depends: depends([]string{loopDriverArgoName}), + When: when, + Arguments: wfapi.Arguments{ + Parameters: []wfapi.Parameter{ + { + Name: paramParentDagID, + Value: wfapi.AnyStringPtr(loopDriverOutputs.executionID), + }, + }, + }, + }, + } + return tasks, nil +} + +func (c *workflowCompiler) iterationItemTask(name string, task *pipelinespec.PipelineTaskSpec, taskJson string, parentDagID string) (tasks []wfapi.DAGTask, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("iterationItem task: %w", err) + } + }() + componentName := task.GetComponentRef().GetName() + componentSpecPlaceholder, err := c.useComponentSpec(componentName) + if err != nil { + return nil, err + } + + // Set up Iteration (Single Task) Template driverArgoName := name + "-driver" driverInputs := dagDriverInputs{ component: componentSpecPlaceholder, @@ -289,9 +357,10 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline return nil, err } driver.Depends = depends(task.GetDependentTasks()) + iterationCount := intstr.FromString(driverOutputs.iterationCount) iterationTasks, err := c.task( - "iteration", + "iteration-item", task, taskInputs{ parentDagID: inputParameter(paramParentDagID), @@ -320,7 +389,8 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline if task.GetTriggerPolicy().GetCondition() != "" { when = driverOutputs.condition + " != false" } - tasks = []wfapi.DAGTask{ + + iteratorTasks := []wfapi.DAGTask{ *driver, { Name: name + "-iterations", @@ -339,7 +409,7 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline WithSequence: &wfapi.Sequence{Count: &iterationCount}, }, } - return tasks, nil + return iteratorTasks, nil } type dagDriverOutputs struct {