Skip to content

Commit

Permalink
fix(backend): implement subdag output resolution
Browse files Browse the repository at this point in the history
Signed-off-by: droctothorpe <[email protected]>
Co-authored-by: zazulam <[email protected]>
Co-authored-by: CarterFendley <[email protected]>
  • Loading branch information
3 people committed Sep 11, 2024
1 parent 1cded35 commit af3c3e1
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 10 deletions.
111 changes: 104 additions & 7 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"strconv"
"time"

"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

"github.com/golang/glog"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/google/uuid"
Expand Down Expand Up @@ -125,6 +126,8 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio
err = fmt.Errorf("driver.RootDAG(%s) failed: %w", opts.info(), err)
}
}()
b, _ := json.Marshal(opts)
glog.V(4).Info("RootDAG opts: ", string(b))
err = validateRootDAG(opts)
if err != nil {
return nil, err
Expand Down Expand Up @@ -230,6 +233,8 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
err = fmt.Errorf("driver.Container(%s) failed: %w", opts.info(), err)
}
}()
b, _ := json.Marshal(opts)
glog.V(4).Info("Container opts: ", string(b))
err = validateContainer(opts)
if err != nil {
return nil, err
Expand Down Expand Up @@ -699,6 +704,8 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
err = fmt.Errorf("driver.DAG(%s) failed: %w", opts.info(), err)
}
}()
b, _ := json.Marshal(opts)
glog.V(4).Info("DAG opts: ", string(b))
err = validateDAG(opts)
if err != nil {
return nil, err
Expand Down Expand Up @@ -749,6 +756,27 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
ecfg.ParentDagID = dag.Execution.GetID()
ecfg.IterationIndex = iterationIndex
ecfg.NotTriggered = !execution.WillTrigger()

outputParameters := opts.Component.GetDag().GetOutputs().GetParameters()
glog.V(4).Info("outputParameters: ", outputParameters)
for _, value := range outputParameters {
outputParameterKey := value.GetValueFromParameter().OutputParameterKey
producerSubTask := value.GetValueFromParameter().ProducerSubtask
glog.V(4).Info("outputParameterKey: ", outputParameterKey)
glog.V(4).Info("producerSubtask: ", producerSubTask)

outputParameterMap := map[string]interface{}{
"output_parameter_key": outputParameterKey,
"producer_subtask": producerSubTask,
}

outputParameterStruct, _ := structpb.NewValue(outputParameterMap)

ecfg.OutputParameters = map[string]*structpb.Value{
value.GetValueFromParameter().OutputParameterKey: outputParameterStruct,
}
}

if opts.Task.GetArtifactIterator() != nil {
return execution, fmt.Errorf("ArtifactIterator is not implemented")
}
Expand Down Expand Up @@ -793,6 +821,12 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
ecfg.IterationCount = &count
execution.IterationCount = &count
}

glog.V(4).Info("pipeline: ", pipeline)
b, _ = json.Marshal(*ecfg)
glog.V(4).Info("ecfg: ", string(b))
glog.V(4).Infof("dag: %v", dag)

// TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started.
createdExecution, err := mlmd.CreateExecution(ctx, pipeline, ecfg)
if err != nil {
Expand Down Expand Up @@ -939,6 +973,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
err = fmt.Errorf("failed to resolve inputs: %w", err)
}
}()
glog.V(4).Infof("dag: %v", dag)
glog.V(4).Infof("task: %v", task)
inputParams, _, err := dag.Execution.GetParameters()
if err != nil {
return nil, err
Expand Down Expand Up @@ -1112,10 +1148,31 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
if err != nil {
return nil, err
}
// TODO: Make this recursive.
for _, v := range tasks {
if v.GetExecution().GetType() == "system.DAGExecution" {
glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName())
dag, err := mlmd.GetDAG(ctx, v.GetExecution().GetId())
if err != nil {
return nil, err
}
subdagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline)
if err != nil {
return nil, err
}
for k, v := range subdagTasks {
tasks[k] = v
}
}
}
tasksCache = tasks

return tasks, nil
}

for name, paramSpec := range task.GetInputs().GetParameters() {
glog.V(4).Infof("name: %v", name)
glog.V(4).Infof("paramSpec: %v", paramSpec)
paramError := func(err error) error {
return fmt.Errorf("resolving input parameter %s with spec %s: %w", name, paramSpec, err)
}
Expand All @@ -1131,8 +1188,11 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
}
inputs.ParameterValues[name] = v

// This is the case where we are consuming an output parameter from an
// upstream task. That task can be a container or a DAG.
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
taskOutput := paramSpec.GetTaskOutputParameter()
glog.V(4).Info("taskOutput: ", taskOutput)
if taskOutput.GetProducerTask() == "" {
return nil, paramError(fmt.Errorf("producer task is empty"))
}
Expand All @@ -1143,19 +1203,56 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
if err != nil {
return nil, paramError(err)
}

// The producer is the task that produces the output that we need to
// consume.
producer, ok := tasks[taskOutput.GetProducerTask()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()))
}
_, outputs, err := producer.GetParameters()

glog.V(4).Info("producer: ", producer)

// Get the producer's outputs.
_, producerOutputs, err := producer.GetParameters()
if err != nil {
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
}
param, ok := outputs[taskOutput.GetOutputParameterKey()]
glog.V(4).Info("producer output parameters: ", producerOutputs)
// Deserialize them.
var producerOutputsMap map[string]string
b, err := producerOutputs["Output"].GetStructValue().MarshalJSON()
if err != nil {
return nil, err
}
json.Unmarshal(b, &producerOutputsMap)
glog.V(4).Info("producerOutputsMap: ", producerOutputsMap)

// If the producer's output includes a producer subtask, which means
// that the producer is a DAG that is getting its output from one of
// the tasks in the DAG, then, we want to roll up the output
// from the producer subtask to the producer, so that the downstream
// logic can retrieve it appropriately.
if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok {
glog.V(4).Infof(
"Overriding producer task, %v, output with producer_subtask, %v, output.",
producer.TaskName(),
producerSubTask,
)
_, producerOutputs, err = tasks[producerSubTask].GetParameters()
if err != nil {
return nil, err
}
glog.V(4).Info("producerSubTask output parameters: ", producerOutputs)
// The only reason we're updating this is to make the downstream
// logging more accurate.
taskOutput.ProducerTask = producerOutputsMap["producer_subtask"]
}

// Grab the value of the producer output.
producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
}
inputs.ParameterValues[name] = param
// Update the input to be the producer output value.
inputs.ParameterValues[name] = producerOutputValue
case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
runtimeValue := paramSpec.GetRuntimeValue()
switch t := runtimeValue.Value.(type) {
Expand Down
16 changes: 13 additions & 3 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/kubeflow/pipelines/backend/src/common/util"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"path"
"strconv"
"strings"
"sync"
"time"

"github.com/kubeflow/pipelines/backend/src/common/util"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"

"github.com/golang/glog"
Expand Down Expand Up @@ -134,6 +135,7 @@ type ExecutionConfig struct {
NotTriggered bool // optional, not triggered executions will have CANCELED state.
ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG.
InputParameters map[string]*structpb.Value
OutputParameters map[string]*structpb.Value
InputArtifactIDs map[string][]int64
IterationIndex *int // Index of the iteration.

Expand Down Expand Up @@ -448,6 +450,8 @@ func getArtifactName(eventPath *pb.Event_Path) (string, error) {
func (c *Client) PublishExecution(ctx context.Context, execution *Execution, outputParameters map[string]*structpb.Value, outputArtifacts []*OutputArtifact, state pb.Execution_State) error {
e := execution.execution
e.LastKnownState = state.Enum()
glog.V(4).Infof("outputParameters: %v", outputParameters)
glog.V(4).Infof("outputArtifacts: %v", outputArtifacts)

if outputParameters != nil {
// Record output parameters.
Expand Down Expand Up @@ -576,7 +580,13 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
},
}}
}

if config.OutputParameters != nil {
e.CustomProperties[keyOutputs] = &pb.Value{Value: &pb.Value_StructValue{
StructValue: &structpb.Struct{
Fields: config.OutputParameters,
},
}}
}
req := &pb.PutExecutionRequest{
Execution: e,
Contexts: []*pb.Context{pipeline.pipelineCtx, pipeline.pipelineRunCtx},
Expand Down

0 comments on commit af3c3e1

Please sign in to comment.