diff --git a/argo/kfp-compiler/acceptance/pipeline_conf.yaml b/argo/kfp-compiler/acceptance/pipeline_conf.yaml index 76ca9b987..fbfe4f2cf 100644 --- a/argo/kfp-compiler/acceptance/pipeline_conf.yaml +++ b/argo/kfp-compiler/acceptance/pipeline_conf.yaml @@ -1,4 +1,4 @@ -name: test +name: namespace/test image: test-pipeline tfxComponents: pipeline.create_components env: @@ -6,4 +6,4 @@ env: value: bar beamArgs: - name: anArg - value: aValue + value: aValue diff --git a/argo/kfp-compiler/acceptance/test_compiler.py b/argo/kfp-compiler/acceptance/test_compiler.py index a2ed8d4c3..57d7356b0 100644 --- a/argo/kfp-compiler/acceptance/test_compiler.py +++ b/argo/kfp-compiler/acceptance/test_compiler.py @@ -46,6 +46,7 @@ def test_cli_v2(): f = open(output_file_path, "r") pipeline = yaml.safe_load(f.read()) assert pipeline['pipelineSpec']['schemaVersion'] == '2.0.0' + assert pipeline['pipelineSpec']['pipelineInfo']['name'] == "namespace-test" def test_failure(): diff --git a/argo/kfp-compiler/kfp_compiler/compiler.py b/argo/kfp-compiler/kfp_compiler/compiler.py index 7f23c53c4..3c679f44a 100644 --- a/argo/kfp-compiler/kfp_compiler/compiler.py +++ b/argo/kfp-compiler/kfp_compiler/compiler.py @@ -68,6 +68,10 @@ def load_fn(tfx_components: str, env: list): return fn +def sanitise_namespaced_pipeline_name(namespaced_name: str) -> str: + return namespaced_name.replace("/", "-") + + @click.command() @click.option('--pipeline_config', help='Pipeline configuration in yaml format', required=True) @click.option('--provider_config', help='Provider configuration in yaml format', required=True) @@ -78,9 +82,11 @@ def compile(pipeline_config: str, provider_config: str, output_file: str): pipeline_config_contents = yaml.safe_load(pipeline_stream) provider_config_contents = yaml.safe_load(provider_stream) - click.secho(f'Compiling with pipeline: {pipeline_config_contents} and provider {provider_config_contents} ', fg='green') + click.secho(f'Compiling with pipeline: {pipeline_config_contents} and provider {provider_config_contents} ', + fg='green') - pipeline_root, serving_model_directory, temp_location = pipeline_paths_for_config(pipeline_config_contents, provider_config_contents) + pipeline_root, serving_model_directory, temp_location = pipeline_paths_for_config(pipeline_config_contents, + provider_config_contents) beam_args = provider_config_contents.get('defaultBeamArgs', []) beam_args.extend(pipeline_config_contents.get('beamArgs', [])) @@ -94,7 +100,7 @@ def compile(pipeline_config: str, provider_config: str, output_file: str): compile_fn(pipeline_config_contents, output_file).run( pipeline.Pipeline( - pipeline_name=pipeline_config_contents['name'], + pipeline_name=sanitise_namespaced_pipeline_name(pipeline_config_contents['name']), pipeline_root=pipeline_root, components=expanded_components, enable_cache=False, diff --git a/argo/kfp-compiler/tests/test_compiler.py b/argo/kfp-compiler/tests/test_compiler.py index ecc589ee9..921ed98a7 100644 --- a/argo/kfp-compiler/tests/test_compiler.py +++ b/argo/kfp-compiler/tests/test_compiler.py @@ -24,3 +24,10 @@ def test_pipeline_paths_for_config(): assert pipeline_root == "pipeline_root/pipeline" assert serving_model_directory == "pipeline_root/pipeline/serving" assert temp_directory == "pipeline_root/pipeline/tmp" + + +def test_sanitise_namespaced_pipeline_name(): + assert compiler.sanitise_namespaced_pipeline_name("pipeline-name") == "pipeline-name" + assert compiler.sanitise_namespaced_pipeline_name("/pipeline-name") == "-pipeline-name" + assert compiler.sanitise_namespaced_pipeline_name("mlops/pipeline-name") == "mlops-pipeline-name" + assert compiler.sanitise_namespaced_pipeline_name("") == "" diff --git a/argo/providers/base/provider.go b/argo/providers/base/provider.go index 28f0910cd..7d62181ba 100644 --- a/argo/providers/base/provider.go +++ b/argo/providers/base/provider.go @@ -9,12 +9,12 @@ import ( ) type PipelineDefinition struct { - Name string `yaml:"name"` - Version string `yaml:"version"` - Image string `yaml:"image"` - TfxComponents string `yaml:"tfxComponents"` - Env []apis.NamedValue `yaml:"env"` - BeamArgs []apis.NamedValue `yaml:"beamArgs"` + Name common.NamespacedName `yaml:"name"` + Version string `yaml:"version"` + Image string `yaml:"image"` + TfxComponents string `yaml:"tfxComponents"` + Env []apis.NamedValue `yaml:"env"` + BeamArgs []apis.NamedValue `yaml:"beamArgs"` } type ExperimentDefinition struct { @@ -24,7 +24,7 @@ type ExperimentDefinition struct { } type RunScheduleDefinition struct { - Name string `yaml:"name"` + Name common.NamespacedName `yaml:"name"` Version string `yaml:"version"` PipelineName common.NamespacedName `yaml:"pipelineName"` PipelineVersion string `yaml:"pipelineVersion"` diff --git a/argo/providers/kfp/provider.go b/argo/providers/kfp/provider.go index 2306c0134..a95722493 100644 --- a/argo/providers/kfp/provider.go +++ b/argo/providers/kfp/provider.go @@ -43,7 +43,7 @@ func (kfpp KfpProvider) CreatePipeline(ctx context.Context, providerConfig KfpPr } result, err := pipelineUploadService.UploadPipeline(&pipeline_upload_service.UploadPipelineParams{ - Name: &pipelineDefinition.Name, + Name: &pipelineDefinition.Name.Name, Uploadfile: runtime.NamedReader(pipelineFileName, reader), Context: ctx, }, nil) @@ -234,7 +234,7 @@ func (kfpp KfpProvider) CreateRunSchedule(ctx context.Context, providerConfig Kf Parameters: jobParameters, }, Description: string(runScheduleAsDescription), - Name: runScheduleDefinition.Name, + Name: runScheduleDefinition.Name.Name, MaxConcurrency: 1, Enabled: true, NoCatchup: true, diff --git a/argo/providers/stub/provider.go b/argo/providers/stub/provider.go index 92e7dc8f1..f282faff4 100644 --- a/argo/providers/stub/provider.go +++ b/argo/providers/stub/provider.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/argoproj/argo-events/eventsources/sources/generic" + "github.com/sky-uk/kfp-operator/argo/common" "github.com/sky-uk/kfp-operator/argo/providers/base" ) @@ -14,8 +15,8 @@ type StubProviderConfig struct { } type ResourceDefinition struct { - Name string `yaml:"name"` - Version string `yaml:"version"` + Name common.NamespacedName `yaml:"name"` + Version string `yaml:"version"` } type ExpectedInput struct { @@ -71,7 +72,7 @@ func (s StubProvider) DeletePipeline(_ context.Context, providerConfig StubProvi } func (s StubProvider) CreateRun(_ context.Context, providerConfig StubProviderConfig, resourceDefinition base.RunDefinition) (string, error) { - return verifyCreateCall(providerConfig, ResourceDefinition{resourceDefinition.Name.Name, resourceDefinition.Version}) + return verifyCreateCall(providerConfig, ResourceDefinition{resourceDefinition.Name, resourceDefinition.Version}) } func (s StubProvider) DeleteRun(_ context.Context, providerConfig StubProviderConfig, id string) error { @@ -79,11 +80,11 @@ func (s StubProvider) DeleteRun(_ context.Context, providerConfig StubProviderCo } func (s StubProvider) CreateRunSchedule(_ context.Context, providerConfig StubProviderConfig, resourceDefinition base.RunScheduleDefinition) (string, error) { - return verifyCreateCall(providerConfig, ResourceDefinition{resourceDefinition.Name, resourceDefinition.Version}) + return verifyCreateCall(providerConfig, ResourceDefinition{Name: resourceDefinition.Name, Version: resourceDefinition.Version}) } func (s StubProvider) UpdateRunSchedule(_ context.Context, providerConfig StubProviderConfig, resourceDefinition base.RunScheduleDefinition, id string) (string, error) { - return verifyUpdateCall(providerConfig, ResourceDefinition{resourceDefinition.Name, resourceDefinition.Version}, id) + return verifyUpdateCall(providerConfig, ResourceDefinition{Name: resourceDefinition.Name, Version: resourceDefinition.Version}, id) } func (s StubProvider) DeleteRunSchedule(_ context.Context, providerConfig StubProviderConfig, id string) error { @@ -91,11 +92,11 @@ func (s StubProvider) DeleteRunSchedule(_ context.Context, providerConfig StubPr } func (s StubProvider) CreateExperiment(_ context.Context, providerConfig StubProviderConfig, resourceDefinition base.ExperimentDefinition) (string, error) { - return verifyCreateCall(providerConfig, ResourceDefinition{resourceDefinition.Name, resourceDefinition.Version}) + return verifyCreateCall(providerConfig, ResourceDefinition{common.NamespacedName{Name: resourceDefinition.Name}, resourceDefinition.Version}) } func (s StubProvider) UpdateExperiment(_ context.Context, providerConfig StubProviderConfig, resourceDefinition base.ExperimentDefinition, id string) (string, error) { - return verifyUpdateCall(providerConfig, ResourceDefinition{resourceDefinition.Name, resourceDefinition.Version}, id) + return verifyUpdateCall(providerConfig, ResourceDefinition{common.NamespacedName{Name: resourceDefinition.Name}, resourceDefinition.Version}, id) } func (s StubProvider) DeleteExperiment(_ context.Context, providerConfig StubProviderConfig, id string) error { diff --git a/argo/providers/vai/config.go b/argo/providers/vai/config.go index 76406d065..6421bbdd2 100644 --- a/argo/providers/vai/config.go +++ b/argo/providers/vai/config.go @@ -2,7 +2,7 @@ package vai import ( "fmt" - "strings" + "github.com/sky-uk/kfp-operator/argo/common" ) type VAIProviderConfig struct { @@ -28,16 +28,20 @@ func (vaipc VAIProviderConfig) pipelineJobName(name string) string { return fmt.Sprintf("%s/pipelineJobs/%s", vaipc.parent(), name) } -func (vaipc VAIProviderConfig) pipelineStorageObject(pipelineName string, pipelineVersion string) string { - return fmt.Sprintf("%s/%s", pipelineName, pipelineVersion) -} - -func (vaipc VAIProviderConfig) gcsUri(bucket string, pathSegments ...string) string { - return fmt.Sprintf("gs://%s/%s", bucket, strings.Join(pathSegments, "/")) +func (vaipc VAIProviderConfig) pipelineStorageObject(pipelineName common.NamespacedName, pipelineVersion string) (string, error) { + namespaceName, err := pipelineName.String() + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", namespaceName, pipelineVersion), nil } -func (vaipc VAIProviderConfig) pipelineUri(pipelineName string, pipelineVersion string) string { - return vaipc.gcsUri(vaipc.PipelineBucket, vaipc.pipelineStorageObject(pipelineName, pipelineVersion)) +func (vaipc VAIProviderConfig) pipelineUri(pipelineName common.NamespacedName, pipelineVersion string) (string, error) { + pipelineUri, err := vaipc.pipelineStorageObject(pipelineName, pipelineVersion) + if err != nil { + return "", err + } + return fmt.Sprintf("gs://%s/%s", vaipc.PipelineBucket, pipelineUri), nil } func (vaipc VAIProviderConfig) getMaxConcurrentRunCountOrDefault() int64 { diff --git a/argo/providers/vai/config_unit_test.go b/argo/providers/vai/config_unit_test.go index 34da40e7d..af6dfa61f 100644 --- a/argo/providers/vai/config_unit_test.go +++ b/argo/providers/vai/config_unit_test.go @@ -26,4 +26,52 @@ var _ = Context("VAI Config", func() { Entry("", -common.RandomInt64(), true), Entry("", common.RandomInt64()+1, false), ) + + DescribeTable("pipelineStorageObject", func(pipelineName common.NamespacedName, pipelineVersion string, expectedStorageObject string) { + storageObject, err := config.pipelineStorageObject(pipelineName, pipelineVersion) + if expectedStorageObject == "" { + Expect(err).To(HaveOccurred()) + } else { + Expect(err).NotTo(HaveOccurred()) + Expect(storageObject).To(Equal(expectedStorageObject)) + } + }, + Entry("", common.NamespacedName{ + Name: "myName", + Namespace: "myNamespace", + }, "version", "myNamespace/myName/version"), + Entry("", common.NamespacedName{ + Name: "myName", + Namespace: "", + }, "version", "myName/version"), + Entry("", common.NamespacedName{ + Name: "", + Namespace: "myNamespace", + }, "version", ""), + ) + + DescribeTable("pipelineUri", func(bucket string, pipelineName common.NamespacedName, pipelineVersion string, expectedStorageObject string) { + config.PipelineBucket = bucket + + storageObject, err := config.pipelineUri(pipelineName, pipelineVersion) + if expectedStorageObject == "" { + Expect(err).To(HaveOccurred()) + } else { + Expect(err).NotTo(HaveOccurred()) + Expect(storageObject).To(Equal(expectedStorageObject)) + } + }, + Entry("", "bucket", common.NamespacedName{ + Name: "myName", + Namespace: "myNamespace", + }, "version", "gs://bucket/myNamespace/myName/version"), + Entry("", "", common.NamespacedName{ + Name: "myName", + Namespace: "myNamespace", + }, "version", "gs:///myNamespace/myName/version"), + Entry("", "bucket", common.NamespacedName{ + Name: "", + Namespace: "myNamespace", + }, "version", ""), + ) }) diff --git a/argo/providers/vai/provider.go b/argo/providers/vai/provider.go index 9eb385f97..ce9ccdbcd 100644 --- a/argo/providers/vai/provider.go +++ b/argo/providers/vai/provider.go @@ -129,41 +129,50 @@ type VAIProvider struct { } func (vaip VAIProvider) CreatePipeline(ctx context.Context, providerConfig VAIProviderConfig, pipelineDefinition PipelineDefinition, pipelineFile string) (string, error) { - if _, err := vaip.UpdatePipeline(ctx, providerConfig, pipelineDefinition, pipelineDefinition.Name, pipelineFile); err != nil { + if _, err := vaip.UpdatePipeline(ctx, providerConfig, pipelineDefinition, "", pipelineFile); err != nil { return "", err } - return pipelineDefinition.Name, nil + return pipelineDefinition.Name.String() } -func (vaip VAIProvider) UpdatePipeline(ctx context.Context, providerConfig VAIProviderConfig, pipelineDefinition PipelineDefinition, id string, pipelineFile string) (string, error) { +func (vaip VAIProvider) UpdatePipeline(ctx context.Context, providerConfig VAIProviderConfig, pipelineDefinition PipelineDefinition, _ string, pipelineFile string) (string, error) { + pipelineId, err := pipelineDefinition.Name.String() + if err != nil { + return "", err + } client, err := gcsClient(ctx, providerConfig) if err != nil { - return id, err + return pipelineId, err } reader, err := os.Open(pipelineFile) if err != nil { - return id, err + return pipelineId, err } - writer := client.Bucket(providerConfig.PipelineBucket).Object(providerConfig.pipelineStorageObject(id, pipelineDefinition.Version)).NewWriter(ctx) + storageObject, err := providerConfig.pipelineStorageObject(pipelineDefinition.Name, pipelineDefinition.Version) + if err != nil { + return pipelineId, err + } + writer := client.Bucket(providerConfig.PipelineBucket).Object(storageObject).NewWriter(ctx) + _, err = io.Copy(writer, reader) if err != nil { - return id, err + return pipelineId, err } err = writer.Close() if err != nil { - return id, err + return pipelineId, err } err = reader.Close() if err != nil { - return id, err + return pipelineId, err } - return id, nil + return pipelineId, nil } func (vaip VAIProvider) DeletePipeline(ctx context.Context, providerConfig VAIProviderConfig, id string) error { @@ -193,8 +202,30 @@ func (vaip VAIProvider) DeletePipeline(ctx context.Context, providerConfig VAIPr return nil } +func extractFromStruct(pbStruct *structpb.Struct, fieldName string) (value *structpb.Value, err error) { + value, ok := pbStruct.Fields[fieldName] + if !ok { + err = fmt.Errorf("failed extracting field %s from the given struct", fieldName) + } + return value, err +} + +func extractPipelineNameFromPipelineSpec(ctx context.Context, pipelineSpec *structpb.Struct) (string, error) { + logger := common.LoggerFromContext(ctx) + pipelineInfo, err := extractFromStruct(pipelineSpec, "pipelineInfo") + if err != nil { + logger.Error(err, "Failed to extract pipelineInfo from pipeline spec") + return "", err + } + pipelineInfoName, err := extractFromStruct(pipelineInfo.GetStructValue(), "name") + if err != nil { + logger.Error(err, "Failed to extract name from pipelineInfo") + return "", err + } + return pipelineInfoName.GetStringValue(), nil +} + func (vaip VAIProvider) CreateRun(ctx context.Context, providerConfig VAIProviderConfig, runDefinition RunDefinition) (string, error) { - runId := runDefinition.Name.Name pipelineClient, err := aiplatform.NewPipelineClient(ctx, option.WithEndpoint(providerConfig.vaiEndpoint())) if err != nil { @@ -211,9 +242,13 @@ func (vaip VAIProvider) CreateRun(ctx context.Context, providerConfig VAIProvide } } + templateUri, err := providerConfig.pipelineUri(runDefinition.PipelineName, runDefinition.PipelineVersion) + if err != nil { + return "", err + } pipelineJob := &aiplatformpb.PipelineJob{ Labels: runLabelsFromRunDefinition(runDefinition), - TemplateUri: providerConfig.pipelineUri(runDefinition.PipelineName.Name, runDefinition.PipelineVersion), + TemplateUri: templateUri, ServiceAccount: providerConfig.VaiJobServiceAccount, RuntimeConfig: &aiplatformpb.PipelineJob_RuntimeConfig{ Parameters: parameters, @@ -225,9 +260,14 @@ func (vaip VAIProvider) CreateRun(ctx context.Context, providerConfig VAIProvide return "", err } + runId, err := extractPipelineNameFromPipelineSpec(ctx, pipelineJob.PipelineSpec) + if err != nil { + return "", err + } + req := &aiplatformpb.CreatePipelineJobRequest{ Parent: providerConfig.parent(), - PipelineJobId: runId, + PipelineJobId: fmt.Sprintf("%s-%s", runId, runDefinition.Version), PipelineJob: pipelineJob, } @@ -255,9 +295,13 @@ func (vaip VAIProvider) buildPipelineJob(providerConfig VAIProviderConfig, runSc }) // Note: unable to migrate from `Parameters` to `ParameterValues` at this point as `PipelineJob.pipeline_spec.schema_version` used by TFX is 2.0.0 see deprecated comment + templateUri, err := providerConfig.pipelineUri(runScheduleDefinition.PipelineName, runScheduleDefinition.PipelineVersion) + if err != nil { + return nil, err + } pipelineJob := &aiplatformpb.PipelineJob{ Labels: runLabelsFromSchedule(runScheduleDefinition), - TemplateUri: providerConfig.pipelineUri(runScheduleDefinition.PipelineName.Name, runScheduleDefinition.PipelineVersion), + TemplateUri: templateUri, ServiceAccount: providerConfig.VaiJobServiceAccount, RuntimeConfig: &aiplatformpb.PipelineJob_RuntimeConfig{ Parameters: parameters, @@ -281,7 +325,7 @@ func (vaip VAIProvider) buildVaiScheduleFromPipelineJob(providerConfig VAIProvid PipelineJob: pipelineJob, }, }, - DisplayName: fmt.Sprintf("rc-%s", runScheduleDefinition.Name), + DisplayName: fmt.Sprintf("rc-%s-%s", runScheduleDefinition.Name.Namespace, runScheduleDefinition.Name.Name), MaxConcurrentRunCount: providerConfig.getMaxConcurrentRunCountOrDefault(), AllowQueueing: true, }, nil diff --git a/argo/providers/vai/provider_unit_test.go b/argo/providers/vai/provider_unit_test.go index 10db6b21e..758f66057 100644 --- a/argo/providers/vai/provider_unit_test.go +++ b/argo/providers/vai/provider_unit_test.go @@ -4,11 +4,13 @@ package vai import ( "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" + "context" "errors" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/sky-uk/kfp-operator/argo/common" . "github.com/sky-uk/kfp-operator/argo/providers/base" + "google.golang.org/protobuf/types/known/structpb" ) func randomBasicRunDefinition() RunDefinition { @@ -22,7 +24,7 @@ func randomBasicRunDefinition() RunDefinition { func randomRunScheduleDefinition() RunScheduleDefinition { return RunScheduleDefinition{ - Name: common.RandomString(), + Name: common.RandomNamespacedName(), Version: common.RandomString(), PipelineName: common.RandomNamespacedName(), PipelineVersion: common.RandomString(), @@ -194,4 +196,60 @@ var _ = Context("VAI Provider", func() { })) }) }) + + Describe("extractFromStruct", func() { + It("should extract a value from a given struct", func() { + pipelineSpec := map[string]interface{}{ + "myKey": "myValue", + } + pbStruct, err := structpb.NewStruct(pipelineSpec) + Expect(err).NotTo(HaveOccurred()) + + result, err := extractFromStruct(pbStruct, "myKey") + Expect(err).NotTo(HaveOccurred()) + + Expect(result.GetStringValue()).To(Equal("myValue")) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should error if struct is missing a required field", func() { + pipelineSpec, err := structpb.NewStruct(map[string]interface{}{ + "myKey": "myValue", + }) + Expect(err).NotTo(HaveOccurred()) + + _, err = extractFromStruct(pipelineSpec, "myOtherKey") + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("extractPipelineNameFromPipelineSpec", func() { + ctx := context.Background() + + It("should extract pipelineInfo name from a given pipelineSpec", func() { + pipelineName := "myPipelineName" + pipelineSpec := map[string]interface{}{ + "pipelineInfo": map[string]interface{}{ + "name": pipelineName, + }, + } + pbStruct, err := structpb.NewStruct(pipelineSpec) + Expect(err).NotTo(HaveOccurred()) + + result, err := extractPipelineNameFromPipelineSpec(ctx, pbStruct) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(pipelineName)) + }) + + It("should error if pipelineSpec is missing a required field", func() { + pbStruct, err := structpb.NewStruct(map[string]interface{}{ + "pipelineInfo": map[string]interface{}{ + "other_name": "other_value", + }, + }) + + _, err = extractPipelineNameFromPipelineSpec(ctx, pbStruct) + Expect(err).To(HaveOccurred()) + }) + }) }) diff --git a/controllers/pipelines/experiment_workflow_integration_test.go b/controllers/pipelines/experiment_workflow_integration_test.go index 91b7511d1..88a92a5bf 100644 --- a/controllers/pipelines/experiment_workflow_integration_test.go +++ b/controllers/pipelines/experiment_workflow_integration_test.go @@ -17,7 +17,11 @@ var _ = Context("Resource Workflows", Serial, func() { }) var newExperiment = func() *pipelinesv1.Experiment { - return withIntegrationTestFields(pipelinesv1.RandomExperiment()) + resource := pipelinesv1.RandomExperiment() + resourceStatus := resource.GetStatus() + resourceStatus.ProviderId.Provider = TestProvider + resource.SetStatus(resourceStatus) + return resource } DescribeTable("Experiment Workflows", AssertWorkflow[*pipelinesv1.Experiment], diff --git a/controllers/pipelines/pipeline_workflow_factory.go b/controllers/pipelines/pipeline_workflow_factory.go index 2b37ed09d..75111aaa6 100644 --- a/controllers/pipelines/pipeline_workflow_factory.go +++ b/controllers/pipelines/pipeline_workflow_factory.go @@ -3,6 +3,7 @@ package pipelines import ( config "github.com/sky-uk/kfp-operator/apis/config/v1alpha5" pipelinesv1 "github.com/sky-uk/kfp-operator/apis/pipelines/v1alpha5" + "github.com/sky-uk/kfp-operator/argo/common" providers "github.com/sky-uk/kfp-operator/argo/providers/base" ) @@ -12,7 +13,7 @@ type PipelineDefinitionCreator struct { func (pdc PipelineDefinitionCreator) pipelineDefinition(pipeline *pipelinesv1.Pipeline) (providers.PipelineDefinition, error) { return providers.PipelineDefinition{ - Name: pipeline.ObjectMeta.Name, + Name: common.NamespacedName{Name: pipeline.ObjectMeta.Name, Namespace: pipeline.ObjectMeta.Namespace}, Version: pipeline.ComputeVersion(), Image: pipeline.Spec.Image, TfxComponents: pipeline.Spec.TfxComponents, diff --git a/controllers/pipelines/pipeline_workflow_factory_test.go b/controllers/pipelines/pipeline_workflow_factory_test.go index 2a387877e..4f1fc41d3 100644 --- a/controllers/pipelines/pipeline_workflow_factory_test.go +++ b/controllers/pipelines/pipeline_workflow_factory_test.go @@ -7,6 +7,7 @@ import ( . "github.com/onsi/gomega" "github.com/sky-uk/kfp-operator/apis" pipelinesv1 "github.com/sky-uk/kfp-operator/apis/pipelines/v1alpha5" + "github.com/sky-uk/kfp-operator/argo/common" providers "github.com/sky-uk/kfp-operator/argo/providers/base" "gopkg.in/yaml.v2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -26,7 +27,8 @@ var _ = Describe("PipelineDefinition", func() { pipeline := &pipelinesv1.Pipeline{ ObjectMeta: metav1.ObjectMeta{ - Name: "pipelineName", + Name: "pipelineName", + Namespace: "pipelineNamespace", }, Spec: pipelinesv1.PipelineSpec{ Image: "pipelineImage", @@ -38,7 +40,10 @@ var _ = Describe("PipelineDefinition", func() { compilerConfig, _ := wf.pipelineDefinition(pipeline) - Expect(compilerConfig.Name).To(Equal("pipelineName")) + Expect(compilerConfig.Name).To(Equal(common.NamespacedName{ + Name: "pipelineName", + Namespace: "pipelineNamespace", + })) Expect(compilerConfig.Image).To(Equal("pipelineImage")) Expect(compilerConfig.TfxComponents).To(Equal("pipelineTfxComponents")) Expect(compilerConfig.Env).To(Equal(expectedEnv)) @@ -47,7 +52,10 @@ var _ = Describe("PipelineDefinition", func() { It("Creates a valid YAML", func() { config := providers.PipelineDefinition{ - Name: "pipelineName", + Name: common.NamespacedName{ + Name: "pipelineName", + Namespace: "pipelineNamespace", + }, Image: "pipelineImage", TfxComponents: "pipelineTfxComponents", Env: []apis.NamedValue{ @@ -64,7 +72,7 @@ var _ = Describe("PipelineDefinition", func() { m := make(map[interface{}]interface{}) yaml.Unmarshal(configYaml, m) - Expect(m["name"]).To(Equal("pipelineName")) + Expect(m["name"]).To(Equal("pipelineNamespace/pipelineName")) Expect(m["image"]).To(Equal("pipelineImage")) Expect(m["tfxComponents"]).To(Equal("pipelineTfxComponents")) env := m["env"].([]interface{}) diff --git a/controllers/pipelines/runschedule_workflow_factory.go b/controllers/pipelines/runschedule_workflow_factory.go index 513ceea8f..878c240e4 100644 --- a/controllers/pipelines/runschedule_workflow_factory.go +++ b/controllers/pipelines/runschedule_workflow_factory.go @@ -25,7 +25,7 @@ func (rcdc RunScheduleDefinitionCreator) runScheduleDefinition(runSchedule *pipe } return providers.RunScheduleDefinition{ - Name: runSchedule.ObjectMeta.Name, + Name: common.NamespacedName{Name: runSchedule.ObjectMeta.Name, Namespace: runSchedule.Namespace}, RunConfigurationName: runConfigurationNameForRunSchedule(runSchedule), Version: runSchedule.ComputeVersion(), PipelineName: common.NamespacedName{Name: runSchedule.Spec.Pipeline.Name, Namespace: runSchedule.Namespace}, diff --git a/controllers/pipelines/suite_integration_test.go b/controllers/pipelines/suite_integration_test.go index 7d6de049e..9d1ae9be8 100644 --- a/controllers/pipelines/suite_integration_test.go +++ b/controllers/pipelines/suite_integration_test.go @@ -10,6 +10,7 @@ import ( . "github.com/onsi/gomega" "github.com/sky-uk/kfp-operator/apis" pipelinesv1 "github.com/sky-uk/kfp-operator/apis/pipelines/v1alpha5" + "github.com/sky-uk/kfp-operator/argo/common" "github.com/sky-uk/kfp-operator/argo/providers/base" "github.com/sky-uk/kfp-operator/argo/providers/stub" "github.com/sky-uk/kfp-operator/external" @@ -24,7 +25,9 @@ import ( ) const ( - TestTimeout = 120 + TestTimeout = 120 + TestNamespace = "argo" + TestProvider = "stub" ) var ( @@ -51,7 +54,7 @@ var _ = BeforeEach(func() { k8sClient.Delete(ctx, &v1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ Name: "kfp-operator-integration-tests-providers", - Namespace: "argo", + Namespace: TestNamespace, }}) }) @@ -61,7 +64,10 @@ func StubProvider[R pipelinesv1.Resource](stubbedOutput base.Output, resource R) ExpectedInput: stub.ExpectedInput{ Id: resource.GetStatus().ProviderId.Id, ResourceDefinition: stub.ResourceDefinition{ - Name: resource.GetName(), + Name: common.NamespacedName{ + Name: resource.GetName(), + Namespace: resource.GetNamespace(), + }, Version: resource.ComputeVersion(), }, }, @@ -73,10 +79,10 @@ func StubProvider[R pipelinesv1.Resource](stubbedOutput base.Output, resource R) Expect(k8sClient.Create(ctx, &v1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ Name: "kfp-operator-integration-tests-providers", - Namespace: "argo", + Namespace: TestNamespace, }, Data: map[string]string{ - "stub": fmt.Sprintf("%s\nserviceAccount: default\nimage: kfp-operator-stub-provider\nexecutionMode: none", configYaml), + TestProvider: fmt.Sprintf("%s\nserviceAccount: default\nimage: kfp-operator-stub-provider\nexecutionMode: none", configYaml), }, })).To(Succeed()) @@ -111,7 +117,7 @@ func AssertWorkflow[R pipelinesv1.Resource]( } expectedOutput := setUp(testCtx.Resource) - workflow, err := constructWorkflow("stub", testCtx.Resource) + workflow, err := constructWorkflow(TestProvider, testCtx.Resource) Expect(err).NotTo(HaveOccurred()) Expect(k8sClient.Create(ctx, workflow)).To(Succeed()) @@ -127,9 +133,9 @@ func AssertWorkflow[R pipelinesv1.Resource]( } func withIntegrationTestFields[T pipelinesv1.Resource](resource T) T { - resource.SetNamespace("argo") + resource.SetNamespace(TestNamespace) resourceStatus := resource.GetStatus() - resourceStatus.ProviderId.Provider = "stub" + resourceStatus.ProviderId.Provider = TestProvider resource.SetStatus(resourceStatus) return resource