Skip to content

Commit

Permalink
propagate job pod template updates to suspended jobs when resuming
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Jun 3, 2024
1 parent 0436215 commit d809200
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 25 deletions.
39 changes: 31 additions & 8 deletions pkg/controllers/jobset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ func (r *JobSetReconciler) suspendJobs(ctx context.Context, js *jobset.JobSet, a
// resumeJobsIfNecessary iterates through each replicatedJob, resuming any suspended jobs if the JobSet
// is not suspended.
func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset.JobSet, activeJobs []*batchv1.Job, replicatedJobStatuses []jobset.ReplicatedJobStatus, updateStatusOpts *statusUpdateOpts) error {
// Store node selector for each replicatedJob template.
nodeAffinities := map[string]map[string]string{}
// Store pod template for each replicatedJob.
replicatedJobTemplateMap := map[string]corev1.PodTemplateSpec{}
for _, replicatedJob := range js.Spec.ReplicatedJobs {
nodeAffinities[replicatedJob.Name] = replicatedJob.Template.Spec.Template.Spec.NodeSelector
replicatedJobTemplateMap[replicatedJob.Name] = replicatedJob.Template.Spec.Template
}

// Map each replicatedJob to a list of its active jobs.
Expand All @@ -415,7 +415,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset
if !jobSuspended(job) {
continue
}
if err := r.resumeJob(ctx, job, nodeAffinities); err != nil {
if err := r.resumeJob(ctx, job, replicatedJobTemplateMap); err != nil {
return err
}
}
Expand All @@ -433,7 +433,7 @@ func (r *JobSetReconciler) resumeJobsIfNecessary(ctx context.Context, js *jobset
return nil
}

func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, nodeAffinities map[string]map[string]string) error {
func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, replicatedJobTemplateMap map[string]corev1.PodTemplateSpec) error {
log := ctrl.LoggerFrom(ctx)
// Kubernetes validates that a job template is immutable
// so if the job has started i.e., startTime != nil), we must set it to nil first.
Expand All @@ -443,10 +443,33 @@ func (r *JobSetReconciler) resumeJob(ctx context.Context, job *batchv1.Job, node
return err
}
}

// Get name of parent replicated job and use it to look up the pod template.
replicatedJobName := job.Labels[jobset.ReplicatedJobNameKey]
replicatedJobPodTemplate := replicatedJobTemplateMap[replicatedJobName]
if job.Labels != nil && job.Labels[jobset.ReplicatedJobNameKey] != "" {
// When resuming a job, its nodeSelectors should match that of the replicatedJob template
// that it was created from, which may have been updated while it was suspended.
job.Spec.Template.Spec.NodeSelector = nodeAffinities[job.Labels[jobset.ReplicatedJobNameKey]]
// Certain fields on the Job pod template may be mutated while a JobSet is suspended,
// for integration with Kueue. Ensure these updates are propagated to the child Jobs
// when the JobSet is resumed.
// Merge values rather than overwriting them, since a different controller
// (e.g., the Job controller) may have added labels/annotations/etc to the
// Job that do not exist in the ReplicatedJob pod template.
job.Spec.Template.Labels = collections.MergeMaps(
job.Spec.Template.Labels,
replicatedJobPodTemplate.Labels,
)
job.Spec.Template.Annotations = collections.MergeMaps(
job.Spec.Template.Annotations,
replicatedJobPodTemplate.Annotations,
)
job.Spec.Template.Spec.NodeSelector = collections.MergeMaps(
job.Spec.Template.Spec.NodeSelector,
replicatedJobPodTemplate.Spec.NodeSelector,
)
job.Spec.Template.Spec.Tolerations = collections.MergeSlices(
job.Spec.Template.Spec.Tolerations,
replicatedJobPodTemplate.Spec.Tolerations,
)
} else {
log.Error(nil, "job missing ReplicatedJobName label")
}
Expand Down
38 changes: 38 additions & 0 deletions pkg/util/collections/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,41 @@ func IndexOf[T comparable](slice []T, item T) int {
}
return -1
}

// MergeMaps will merge the `old` and `new` maps and return the
// merged map. If a key appears in both maps, the key-value pair
// in the `new` map will overwrite the value in the `old` map.
func MergeMaps[K comparable, V any](old, new map[K]V) map[K]V {
merged := make(map[K]V)
for k, v := range old {
merged[k] = v
}
for k, v := range new {
merged[k] = v // Overwrite if duplicate
}
return merged
}

func MergeSlices[T comparable](s1, s2 []T) []T {
mergedSet := make(map[T]bool)

// Add elements from s1 to the set
for _, item := range s1 {
mergedSet[item] = true
}

// Add elements from s2, only if they are not already in the set
for _, item := range s2 {
if _, exists := mergedSet[item]; !exists {
mergedSet[item] = true
}
}

// Convert the set back into a slice
mergedSlice := make([]T, 0, len(mergedSet))
for item := range mergedSet {
mergedSlice = append(mergedSlice, item)
}

return mergedSlice
}
88 changes: 88 additions & 0 deletions pkg/util/collections/collections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"golang.org/x/exp/slices"
)

func TestConcat(t *testing.T) {
Expand Down Expand Up @@ -151,3 +152,90 @@ func TestContains(t *testing.T) {
})
}
}

func TestMergeMaps(t *testing.T) {
testCases := []struct {
name string
m1 map[string]int
m2 map[string]int
expected map[string]int
}{
{
name: "Basic merge",
m1: map[string]int{"a": 1, "b": 2},
m2: map[string]int{"c": 3, "d": 4},
expected: map[string]int{"a": 1, "b": 2, "c": 3, "d": 4},
},
{
name: "Overlapping keys",
m1: map[string]int{"a": 1, "b": 2},
m2: map[string]int{"b": 3, "c": 4},
expected: map[string]int{"a": 1, "b": 3, "c": 4}, // m2 value for 'b' overwrites
},
{
name: "Empty maps",
m1: map[string]int{},
m2: map[string]int{},
expected: map[string]int{},
},
{
name: "One empty map",
m1: map[string]int{"a": 1, "b": 2},
m2: map[string]int{},
expected: map[string]int{"a": 1, "b": 2},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
merged := MergeMaps(tc.m1, tc.m2)

if !reflect.DeepEqual(merged, tc.expected) {
t.Errorf("expected %v, got %v", tc.expected, merged)
}
})
}
}

func TestMergeSlices(t *testing.T) {
testCases := []struct {
name string
s1 []int
s2 []int
expected []int
}{
{
name: "merge with overlapping elements should not result in duplicates",
s1: []int{1, 2, 3},
s2: []int{3, 4, 5},
expected: []int{1, 2, 3, 4, 5},
},
{
name: "empty slices",
s1: []int{},
s2: []int{},
expected: []int{},
},
{
name: "one empty slice",
s1: []int{1, 2},
s2: []int{},
expected: []int{1, 2},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
merged := MergeSlices(tc.s1, tc.s2)

// Sort before comparison so slices with the same elements
// should be the same.
slices.Sort(merged)
slices.Sort(tc.expected)

if !reflect.DeepEqual(merged, tc.expected) {
t.Errorf("Expected %v, got %v", tc.expected, merged)
}
})
}
}
80 changes: 63 additions & 17 deletions test/integration/controller/jobset_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,16 @@ var _ = ginkgo.Describe("JobSet controller", func() {
updates []*update
}

nodeSelectors := map[string]map[string]string{
"replicated-job-a": {"node-selector-test-a": "node-selector-test-a"},
"replicated-job-b": {"node-selector-test-b": "node-selector-test-b"},
var podTemplateUpdates = &updatePodTemplateOpts{
labels: map[string]string{"label": "value"},
annotations: map[string]string{"annotation": "value"},
nodeSelector: map[string]string{"node-selector-test-a": "node-selector-test-a"},
tolerations: []corev1.Toleration{
{
Key: "key",
Operator: corev1.TolerationOpExists,
},
},
}

ginkgo.DescribeTable("jobset is created and its jobs go through a series of updates",
Expand Down Expand Up @@ -514,7 +521,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
},
{
jobSetUpdateFn: func(js *jobset.JobSet) {
updateJobSetNodeSelectors(js, nodeSelectors)
updatePodTemplates(js, podTemplateUpdates)
},
checkJobSetState: func(js *jobset.JobSet) {
ginkgo.By("Check ReplicatedJobStatus for suspend")
Expand Down Expand Up @@ -542,7 +549,7 @@ var _ = ginkgo.Describe("JobSet controller", func() {
{
checkJobSetState: func(js *jobset.JobSet) {
ginkgo.By("checking jobs have expected node selectors")
gomega.Eventually(matchJobsNodeSelectors, timeout, interval).WithArguments(js, nodeSelectors).Should(gomega.Equal(true))
gomega.Eventually(checkPodTemplateUpdates, timeout, interval).WithArguments(js, podTemplateUpdates).Should(gomega.Equal(true))
},
jobUpdateFn: completeAllJobs,
checkJobSetCondition: testutil.JobSetCompleted,
Expand Down Expand Up @@ -1464,15 +1471,35 @@ func suspendJobSet(js *jobset.JobSet, suspend bool) {
}, timeout, interval).Should(gomega.Succeed())
}

func updateJobSetNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) {
// updatePodTemplateOpts contains pod template values
// which can be mutated on a ReplicatedJob template
// while a JobSet is suspended.
type updatePodTemplateOpts struct {
labels map[string]string
annotations map[string]string
nodeSelector map[string]string
tolerations []corev1.Toleration
}

func updatePodTemplates(js *jobset.JobSet, opts *updatePodTemplateOpts) {
gomega.Eventually(func() error {
var jsGet jobset.JobSet
if err := k8sClient.Get(ctx, types.NamespacedName{Name: js.Name, Namespace: js.Namespace}, &jsGet); err != nil {
return err
}
for index := range jsGet.Spec.ReplicatedJobs {
jsGet.Spec.ReplicatedJobs[index].
Template.Spec.Template.Spec.NodeSelector = nodeSelectors[jsGet.Spec.ReplicatedJobs[index].Name]
podTemplate := &jsGet.Spec.ReplicatedJobs[index].Template.Spec.Template
// Update labels.
podTemplate.Labels = opts.labels

// Update annotations.
podTemplate.Annotations = opts.annotations

// Update node selector.
podTemplate.Spec.NodeSelector = opts.nodeSelector

// Update tolerations.
podTemplate.Spec.Tolerations = opts.tolerations
}
return k8sClient.Update(ctx, &jsGet)
}, timeout, interval).Should(gomega.Succeed())
Expand All @@ -1496,29 +1523,48 @@ func matchJobsSuspendState(js *jobset.JobSet, suspend bool) (bool, error) {
return true, nil
}

func matchJobsNodeSelectors(js *jobset.JobSet, nodeSelectors map[string]map[string]string) (bool, error) {
func checkPodTemplateUpdates(js *jobset.JobSet, podTemplateUpdates *updatePodTemplateOpts) (bool, error) {
var jobList batchv1.JobList
if err := k8sClient.List(ctx, &jobList, client.InNamespace(js.Namespace)); err != nil {
return false, err
}
// Count number of updated jobs
jobsUpdated := 0
for _, job := range jobList.Items {
rjobName, ok := job.Labels[jobset.ReplicatedJobNameKey]
if !ok {
return false, fmt.Errorf(fmt.Sprintf("%s job missing ReplicatedJobName label", job.Name))
// Check label was added.
for label, value := range podTemplateUpdates.labels {
if job.Spec.Template.Labels[label] != value {
return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[label], value)
}
}
if !apiequality.Semantic.DeepEqual(job.Spec.Template.Spec.NodeSelector, nodeSelectors[rjobName]) {
return false, nil

// Check annotation was added.
for annotation, value := range podTemplateUpdates.annotations {
if job.Spec.Template.Annotations[annotation] != value {
return false, fmt.Errorf("%s != %s", job.Spec.Template.Labels[annotation], value)
}
}

// Check nodeSelector was updated.
for label, value := range podTemplateUpdates.nodeSelector {
if job.Spec.Template.Spec.NodeSelector[label] != value {
return false, fmt.Errorf("%s != %s", job.Spec.Template.Spec.NodeSelector[label], value)
}
}

// Check tolerations were updated.
for _, toleration := range podTemplateUpdates.tolerations {
if !collections.Contains(job.Spec.Template.Spec.Tolerations, toleration) {
return false, fmt.Errorf("missing toleration %v", toleration)
}
}

jobsUpdated++
}
// Calculate expected number of updated jobs
wantJobsUpdated := 0
for _, rjob := range js.Spec.ReplicatedJobs {
if _, exists := nodeSelectors[rjob.Name]; exists {
wantJobsUpdated += int(rjob.Replicas)
}
wantJobsUpdated += int(rjob.Replicas)
}
return wantJobsUpdated == jobsUpdated, nil
}
Expand Down

0 comments on commit d809200

Please sign in to comment.