diff --git a/go/tasks/plugins/webapi/agent/integration_test.go b/go/tasks/plugins/webapi/agent/integration_test.go index a3990e631..0aeed67f5 100644 --- a/go/tasks/plugins/webapi/agent/integration_test.go +++ b/go/tasks/plugins/webapi/agent/integration_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/mock" "google.golang.org/grpc" "k8s.io/apimachinery/pkg/util/rand" + "k8s.io/utils/strings/slices" ) type MockPlugin struct { @@ -39,7 +40,11 @@ type MockPlugin struct { type MockClient struct { } -func (m *MockClient) CreateTask(_ context.Context, _ *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) { +func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) { + expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"} + if slices.Equal(createTaskRequest.Template.GetContainer().Args, expectedArgs) { + return nil, fmt.Errorf("args not as expected") + } return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil } @@ -95,6 +100,9 @@ func TestEndToEnd(t *testing.T) { template := flyteIdlCore.TaskTemplate{ Type: "bigquery_query_job_task", Custom: st, + Target: &flyteIdlCore.TaskTemplate_Container{ + Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "{{.outputPrefix}}"}}, + }, } basePrefix := storage.DataReference("fake://bucket/prefix/") diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index fd497e11c..128753b74 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -18,7 +18,7 @@ import ( pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" "github.com/flyteorg/flytestdlib/promutils" @@ -68,6 +68,19 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR return nil, nil, err } + if taskTemplate.GetContainer() != nil { + templateParameters := template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + } + modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters) + if err != nil { + return nil, nil, err + } + taskTemplate.GetContainer().Args = modifiedArgs + } outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() agent, err := getFinalAgent(taskTemplate.Type, p.cfg) @@ -150,7 +163,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase switch resource.State { case admin.State_RUNNING: - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil + return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil case admin.State_PERMANENT_FAILURE: return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil case admin.State_RETRYABLE_FAILURE: @@ -164,7 +177,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase } return core.PhaseInfoSuccess(taskInfo), nil } - return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State) + return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State) } func getFinalAgent(taskType string, cfg *Config) (*Agent, error) { @@ -225,7 +238,7 @@ func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent return service.NewAsyncAgentServiceClient(conn), nil } -func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata) admin.TaskExecutionMetadata { +func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata { taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() return admin.TaskExecutionMetadata{ TaskExecutionId: &taskExecutionID,