diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 7ad4fbbd5..e881ef10e 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -444,8 +444,11 @@ func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *ad if resource != nil && resource.Attributes.GetClusterAssignment() != nil { return resource.Attributes.GetClusterAssignment(), nil } - // Defaults to empty assignment with no selectors - return &admin.ClusterAssignment{}, nil + clusterPoolAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[request.GetDomain()] + + return &admin.ClusterAssignment{ + ClusterPoolName: clusterPoolAssignment.Pool, + }, nil } func (m *ExecutionManager) launchSingleTaskExecution( diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 73fabbc40..34d436ee9 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -5440,6 +5440,30 @@ func TestGetClusterAssignment(t *testing.T) { assert.NoError(t, err) assert.True(t, proto.Equal(ca, &reqClusterAssignment)) }) + t.Run("value from config", func(t *testing.T) { + customCP := "my_cp" + clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{} + clusterPoolAsstProvider.OnGetClusterPoolAssignments().Return(runtimeInterfaces.ClusterPoolAssignments{ + workflowIdentifier.GetDomain(): runtimeInterfaces.ClusterPoolAssignment{ + Pool: customCP, + }, + }) + mockConfig := getMockExecutionsConfigProvider() + mockConfig.(*runtimeMocks.MockConfigurationProvider).AddClusterPoolAssignmentConfiguration(clusterPoolAsstProvider) + + executionManager := ExecutionManager{ + resourceManager: &managerMocks.MockResourceManager{}, + config: mockConfig, + } + + ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{}, + }) + assert.NoError(t, err) + assert.Equal(t, customCP, ca.GetClusterPoolName()) + }) } func TestResolvePermissions(t *testing.T) { diff --git a/pkg/runtime/cluster_pool_assignment_provider.go b/pkg/runtime/cluster_pool_assignment_provider.go new file mode 100644 index 000000000..b6e2a406a --- /dev/null +++ b/pkg/runtime/cluster_pool_assignment_provider.go @@ -0,0 +1,24 @@ +package runtime + +import ( + "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + + "github.com/flyteorg/flytestdlib/config" +) + +const clusterPoolsKey = "clusterPools" + +var clusterPoolsConfig = config.MustRegisterSection(clusterPoolsKey, &interfaces.ClusterPoolAssignmentConfig{ + ClusterPoolAssignments: make(interfaces.ClusterPoolAssignments), +}) + +// Implementation of an interfaces.ClusterPoolAssignmentConfiguration +type ClusterPoolAssignmentConfigurationProvider struct{} + +func (p *ClusterPoolAssignmentConfigurationProvider) GetClusterPoolAssignments() interfaces.ClusterPoolAssignments { + return clusterPoolsConfig.GetConfig().(*interfaces.ClusterPoolAssignmentConfig).ClusterPoolAssignments +} + +func NewClusterPoolAssignmentConfigurationProvider() interfaces.ClusterPoolAssignmentConfiguration { + return &ClusterPoolAssignmentConfigurationProvider{} +} diff --git a/pkg/runtime/configuration_provider.go b/pkg/runtime/configuration_provider.go index 73a5c8d81..217d20dc5 100644 --- a/pkg/runtime/configuration_provider.go +++ b/pkg/runtime/configuration_provider.go @@ -15,6 +15,7 @@ type ConfigurationProvider struct { clusterResourceConfiguration interfaces.ClusterResourceConfiguration namespaceMappingConfiguration interfaces.NamespaceMappingConfiguration qualityOfServiceConfiguration interfaces.QualityOfServiceConfiguration + clusterPoolAssignmentConfiguration interfaces.ClusterPoolAssignmentConfiguration } func (p *ConfigurationProvider) ApplicationConfiguration() interfaces.ApplicationConfiguration { @@ -53,6 +54,10 @@ func (p *ConfigurationProvider) QualityOfServiceConfiguration() interfaces.Quali return p.qualityOfServiceConfiguration } +func (p *ConfigurationProvider) ClusterPoolAssignmentConfiguration() interfaces.ClusterPoolAssignmentConfiguration { + return p.clusterPoolAssignmentConfiguration +} + func NewConfigurationProvider() interfaces.Configuration { return &ConfigurationProvider{ applicationConfiguration: NewApplicationConfigurationProvider(), @@ -64,5 +69,6 @@ func NewConfigurationProvider() interfaces.Configuration { clusterResourceConfiguration: NewClusterResourceConfigurationProvider(), namespaceMappingConfiguration: NewNamespaceMappingConfigurationProvider(), qualityOfServiceConfiguration: NewQualityOfServiceConfigProvider(), + clusterPoolAssignmentConfiguration: NewClusterPoolAssignmentConfigurationProvider(), } } diff --git a/pkg/runtime/interfaces/cluster_pools.go b/pkg/runtime/interfaces/cluster_pools.go new file mode 100644 index 000000000..0392c8254 --- /dev/null +++ b/pkg/runtime/interfaces/cluster_pools.go @@ -0,0 +1,17 @@ +package interfaces + +//go:generate mockery -name ClusterPoolAssignmentConfiguration -output=mocks -case=underscore + +type ClusterPoolAssignment struct { + Pool string `json:"pool"` +} + +type ClusterPoolAssignments = map[DomainName]ClusterPoolAssignment + +type ClusterPoolAssignmentConfig struct { + ClusterPoolAssignments ClusterPoolAssignments `json:"clusterPoolAssignments"` +} + +type ClusterPoolAssignmentConfiguration interface { + GetClusterPoolAssignments() ClusterPoolAssignments +} diff --git a/pkg/runtime/interfaces/configuration.go b/pkg/runtime/interfaces/configuration.go index 7fe694be0..a3272cbcb 100644 --- a/pkg/runtime/interfaces/configuration.go +++ b/pkg/runtime/interfaces/configuration.go @@ -11,4 +11,5 @@ type Configuration interface { ClusterResourceConfiguration() ClusterResourceConfiguration NamespaceMappingConfiguration() NamespaceMappingConfiguration QualityOfServiceConfiguration() QualityOfServiceConfiguration + ClusterPoolAssignmentConfiguration() ClusterPoolAssignmentConfiguration } diff --git a/pkg/runtime/interfaces/mocks/cluster_pool_assignment_configuration.go b/pkg/runtime/interfaces/mocks/cluster_pool_assignment_configuration.go new file mode 100644 index 000000000..6e0c1719d --- /dev/null +++ b/pkg/runtime/interfaces/mocks/cluster_pool_assignment_configuration.go @@ -0,0 +1,47 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + interfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + mock "github.com/stretchr/testify/mock" +) + +// ClusterPoolAssignmentConfiguration is an autogenerated mock type for the ClusterPoolAssignmentConfiguration type +type ClusterPoolAssignmentConfiguration struct { + mock.Mock +} + +type ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments struct { + *mock.Call +} + +func (_m ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments) Return(_a0 map[string]interfaces.ClusterPoolAssignment) *ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments { + return &ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments{Call: _m.Call.Return(_a0)} +} + +func (_m *ClusterPoolAssignmentConfiguration) OnGetClusterPoolAssignments() *ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments { + c_call := _m.On("GetClusterPoolAssignments") + return &ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments{Call: c_call} +} + +func (_m *ClusterPoolAssignmentConfiguration) OnGetClusterPoolAssignmentsMatch(matchers ...interface{}) *ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments { + c_call := _m.On("GetClusterPoolAssignments", matchers...) + return &ClusterPoolAssignmentConfiguration_GetClusterPoolAssignments{Call: c_call} +} + +// GetClusterPoolAssignments provides a mock function with given fields: +func (_m *ClusterPoolAssignmentConfiguration) GetClusterPoolAssignments() map[string]interfaces.ClusterPoolAssignment { + ret := _m.Called() + + var r0 map[string]interfaces.ClusterPoolAssignment + if rf, ok := ret.Get(0).(func() map[string]interfaces.ClusterPoolAssignment); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interfaces.ClusterPoolAssignment) + } + } + + return r0 +} diff --git a/pkg/runtime/mocks/mock_configuration_provider.go b/pkg/runtime/mocks/mock_configuration_provider.go index 7af3e2e35..2e8577736 100644 --- a/pkg/runtime/mocks/mock_configuration_provider.go +++ b/pkg/runtime/mocks/mock_configuration_provider.go @@ -16,6 +16,7 @@ type MockConfigurationProvider struct { clusterResourceConfiguration interfaces.ClusterResourceConfiguration namespaceMappingConfiguration interfaces.NamespaceMappingConfiguration qualityOfServiceConfiguration interfaces.QualityOfServiceConfiguration + clusterPoolAssignmentConfiguration interfaces.ClusterPoolAssignmentConfiguration } func (p *MockConfigurationProvider) ApplicationConfiguration() interfaces.ApplicationConfiguration { @@ -70,6 +71,14 @@ func (p *MockConfigurationProvider) AddQualityOfServiceConfiguration(config inte p.qualityOfServiceConfiguration = config } +func (p *MockConfigurationProvider) ClusterPoolAssignmentConfiguration() interfaces.ClusterPoolAssignmentConfiguration { + return p.clusterPoolAssignmentConfiguration +} + +func (p *MockConfigurationProvider) AddClusterPoolAssignmentConfiguration(cfg interfaces.ClusterPoolAssignmentConfiguration) { + p.clusterPoolAssignmentConfiguration = cfg +} + func NewMockConfigurationProvider( applicationConfiguration interfaces.ApplicationConfiguration, queueConfiguration interfaces.QueueConfiguration, @@ -82,13 +91,17 @@ func NewMockConfigurationProvider( mockQualityOfServiceConfiguration.OnGetDefaultTiers().Return(make(map[string]core.QualityOfService_Tier)) mockQualityOfServiceConfiguration.OnGetTierExecutionValues().Return(make(map[core.QualityOfService_Tier]core.QualityOfServiceSpec)) + mockClusterPoolAssignmentConfiguration := &ifaceMocks.ClusterPoolAssignmentConfiguration{} + mockClusterPoolAssignmentConfiguration.OnGetClusterPoolAssignments().Return(make(map[string]interfaces.ClusterPoolAssignment)) + return &MockConfigurationProvider{ - applicationConfiguration: applicationConfiguration, - queueConfiguration: queueConfiguration, - clusterConfiguration: clusterConfiguration, - taskResourceConfiguration: taskResourceConfiguration, - whitelistConfiguration: whitelistConfiguration, - namespaceMappingConfiguration: namespaceMappingConfiguration, - qualityOfServiceConfiguration: mockQualityOfServiceConfiguration, + applicationConfiguration: applicationConfiguration, + queueConfiguration: queueConfiguration, + clusterConfiguration: clusterConfiguration, + taskResourceConfiguration: taskResourceConfiguration, + whitelistConfiguration: whitelistConfiguration, + namespaceMappingConfiguration: namespaceMappingConfiguration, + qualityOfServiceConfiguration: mockQualityOfServiceConfiguration, + clusterPoolAssignmentConfiguration: mockClusterPoolAssignmentConfiguration, } }