Skip to content

Commit

Permalink
feat(backend): Add intermediate Template for iterator Tasks in ArgoCo…
Browse files Browse the repository at this point in the history
…mpiler

Signed-off-by: Giulio Frasca <[email protected]>
  • Loading branch information
gmfrasca committed May 13, 2024
1 parent 6837f77 commit 9408d36
Showing 1 changed file with 81 additions and 11 deletions.
92 changes: 81 additions & 11 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen
if err != nil {
return err
}
// Add Parallelism Limit if present
parallel := int64(kfpTask.GetIteratorPolicy().GetParallelismLimit())
if parallel > 0 {
currentParallelism := dag.Parallelism
if currentParallelism == nil || parallel > *currentParallelism {
dag.Parallelism = &parallel
}
}

dag.DAG.Tasks = append(dag.DAG.Tasks, tasks...)
}
Expand Down Expand Up @@ -278,6 +270,82 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
if err != nil {
return nil, err
}

// Set up Loop Control Template
loopDriverArgoName := name + "-loop-driver"
loopDriverInputs := dagDriverInputs{
component: componentSpecPlaceholder,
parentDagID: parentDagID,
task: taskJson, // TODO(Bobgy): avoid duplicating task JSON twice in the template.
}
loopDriver, loopDriverOutputs, err := c.dagDriverTask(loopDriverArgoName, loopDriverInputs)
if err != nil {
return nil, err
}
loopDriver.Depends = depends(task.GetDependentTasks())

iteratorTasks, err := c.iterationItemTask("iteration", task, taskJson, parentDagID)
if err != nil {
return nil, err
}

loopTmpl := &wfapi.Template{
Inputs: wfapi.Inputs{
Parameters: []wfapi.Parameter{
{Name: paramParentDagID},
},
},
DAG: &wfapi.DAGTemplate{
Tasks: iteratorTasks,
},
}
parallellism_limit := int64(task.GetIteratorPolicy().GetParallelismLimit())
if parallellism_limit > 0 {
loopTmpl.Parallelism = &parallellism_limit
}

loopTmplName, err := c.addTemplate(loopTmpl, componentName+"-loop-"+name)
if err != nil {
return nil, err
}
when := ""
if task.GetTriggerPolicy().GetCondition() != "" {
when = loopDriverOutputs.condition + " != false"
}

tasks = []wfapi.DAGTask{
*loopDriver,
{
Name: name + "-loop",
Template: loopTmplName,
Depends: depends([]string{loopDriverArgoName}),
When: when,
Arguments: wfapi.Arguments{
Parameters: []wfapi.Parameter{
{
Name: paramParentDagID,
Value: wfapi.AnyStringPtr(loopDriverOutputs.executionID),
},
},
},
},
}
return tasks, nil
}

func (c *workflowCompiler) iterationItemTask(name string, task *pipelinespec.PipelineTaskSpec, taskJson string, parentDagID string) (tasks []wfapi.DAGTask, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("iterationItem task: %w", err)
}
}()
componentName := task.GetComponentRef().GetName()
componentSpecPlaceholder, err := c.useComponentSpec(componentName)
if err != nil {
return nil, err
}

// Set up Iteration (Single Task) Template
driverArgoName := name + "-driver"
driverInputs := dagDriverInputs{
component: componentSpecPlaceholder,
Expand All @@ -289,9 +357,10 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
return nil, err
}
driver.Depends = depends(task.GetDependentTasks())

iterationCount := intstr.FromString(driverOutputs.iterationCount)
iterationTasks, err := c.task(
"iteration",
"iteration-item",
task,
taskInputs{
parentDagID: inputParameter(paramParentDagID),
Expand Down Expand Up @@ -320,7 +389,8 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
if task.GetTriggerPolicy().GetCondition() != "" {
when = driverOutputs.condition + " != false"
}
tasks = []wfapi.DAGTask{

iteratorTasks := []wfapi.DAGTask{
*driver,
{
Name: name + "-iterations",
Expand All @@ -339,7 +409,7 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
WithSequence: &wfapi.Sequence{Count: &iterationCount},
},
}
return tasks, nil
return iteratorTasks, nil
}

type dagDriverOutputs struct {
Expand Down

0 comments on commit 9408d36

Please sign in to comment.