diff --git a/go/tasks/plugins/webapi/agent/config.go b/go/tasks/plugins/webapi/agent/config.go index 0adc12cbd..d8f7c10e8 100644 --- a/go/tasks/plugins/webapi/agent/config.go +++ b/go/tasks/plugins/webapi/agent/config.go @@ -58,10 +58,14 @@ type Config struct { // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` + // The default grpc endpoint if there does not exist a more specific matching against task types DefaultGrpcEndpoint GrpcEndpoint `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."` + // The grpc endpoints of agent services, which are used to match against specific task types + GrpcEndpoints map[string]*GrpcEndpoint `json:"grpcEndpoints" pflag:",The grpc endpoints of agent services."` + // Maps endpoint to their plugin handler. {TaskType: Endpoint} - EndpointForTaskTypes map[string]GrpcEndpoint `json:"endpointForTaskTypes" pflag:"-,"` + EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"` // SupportedTaskTypes is a list of task types that are supported by this plugin. SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` diff --git a/go/tasks/plugins/webapi/agent/config_test.go b/go/tasks/plugins/webapi/agent/config_test.go index c7f8634fb..76d7ddd44 100644 --- a/go/tasks/plugins/webapi/agent/config_test.go +++ b/go/tasks/plugins/webapi/agent/config_test.go @@ -27,6 +27,14 @@ func TestGetAndSetConfig(t *testing.T) { }, } cfg.DefaultGrpcEndpoint.DefaultTimeout = config.Duration{Duration: 10 * time.Second} + cfg.GrpcEndpoints = map[string]*GrpcEndpoint{ + "endpoint_1": { + Insecure: cfg.DefaultGrpcEndpoint.Insecure, + DefaultServiceConfig: cfg.DefaultGrpcEndpoint.DefaultServiceConfig, + Timeouts: cfg.DefaultGrpcEndpoint.Timeouts, + }, + } + cfg.EndpointForTaskTypes = map[string]string{"task_type_1": "endpoint_1"} err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, &cfg, GetConfig()) diff --git a/go/tasks/plugins/webapi/agent/integration_test.go b/go/tasks/plugins/webapi/agent/integration_test.go index 4fff5f977..83978aae2 100644 --- a/go/tasks/plugins/webapi/agent/integration_test.go +++ b/go/tasks/plugins/webapi/agent/integration_test.go @@ -55,11 +55,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ return &admin.DeleteTaskResponse{}, nil } -func mockGetClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func mockGetClientFunc(_ context.Context, _ *GrpcEndpoint, _ map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { return &MockClient{}, nil } -func mockGetBadClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func mockGetBadClientFunc(_ context.Context, _ *GrpcEndpoint, _ map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { return nil, fmt.Errorf("error") } diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 711a82d40..c87655b6a 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -25,13 +25,13 @@ import ( "google.golang.org/grpc" ) -type GetClientFunc func(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) +type GetClientFunc func(ctx context.Context, endpoint *GrpcEndpoint, connectionCache map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) type Plugin struct { metricScope promutils.Scope cfg *Config getClient GetClientFunc - connectionCache map[string]*grpc.ClientConn + connectionCache map[*GrpcEndpoint]*grpc.ClientConn } type ResourceWrapper struct { @@ -70,7 +70,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() - endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg) + if err != nil { + return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err) + } client, err := p.getClient(ctx, endpoint, p.connectionCache) if err != nil { return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) @@ -96,7 +99,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg) + if err != nil { + return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err) + } client, err := p.getClient(ctx, endpoint, p.connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent with error: %v", err) @@ -122,7 +128,10 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg) + if err != nil { + return fmt.Errorf("failed to find agent endpoint with error: %v", err) + } client, err := p.getClient(ctx, endpoint, p.connectionCache) if err != nil { return fmt.Errorf("failed to connect to agent with error: %v", err) @@ -158,16 +167,19 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State) } -func getFinalEndpoint(taskType string, defaultEndpoint GrpcEndpoint, endpointForTaskTypes map[string]GrpcEndpoint) GrpcEndpoint { - if t, exists := endpointForTaskTypes[taskType]; exists { - return t +func getFinalEndpoint(taskType string, cfg *Config) (*GrpcEndpoint, error) { + if id, exists := cfg.EndpointForTaskTypes[taskType]; exists { + if endpoint, exists := cfg.GrpcEndpoints[id]; exists { + return endpoint, nil + } + return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType) } - return defaultEndpoint + return &cfg.DefaultGrpcEndpoint, nil } -func getClientFunc(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, ok := connectionCache[endpoint.Endpoint] +func getClientFunc(ctx context.Context, endpoint *GrpcEndpoint, connectionCache map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { + conn, ok := connectionCache[endpoint] if ok { return service.NewAsyncAgentServiceClient(conn), nil } @@ -195,7 +207,7 @@ func getClientFunc(ctx context.Context, endpoint GrpcEndpoint, connectionCache m if err != nil { return nil, err } - connectionCache[endpoint.Endpoint] = conn + connectionCache[endpoint] = conn defer func() { if err != nil { if cerr := conn.Close(); cerr != nil { @@ -233,7 +245,7 @@ func getFinalTimeout(operation string, endpoint *GrpcEndpoint) config.Duration { return endpoint.DefaultTimeout } -func getFinalContext(ctx context.Context, operation string, endpoint GrpcEndpoint) (context.Context, context.CancelFunc) { +func getFinalContext(ctx context.Context, operation string, endpoint *GrpcEndpoint) (context.Context, context.CancelFunc) { timeout := getFinalTimeout(operation, endpoint).Duration if timeout == 0 { return ctx, func() {} @@ -252,7 +264,7 @@ func newAgentPlugin() webapi.PluginEntry { metricScope: iCtx.MetricsScope(), cfg: GetConfig(), getClient: getClientFunc, - connectionCache: make(map[string]*grpc.ClientConn), + connectionCache: make(map[*GrpcEndpoint]*grpc.ClientConn), }, nil }, } diff --git a/go/tasks/plugins/webapi/agent/plugin_test.go b/go/tasks/plugins/webapi/agent/plugin_test.go index 0a0d3f94a..75c879023 100644 --- a/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/go/tasks/plugins/webapi/agent/plugin_test.go @@ -19,16 +19,18 @@ func TestPlugin(t *testing.T) { fakeSetupContext := pluginCoreMocks.SetupContext{} fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) + cfg := defaultConfig + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultGrpcEndpoint = GrpcEndpoint{Endpoint: "test-agent.flyte.svc.cluster.local:80"} + cfg.GrpcEndpoints = map[string]*GrpcEndpoint{"spark_agent": {Endpoint: "localhost:80"}} + cfg.EndpointForTaskTypes = map[string]string{"spark": "spark_agent", "bar": "bar_agent"} + plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), } t.Run("get config", func(t *testing.T) { - cfg := defaultConfig - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - cfg.DefaultGrpcEndpoint = GrpcEndpoint{Endpoint: "test-agent.flyte.svc.cluster.local:80"} - cfg.EndpointForTaskTypes = map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}} err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) @@ -48,37 +50,53 @@ func TestPlugin(t *testing.T) { }) t.Run("test getFinalEndpoint", func(t *testing.T) { - defaultGrpcEndpoint := GrpcEndpoint{Endpoint: "localhost:8080"} - endpoint := getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}}) - assert.Equal(t, "localhost:80", endpoint.Endpoint) - endpoint = getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{}) - assert.Equal(t, "localhost:8080", endpoint.Endpoint) + endpoint, _ := getFinalEndpoint("spark", &cfg) + assert.Equal(t, cfg.GrpcEndpoints["spark_agent"].Endpoint, endpoint.Endpoint) + endpoint, _ = getFinalEndpoint("foo", &cfg) + assert.Equal(t, cfg.DefaultGrpcEndpoint.Endpoint, endpoint.Endpoint) + _, err := getFinalEndpoint("bar", &cfg) + assert.NotNil(t, err) }) t.Run("test getClientFunc", func(t *testing.T) { - client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80"}, map[string]*grpc.ClientConn{}) + client, err := getClientFunc(context.Background(), &GrpcEndpoint{Endpoint: "localhost:80"}, map[*GrpcEndpoint]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) }) t.Run("test getClientFunc more config", func(t *testing.T) { - client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[string]*grpc.ClientConn{}) + client, err := getClientFunc(context.Background(), &GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*GrpcEndpoint]*grpc.ClientConn{}) + assert.NoError(t, err) + assert.NotNil(t, client) + }) + + t.Run("test getClientFunc cache hit", func(t *testing.T) { + connectionCache := make(map[*GrpcEndpoint]*grpc.ClientConn) + endpoint := &GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} + + client, err := getClientFunc(context.Background(), endpoint, connectionCache) assert.NoError(t, err) assert.NotNil(t, client) + assert.NotNil(t, client, connectionCache[endpoint]) + + cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache) + assert.NoError(t, err) + assert.NotNil(t, cachedClient) + assert.Equal(t, client, cachedClient) }) t.Run("test getFinalTimeout", func(t *testing.T) { - timeout := getFinalTimeout("CreateTask", GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) + timeout := getFinalTimeout("CreateTask", &GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.Equal(t, 1*time.Millisecond, timeout.Duration) - timeout = getFinalTimeout("DeleteTask", GrpcEndpoint{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}}) + timeout = getFinalTimeout("DeleteTask", &GrpcEndpoint{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}}) assert.Equal(t, 10*time.Second, timeout.Duration) }) t.Run("test getFinalContext", func(t *testing.T) { - ctx, _ := getFinalContext(context.TODO(), "DeleteTask", GrpcEndpoint{}) + ctx, _ := getFinalContext(context.TODO(), "DeleteTask", &GrpcEndpoint{}) assert.Equal(t, context.TODO(), ctx) - ctx, _ = getFinalContext(context.TODO(), "CreateTask", GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) + ctx, _ = getFinalContext(context.TODO(), "CreateTask", &GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.NotEqual(t, context.TODO(), ctx) }) }