Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Render task template in the agent client #384

Merged
merged 5 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/utils/strings/slices"
)

type MockPlugin struct {
Expand All @@ -39,7 +40,11 @@ type MockPlugin struct {
type MockClient struct {
}

func (m *MockClient) CreateTask(_ context.Context, _ *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}
if slices.Equal(createTaskRequest.Template.GetContainer().Args, expectedArgs) {
return nil, fmt.Errorf("args not as expected")
}
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
}

Expand Down Expand Up @@ -95,6 +100,9 @@ func TestEndToEnd(t *testing.T) {
template := flyteIdlCore.TaskTemplate{
Type: "bigquery_query_job_task",
Custom: st,
Target: &flyteIdlCore.TaskTemplate_Container{
Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "{{.outputPrefix}}"}},
},
}
basePrefix := storage.DataReference("fake://bucket/prefix/")

Expand Down
21 changes: 17 additions & 4 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flytestdlib/promutils"
Expand Down Expand Up @@ -68,6 +68,19 @@
return nil, nil, err
}

if taskTemplate.GetContainer() != nil {
templateParameters := template.Parameters{
TaskExecMetadata: taskCtx.TaskExecutionMetadata(),
Inputs: taskCtx.InputReader(),
OutputPath: taskCtx.OutputWriter(),
Task: taskCtx.TaskReader(),
}
modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters)
if err != nil {
return nil, nil, err
}

Check warning on line 81 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L80-L81

Added lines #L80 - L81 were not covered by tests
taskTemplate.GetContainer().Args = modifiedArgs
}
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
Expand Down Expand Up @@ -150,7 +163,7 @@

switch resource.State {
case admin.State_RUNNING:
return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil
return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil
case admin.State_PERMANENT_FAILURE:
return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case admin.State_RETRYABLE_FAILURE:
Expand All @@ -164,7 +177,7 @@
}
return core.PhaseInfoSuccess(taskInfo), nil
}
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State)

Check warning on line 180 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L180

Added line #L180 was not covered by tests
}

func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
Expand Down Expand Up @@ -225,7 +238,7 @@
return service.NewAsyncAgentServiceClient(conn), nil
}

func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata) admin.TaskExecutionMetadata {
func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata {
taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID()
return admin.TaskExecutionMetadata{
TaskExecutionId: &taskExecutionID,
Expand Down
Loading