Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
enable secrets in non-python k8s tasks (#401)
Browse files Browse the repository at this point in the history
* enabling secrets in non-python k8s tasks

Signed-off-by: Daniel Rammer <[email protected]>

* added unit tests

Signed-off-by: Daniel Rammer <[email protected]>

---------

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Sep 11, 2023
1 parent c978638 commit 8eddca3
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 9 deletions.
6 changes: 5 additions & 1 deletion go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import (

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"

flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
Expand Down Expand Up @@ -157,6 +157,10 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
return nil, fmt.Errorf("number of launch worker should be more then 0")
}

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

jobSpec := kubeflowv1.MPIJobSpec{
SlotsPerWorker: &slots,
RunPolicy: runPolicy,
Expand Down
23 changes: 21 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ var (
"test-args",
}

dummyAnnotations = map[string]string{
"annotation-key": "annotation-value",
}
dummyLabels = map[string]string{
"label-key": "label-value",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand Down Expand Up @@ -150,8 +157,8 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut
taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetAnnotations().Return(dummyAnnotations)
taskExecutionMetadata.OnGetLabels().Return(dummyLabels)
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
Expand Down Expand Up @@ -304,6 +311,18 @@ func TestBuildResourceMPI(t *testing.T) {
assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas)
assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker)

// verify TaskExecutionMetadata labels and annotations are copied to the MPIJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range mpiJob.Spec.MPIReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range mpiJob.Spec.MPIReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}

for _, replicaSpec := range mpiJob.Spec.MPIReplicaSpecs {
for _, container := range replicaSpec.Template.Spec.Containers {
assert.Equal(t, resourceRequirements.Requests, container.Resources.Requests)
Expand Down
5 changes: 5 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
Expand Down Expand Up @@ -148,6 +149,10 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return nil, fmt.Errorf("number of worker should be more then 0")
}

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

jobSpec := kubeflowv1.PyTorchJobSpec{
PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.PyTorchJobReplicaTypeMaster: {
Expand Down
35 changes: 33 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ var (
"test-args",
}

dummyAnnotations = map[string]string{
"annotation-key": "annotation-value",
}
dummyLabels = map[string]string{
"label-key": "label-value",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand Down Expand Up @@ -170,8 +177,8 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx
taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetAnnotations().Return(dummyAnnotations)
taskExecutionMetadata.OnGetLabels().Return(dummyLabels)
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
Expand Down Expand Up @@ -339,6 +346,18 @@ func TestBuildResourcePytorchElastic(t *testing.T) {
}

assert.True(t, hasContainerWithDefaultPytorchName)

// verify TaskExecutionMetadata labels and annotations are copied to the PyTorchJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}
}

func TestBuildResourcePytorch(t *testing.T) {
Expand All @@ -356,6 +375,18 @@ func TestBuildResourcePytorch(t *testing.T) {
assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.Nil(t, pytorchJob.Spec.ElasticPolicy)

// verify TaskExecutionMetadata labels and annotations are copied to the TensorFlowJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}

for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
var hasContainerWithDefaultPytorchName = false

Expand Down
5 changes: 5 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
Expand Down Expand Up @@ -163,6 +164,10 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
return nil, fmt.Errorf("number of worker should be more then 0")
}

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

jobSpec := kubeflowv1.TFJobSpec{
TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{},
}
Expand Down
23 changes: 21 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ var (
"test-args",
}

dummyAnnotations = map[string]string{
"annotation-key": "annotation-value",
}
dummyLabels = map[string]string{
"label-key": "label-value",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand Down Expand Up @@ -152,8 +159,8 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas
taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetAnnotations().Return(dummyAnnotations)
taskExecutionMetadata.OnGetLabels().Return(dummyLabels)
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
Expand Down Expand Up @@ -306,6 +313,18 @@ func TestBuildResourceTensorFlow(t *testing.T) {
assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas)
assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas)

// verify TaskExecutionMetadata labels and annotations are copied to the TensorFlowJob
for k, v := range dummyAnnotations {
for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Annotations[k])
}
}
for k, v := range dummyLabels {
for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs {
assert.Equal(t, v, replicaSpec.Template.ObjectMeta.Labels[k])
}
}

for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs {
var hasContainerWithDefaultTensorFlowName = false

Expand Down
6 changes: 4 additions & 2 deletions go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
Expand Down Expand Up @@ -207,8 +208,9 @@ func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMe
Spec: *headPodSpec,
ObjectMeta: *objectMeta,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
cfg := config.GetK8sPluginConfig()
podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
return podTemplateSpec
}

Expand Down

0 comments on commit 8eddca3

Please sign in to comment.