From 8b275799bc72695bcc003d3a4de62608c3f8e273 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Mon, 3 Jul 2023 12:00:57 +0200 Subject: [PATCH] Support gRPC config for agent-service plugin Signed-off-by: Hongxin Liang --- go/tasks/plugins/webapi/agent/config.go | 27 +++++++-- go/tasks/plugins/webapi/agent/config_test.go | 15 +++++ .../plugins/webapi/agent/integration_test.go | 4 +- go/tasks/plugins/webapi/agent/plugin.go | 59 +++++++++++++++---- go/tasks/plugins/webapi/agent/plugin_test.go | 30 +++++++--- 5 files changed, 111 insertions(+), 24 deletions(-) diff --git a/go/tasks/plugins/webapi/agent/config.go b/go/tasks/plugins/webapi/agent/config.go index 14993b240..d0baff692 100644 --- a/go/tasks/plugins/webapi/agent/config.go +++ b/go/tasks/plugins/webapi/agent/config.go @@ -10,6 +10,8 @@ import ( ) var ( + defaultTimeout = config.Duration{Duration: 10 * time.Second} + defaultConfig = Config{ WebAPI: webapi.PluginConfig{ ResourceQuotas: map[core.ResourceNamespace]int{ @@ -39,8 +41,11 @@ var ( Value: 50, }, }, - DefaultGrpcEndpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80", - SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, + DefaultGrpcEndpoint: GrpcEndpoint{ + Endpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80", + Insecure: true, + }, + SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, } configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig) @@ -54,15 +59,29 @@ type Config struct { // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` - DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."` + DefaultGrpcEndpoint GrpcEndpoint `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."` // Maps endpoint to their plugin handler. {TaskType: Endpoint} - EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"` + EndpointForTaskTypes map[string]GrpcEndpoint `json:"endpointForTaskTypes" pflag:"-,"` // SupportedTaskTypes is a list of task types that are supported by this plugin. SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` } +type GrpcEndpoint struct { + // Endpoint points to a gRPC service + Endpoint string `json:"endpoint"` + + // Insecure indicates whether the communication with the gRPC service is insecure + Insecure bool `json:"insecure"` + + // DefaultServiceConfig sets default gRPC service config; check https://github.com/grpc/grpc/blob/master/doc/service_config.md for more details + DefaultServiceConfig string `json:"defaultServiceConfig"` + + // Timeouts defines various RPC timeout values for different plugin operations: CreateTask, GetTask, DeleteTask; if not configured, defaults to 10s + Timeouts map[string]config.Duration `json:"timeouts"` +} + func GetConfig() *Config { return configSection.GetConfig().(*Config) } diff --git a/go/tasks/plugins/webapi/agent/config_test.go b/go/tasks/plugins/webapi/agent/config_test.go index e7201a2b9..a69cfc92f 100644 --- a/go/tasks/plugins/webapi/agent/config_test.go +++ b/go/tasks/plugins/webapi/agent/config_test.go @@ -4,6 +4,8 @@ import ( "testing" "time" + "github.com/flyteorg/flytestdlib/config" + "github.com/stretchr/testify/assert" ) @@ -11,6 +13,19 @@ func TestGetAndSetConfig(t *testing.T) { cfg := defaultConfig cfg.WebAPI.Caching.Workers = 1 cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultGrpcEndpoint.Insecure = false + cfg.DefaultGrpcEndpoint.DefaultServiceConfig = "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}" + cfg.DefaultGrpcEndpoint.Timeouts = map[string]config.Duration{ + "CreateTask": { + Duration: 1 * time.Millisecond, + }, + "GetTask": { + Duration: 2 * time.Millisecond, + }, + "DeleteTask": { + Duration: 3 * time.Millisecond, + }, + } err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, &cfg, GetConfig()) diff --git a/go/tasks/plugins/webapi/agent/integration_test.go b/go/tasks/plugins/webapi/agent/integration_test.go index 30036ede7..4fff5f977 100644 --- a/go/tasks/plugins/webapi/agent/integration_test.go +++ b/go/tasks/plugins/webapi/agent/integration_test.go @@ -55,11 +55,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ return &admin.DeleteTaskResponse{}, nil } -func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func mockGetClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { return &MockClient{}, nil } -func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func mockGetBadClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { return nil, fmt.Errorf("error") } diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 70a335021..f0578225e 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -2,10 +2,14 @@ package agent import ( "context" + "crypto/x509" "encoding/gob" "fmt" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flytestdlib/config" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" @@ -21,7 +25,7 @@ import ( "google.golang.org/grpc" ) -type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) +type GetClientFunc func(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) type Plugin struct { metricScope promutils.Scope @@ -72,7 +76,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) } - res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix}) + newCtx, cancel := context.WithTimeout(ctx, getFinalTimeout("CreateTask", endpoint.Timeouts).Duration) + defer cancel() + + res, err := client.CreateTask(newCtx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix}) if err != nil { return nil, nil, err } @@ -94,7 +101,10 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba return nil, fmt.Errorf("failed to connect to agent with error: %v", err) } - res, err := client.GetTask(ctx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) + newCtx, cancel := context.WithTimeout(ctx, getFinalTimeout("GetTask", endpoint.Timeouts).Duration) + defer cancel() + + res, err := client.GetTask(newCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) if err != nil { return nil, err } @@ -117,7 +127,10 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return fmt.Errorf("failed to connect to agent with error: %v", err) } - _, err = client.DeleteTask(ctx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) + newCtx, cancel := context.WithTimeout(ctx, getFinalTimeout("DeleteTask", endpoint.Timeouts).Duration) + defer cancel() + + _, err = client.DeleteTask(newCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) return err } @@ -144,7 +157,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State) } -func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map[string]string) string { +func getFinalEndpoint(taskType string, defaultEndpoint GrpcEndpoint, endpointForTaskTypes map[string]GrpcEndpoint) GrpcEndpoint { if t, exists := endpointForTaskTypes[taskType]; exists { return t } @@ -152,20 +165,36 @@ func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map return defaultEndpoint } -func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, ok := connectionCache[endpoint] +func getClientFunc(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { + conn, ok := connectionCache[endpoint.Endpoint] if ok { return service.NewAsyncAgentServiceClient(conn), nil } + var opts []grpc.DialOption - var err error - opts = append(opts, grpc.WithInsecure()) - conn, err = grpc.Dial(endpoint, opts...) + if endpoint.Insecure { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } else { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + creds := credentials.NewClientTLSFromCert(pool, "") + opts = append(opts, grpc.WithTransportCredentials(creds)) + } + + if endpoint.DefaultServiceConfig != "" { + opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig)) + } + + var err error + conn, err = grpc.Dial(endpoint.Endpoint, opts...) if err != nil { return nil, err } - connectionCache[endpoint] = conn + connectionCache[endpoint.Endpoint] = conn defer func() { if err != nil { if cerr := conn.Close(); cerr != nil { @@ -183,6 +212,14 @@ func getClientFunc(ctx context.Context, endpoint string, connectionCache map[str return service.NewAsyncAgentServiceClient(conn), nil } +func getFinalTimeout(operation string, timeouts map[string]config.Duration) config.Duration { + if t, exists := timeouts[operation]; exists { + return t + } + + return defaultTimeout +} + func newAgentPlugin() webapi.PluginEntry { supportedTaskTypes := GetConfig().SupportedTaskTypes diff --git a/go/tasks/plugins/webapi/agent/plugin_test.go b/go/tasks/plugins/webapi/agent/plugin_test.go index 174115eea..238641a40 100644 --- a/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/go/tasks/plugins/webapi/agent/plugin_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/flyteorg/flytestdlib/config" + "google.golang.org/grpc" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -25,8 +27,8 @@ func TestPlugin(t *testing.T) { cfg := defaultConfig cfg.WebAPI.Caching.Workers = 1 cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - cfg.DefaultGrpcEndpoint = "test-agent.flyte.svc.cluster.local:80" - cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"} + cfg.DefaultGrpcEndpoint = GrpcEndpoint{Endpoint: "test-agent.flyte.svc.cluster.local:80"} + cfg.EndpointForTaskTypes = map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}} err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) @@ -46,15 +48,29 @@ func TestPlugin(t *testing.T) { }) t.Run("test getFinalEndpoint", func(t *testing.T) { - endpoint := getFinalEndpoint("spark", "localhost:8080", map[string]string{"spark": "localhost:80"}) - assert.Equal(t, endpoint, "localhost:80") - endpoint = getFinalEndpoint("spark", "localhost:8080", map[string]string{}) - assert.Equal(t, endpoint, "localhost:8080") + defaultGrpcEndpoint := GrpcEndpoint{Endpoint: "localhost:8080"} + endpoint := getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}}) + assert.Equal(t, endpoint.Endpoint, "localhost:80") + endpoint = getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{}) + assert.Equal(t, endpoint.Endpoint, "localhost:8080") }) t.Run("test getClientFunc", func(t *testing.T) { - client, err := getClientFunc(context.Background(), "localhost:80", map[string]*grpc.ClientConn{}) + client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80"}, map[string]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) }) + + t.Run("test getClientFunc more config", func(t *testing.T) { + client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[string]*grpc.ClientConn{}) + assert.NoError(t, err) + assert.NotNil(t, client) + }) + + t.Run("test getFinalTimeout", func(t *testing.T) { + timeout := getFinalTimeout("CreateTask", map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}) + assert.Equal(t, timeout.Duration, 1*time.Millisecond) + timeout = getFinalTimeout("DeleteTask", map[string]config.Duration{}) + assert.Equal(t, timeout.Duration, 10*time.Second) + }) }