diff --git a/mysqlcluster/syncer/statefulset.go b/mysqlcluster/syncer/statefulset.go index aa45ffa4..b29512a6 100644 --- a/mysqlcluster/syncer/statefulset.go +++ b/mysqlcluster/syncer/statefulset.go @@ -25,7 +25,6 @@ import ( "github.com/go-logr/logr" "github.com/iancoleman/strcase" "github.com/imdario/mergo" - "github.com/presslabs/controller-util/pkg/mergo/transformers" "github.com/presslabs/controller-util/pkg/syncer" appsv1 "k8s.io/api/apps/v1" @@ -367,43 +366,46 @@ func (s *StatefulSetSyncer) updatePod(ctx context.Context) error { // mutate set the statefulset. func (s *StatefulSetSyncer) mutate() error { - s.sfs.Spec.ServiceName = s.GetNameForResource(utils.StatefulSet) - s.sfs.Spec.Replicas = s.Spec.Replicas - s.sfs.Spec.Selector = metav1.SetAsLabelSelector(s.GetSelectorLabels()) - s.sfs.Spec.UpdateStrategy = appsv1.StatefulSetUpdateStrategy{ - Type: appsv1.OnDeleteStatefulSetStrategyType, - } - - s.sfs.Spec.Template.ObjectMeta.Labels = s.GetLabels() + // build lables. + podLables := s.GetLabels() for k, v := range s.Spec.PodPolicy.Labels { - s.sfs.Spec.Template.ObjectMeta.Labels[k] = v + podLables[k] = v } - s.sfs.Spec.Template.ObjectMeta.Labels["role"] = string(utils.Candidate) - s.sfs.Spec.Template.ObjectMeta.Labels["healthy"] = "no" - - s.sfs.Spec.Template.Annotations = s.Spec.PodPolicy.Annotations - if len(s.sfs.Spec.Template.ObjectMeta.Annotations) == 0 { - s.sfs.Spec.Template.ObjectMeta.Annotations = make(map[string]string) + podLables["role"] = string(utils.Follower) + podLables["healthy"] = "no" + // build annotations. + podAnnotations := make(map[string]string) + if len(s.Spec.PodPolicy.Annotations) > 0 { + podAnnotations = s.Spec.PodPolicy.Annotations } if s.Spec.MetricsOpts.Enabled { - s.sfs.Spec.Template.ObjectMeta.Annotations["prometheus.io/scrape"] = "true" - s.sfs.Spec.Template.ObjectMeta.Annotations["prometheus.io/port"] = fmt.Sprintf("%d", utils.MetricsPort) - } - s.sfs.Spec.Template.ObjectMeta.Annotations["config_rev"] = s.cmRev - s.sfs.Spec.Template.ObjectMeta.Annotations["secret_rev"] = s.sctRev - - err := mergo.Merge(&s.sfs.Spec.Template.Spec, s.ensurePodSpec(), mergo.WithTransformers(transformers.PodSpec)) - if err != nil { - return err + podAnnotations["prometheus.io/scrape"] = "true" + podAnnotations["prometheus.io/port"] = fmt.Sprintf("%d", utils.MetricsPort) + } + podAnnotations["config_rev"] = s.cmRev + podAnnotations["secret_rev"] = s.sctRev + + templateSpec := appsv1.StatefulSetSpec{ + Replicas: s.Spec.Replicas, + ServiceName: s.GetNameForResource(utils.StatefulSet), + Selector: metav1.SetAsLabelSelector(s.GetSelectorLabels()), + UpdateStrategy: appsv1.StatefulSetUpdateStrategy{ + Type: appsv1.OnDeleteStatefulSetStrategyType, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: podLables, + Annotations: podAnnotations, + }, + Spec: ensurePodSpec(s.MysqlCluster), + }, } - s.sfs.Spec.Template.Spec.Tolerations = s.Spec.PodPolicy.Tolerations - if s.Spec.Persistence.Enabled { - if s.sfs.Spec.VolumeClaimTemplates, err = s.EnsureVolumeClaimTemplates(s.cli.Scheme()); err != nil { + var err error + if templateSpec.VolumeClaimTemplates, err = s.EnsureVolumeClaimTemplates(s.cli.Scheme()); err != nil { return err } } - // Set owner reference only if owner resource is not being deleted, otherwise the owner // reference will be reset in case of deleting with cascade=false. if s.Unwrap().GetDeletionTimestamp().IsZero() { @@ -415,38 +417,38 @@ func (s *StatefulSetSyncer) mutate() error { // will not delete it again because has no owner reference set. return fmt.Errorf("owner is deleted") } - return nil + return mergo.Merge(&s.sfs.Spec, templateSpec, mergo.WithTransformers(utils.StsSpec)) } // ensurePodSpec used to ensure the podspec. -func (s *StatefulSetSyncer) ensurePodSpec() corev1.PodSpec { - initSidecar := container.EnsureContainer(utils.ContainerInitSidecarName, s.MysqlCluster) - initMysql := container.EnsureContainer(utils.ContainerInitMysqlName, s.MysqlCluster) +func ensurePodSpec(c *mysqlcluster.MysqlCluster) corev1.PodSpec { + initSidecar := container.EnsureContainer(utils.ContainerInitSidecarName, c) + initMysql := container.EnsureContainer(utils.ContainerInitMysqlName, c) initContainers := []corev1.Container{initSidecar, initMysql} - mysql := container.EnsureContainer(utils.ContainerMysqlName, s.MysqlCluster) - xenon := container.EnsureContainer(utils.ContainerXenonName, s.MysqlCluster) - backup := container.EnsureContainer(utils.ContainerBackupName, s.MysqlCluster) + mysql := container.EnsureContainer(utils.ContainerMysqlName, c) + xenon := container.EnsureContainer(utils.ContainerXenonName, c) + backup := container.EnsureContainer(utils.ContainerBackupName, c) containers := []corev1.Container{mysql, xenon, backup} - if s.Spec.MetricsOpts.Enabled { - containers = append(containers, container.EnsureContainer(utils.ContainerMetricsName, s.MysqlCluster)) + if c.Spec.MetricsOpts.Enabled { + containers = append(containers, container.EnsureContainer(utils.ContainerMetricsName, c)) } - if s.Spec.PodPolicy.SlowLogTail { - containers = append(containers, container.EnsureContainer(utils.ContainerSlowLogName, s.MysqlCluster)) + if c.Spec.PodPolicy.SlowLogTail { + containers = append(containers, container.EnsureContainer(utils.ContainerSlowLogName, c)) } - if s.Spec.PodPolicy.AuditLogTail { - containers = append(containers, container.EnsureContainer(utils.ContainerAuditLogName, s.MysqlCluster)) + if c.Spec.PodPolicy.AuditLogTail { + containers = append(containers, container.EnsureContainer(utils.ContainerAuditLogName, c)) } return corev1.PodSpec{ InitContainers: initContainers, Containers: containers, - Volumes: s.EnsureVolumes(), - SchedulerName: s.Spec.PodPolicy.SchedulerName, - ServiceAccountName: s.GetNameForResource(utils.ServiceAccount), - Affinity: s.Spec.PodPolicy.Affinity, - PriorityClassName: s.Spec.PodPolicy.PriorityClassName, - Tolerations: s.Spec.PodPolicy.Tolerations, + Volumes: c.EnsureVolumes(), + SchedulerName: c.Spec.PodPolicy.SchedulerName, + ServiceAccountName: c.GetNameForResource(utils.ServiceAccount), + Affinity: c.Spec.PodPolicy.Affinity, + PriorityClassName: c.Spec.PodPolicy.PriorityClassName, + Tolerations: c.Spec.PodPolicy.Tolerations, } } @@ -589,6 +591,9 @@ func (s *StatefulSetSyncer) backupIsRunning(ctx context.Context) (bool, error) { // Updates to statefulset spec for fields other than 'replicas', 'template', and 'updateStrategy' are forbidden. func (s *StatefulSetSyncer) sfsUpdated(existing *appsv1.StatefulSet) bool { + if s.sfs.Status.UpdateRevision != s.sfs.Status.CurrentRevision { + return true + } var resizeVolume = false // TODO: this is a temporary workaround until we figure out a better way to do this. if len(existing.Spec.VolumeClaimTemplates) > 0 && len(s.sfs.Spec.VolumeClaimTemplates) > 0 { diff --git a/utils/transform_suite_test.go b/utils/transform_suite_test.go new file mode 100644 index 00000000..c29566e3 --- /dev/null +++ b/utils/transform_suite_test.go @@ -0,0 +1,13 @@ +package utils_test + +import ( + "testing" + + ginkgo "github.com/onsi/ginkgo/v2" + gomega "github.com/onsi/gomega" +) + +func TestTrasformer(t *testing.T) { + gomega.RegisterFailHandler(ginkgo.Fail) + ginkgo.RunSpecs(t, "transformer") +} diff --git a/utils/transformer_test.go b/utils/transformer_test.go new file mode 100644 index 00000000..25edefc5 --- /dev/null +++ b/utils/transformer_test.go @@ -0,0 +1,138 @@ +package utils_test + +import ( + "github.com/imdario/mergo" + ginkgo "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/radondb/radondb-mysql-kubernetes/utils" +) + +var _ = ginkgo.Describe("transformer", func() { + var two int32 = 2 + var three int32 = 3 + + templatePodSpec := corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "mysql", + Image: "percona:latest", + Command: []string{ + "cmd1", + }, + Env: []corev1.EnvVar{ + { + Name: "MYSQL_ROOT_PASSWORD", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: "sample-password", + }, + Key: "MYSQL_ROOT_PASSWORD", + }, + }, + }, + }, + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("1")}, + }, + }, + { + Name: "xenon", + Image: "xenon:latest", + Env: []corev1.EnvVar{ + { + Name: "oldenv", + Value: "oldenv", + }, + }, + }, + }, + } + templateStsSpec := appsv1.StatefulSetSpec{ + Replicas: &two, + UpdateStrategy: appsv1.StatefulSetUpdateStrategy{ + Type: appsv1.OnDeleteStatefulSetStrategyType, + }, + Template: corev1.PodTemplateSpec{ + Spec: templatePodSpec, + }, + } + + ginkgo.It("init should successfully", func() { + actualStsSpec := appsv1.StatefulSetSpec{} + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(actualStsSpec).Should(gomega.Equal(templateStsSpec)) + }) + + ginkgo.It("merge n times should successfully", func() { + actualStsSpec := appsv1.StatefulSetSpec{} + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + gomega.Expect(actualStsSpec.Template.Spec.Containers[0].Command[0]).Should(gomega.Equal("cmd1")) + + actualStsSpec.Template.Spec.Containers[0].Command = append(actualStsSpec.Template.Spec.Containers[0].Command, "cmd2") + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(len(actualStsSpec.Template.Spec.Containers[0].Command)).Should(gomega.Equal(2)) + gomega.Expect(actualStsSpec.Template.Spec.Containers[0].Command[0]).Should(gomega.Equal("cmd1")) + gomega.Expect(actualStsSpec.Template.Spec.Containers[0].Command[1]).Should(gomega.Equal("cmd2")) + }) + + ginkgo.It("add containers should successfully", func() { + actualStsSpec := *templateStsSpec.DeepCopy() + actualStsSpec.Template.Spec.Containers = append(actualStsSpec.Template.Spec.Containers, corev1.Container{Name: "test"}) + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(len(actualStsSpec.Template.Spec.Containers)).Should(gomega.Equal(3)) + gomega.Expect(actualStsSpec.Template.Spec.Containers[2].Name).Should(gomega.Equal("test")) + }) + + ginkgo.It("modify envs should successfully", func() { + actualStsSpec := *templateStsSpec.DeepCopy() + actualStsSpec.Template.Spec.Containers[0].Env[0].ValueFrom = &corev1.EnvVarSource{} + actualStsSpec.Template.Spec.Containers[1].Env = append(actualStsSpec.Template.Spec.Containers[1].Env, corev1.EnvVar{Name: "newenv", Value: "newenv"}) + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(actualStsSpec.Template.Spec.Containers[0].Env[0].ValueFrom).ShouldNot(gomega.BeNil()) + gomega.Expect(len(actualStsSpec.Template.Spec.Containers[1].Env)).Should(gomega.Equal(2)) + }) + + ginkgo.It("modify container image should not successfully", func() { + actualStsSpec := *templateStsSpec.DeepCopy() + actualStsSpec.Template.Spec.Containers[0].Image = "mysql:latest" + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(len(actualStsSpec.Template.Spec.Containers)).Should(gomega.Equal(2)) + gomega.Expect(actualStsSpec.Template.Spec.Containers[0].Image).ShouldNot(gomega.Equal("mysql:latest")) + }) + + ginkgo.It("merge replicas,updateStrategy should not successfully", func() { + actualStsSpec := *templateStsSpec.DeepCopy() + actualStsSpec.Replicas = &three + actualStsSpec.UpdateStrategy = appsv1.StatefulSetUpdateStrategy{ + Type: appsv1.RollingUpdateStatefulSetStrategyType, + } + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(actualStsSpec.Replicas).Should(gomega.Equal(&two)) + gomega.Expect(actualStsSpec.UpdateStrategy.Type).Should(gomega.Equal(appsv1.OnDeleteStatefulSetStrategyType)) + }) + + ginkgo.It("modify resources should not successfully", func() { + actualStsSpec := *templateStsSpec.DeepCopy() + testResources := corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("3"), + }, + } + actualStsSpec.Template.Spec.Containers[0].Resources = testResources + mergo.Merge(&actualStsSpec, templateStsSpec, mergo.WithTransformers(utils.StsSpec)) + + gomega.Expect(actualStsSpec.Template.Spec.Containers[0].Resources.Limits).ShouldNot(gomega.Equal(testResources)) + }) +}) diff --git a/utils/transfromer.go b/utils/transfromer.go new file mode 100644 index 00000000..70ba981b --- /dev/null +++ b/utils/transfromer.go @@ -0,0 +1,213 @@ +package utils + +import ( + "errors" + "reflect" + + "github.com/imdario/mergo" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" +) + +// TransformerMap is a mergo.Transformers implementation. +type TransformerMap map[reflect.Type]func(dst, src reflect.Value) error + +// StsSpec mergo transformers for corev1.StsSpec. +var StsSpec TransformerMap + +var errCannotMerge = errors.New("cannot merge when key type differs") + +func init() { // nolint: gochecknoinits + StsSpec = TransformerMap{ + reflect.TypeOf([]corev1.Container{}): StsSpec.MergeListByKey("Name", mergo.WithAppendSlice), + reflect.TypeOf([]corev1.ContainerPort{}): StsSpec.MergeListByKey("ContainerPort", mergo.WithOverride), + reflect.TypeOf([]corev1.EnvVar{}): StsSpec.MergeListByKey("Name", mergo.WithAppendSlice), + reflect.TypeOf(corev1.EnvVar{}): StsSpec.OverrideFields("Value", "ValueFrom"), + reflect.TypeOf(corev1.VolumeSource{}): StsSpec.NilOtherFields(), + reflect.TypeOf([]corev1.Toleration{}): StsSpec.MergeListByKey("Key", mergo.WithOverride), + reflect.TypeOf([]corev1.Volume{}): StsSpec.MergeListByKey("Name", mergo.WithOverride), + reflect.TypeOf([]corev1.LocalObjectReference{}): StsSpec.MergeListByKey("Name", mergo.WithOverride), + reflect.TypeOf([]corev1.HostAlias{}): StsSpec.MergeListByKey("IP", mergo.WithOverride), + reflect.TypeOf([]corev1.VolumeMount{}): StsSpec.MergeListByKey("MountPath", mergo.WithOverride), + reflect.TypeOf(corev1.Affinity{}): StsSpec.OverrideFields("NodeAffinity", "PodAffinity", "PodAntiAffinity"), + reflect.TypeOf(""): overwrite, + reflect.TypeOf(new(string)): overwrite, + reflect.TypeOf(new(int32)): overwrite, + reflect.TypeOf(new(int64)): overwrite, + reflect.TypeOf([]string{}): overwriteIfEmpty, // Command, Args, etc. + reflect.TypeOf(appsv1.StatefulSetUpdateStrategy{}): overwrite, + } +} + +// Transformer implements mergo.Tansformers interface for TransformenrMap. +func (s TransformerMap) Transformer(t reflect.Type) func(dst, src reflect.Value) error { + if fn, ok := s[t]; ok { + return fn + } + return nil +} + +// overwrite just overrites the dst value with the source. +func overwrite(dst, src reflect.Value) error { + if !src.IsZero() { + if dst.CanSet() { + dst.Set(src) + } + } + + return nil +} + +func overwriteIfEmpty(dst, src reflect.Value) error { + if src.IsZero() { + return nil + } + if dst.IsZero() { + if dst.CanSet() { + dst.Set(src) + } + } + + return nil +} + +func (s *TransformerMap) mergeByKey(key string, dst, elem reflect.Value, opts ...func(*mergo.Config)) error { + elemKey := elem.FieldByName(key) + + for i := 0; i < dst.Len(); i++ { + dstKey := dst.Index(i).FieldByName(key) + + if elemKey.Kind() != dstKey.Kind() { + return errCannotMerge + } + + eq := eq(key, elem, dst.Index(i)) + if eq { + opts = append(opts, mergo.WithTransformers(s)) + + return mergo.Merge(dst.Index(i).Addr().Interface(), elem.Interface(), opts...) + } + } + + dst.Set(reflect.Append(dst, elem)) + + return nil +} + +func eq(key string, a, b reflect.Value) bool { + aKey := a.FieldByName(key) + bKey := b.FieldByName(key) + + if aKey.Kind() != bKey.Kind() { + return false + } + + eq := false + + // nolint: exhaustive + switch aKey.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + eq = aKey.Int() == bKey.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + eq = aKey.Uint() == bKey.Uint() + case reflect.String: + eq = aKey.String() == bKey.String() + case reflect.Float32, reflect.Float64: + eq = aKey.Float() == bKey.Float() + case reflect.Bool: + eq = aKey.Bool() == bKey.Bool() + case reflect.Complex128, reflect.Complex64: + eq = aKey.Complex() == bKey.Complex() + case reflect.Interface: + eq = aKey.Interface() == bKey.Interface() + case reflect.Map: + eq = aKey.MapRange() == bKey.MapRange() + } + + return eq +} + +func indexByKey(key string, v reflect.Value, list reflect.Value) (int, bool) { + for i := 0; i < list.Len(); i++ { + if eq(key, v, list.Index(i)) { + return i, true + } + } + + return -1, false +} + +// MergeListByKey merges two list by element key (eg. merge []corev1.Container +// by name). If mergo.WithAppendSlice options is passed, the list is extended, +// while elemnts with same name are merged. If not, the list is filtered to +// elements in src. +func (s *TransformerMap) MergeListByKey(key string, opts ...func(*mergo.Config)) func(_, _ reflect.Value) error { + conf := &mergo.Config{} + + for _, opt := range opts { + opt(conf) + } + + return func(dst, src reflect.Value) error { + // entries := reflect.MakeSlice(src.Type(), src.Len(), src.Len()) + entries := reflect.MakeSlice(dst.Type(), dst.Len(), dst.Len()) + + for i := 0; i < src.Len(); i++ { + elem := src.Index(i) + + if err := s.mergeByKey(key, dst, elem, opts...); err != nil { + return err + } + + j, found := indexByKey(key, elem, dst) + if found { + entries.Index(i).Set(dst.Index(j)) + } + } + + if !conf.AppendSlice { + dst.SetLen(entries.Len()) + dst.SetCap(entries.Cap()) + dst.Set(entries) + } + + return nil + } +} + +// NilOtherFields nils all fields not defined in src. +func (s *TransformerMap) NilOtherFields(opts ...func(*mergo.Config)) func(_, _ reflect.Value) error { + return func(dst, src reflect.Value) error { + for i := 0; i < dst.NumField(); i++ { + dstField := dst.Type().Field(i) + srcValue := src.FieldByName(dstField.Name) + dstValue := dst.FieldByName(dstField.Name) + + if srcValue.Kind() == reflect.Ptr && srcValue.IsNil() { + dstValue.Set(srcValue) + } else { + if dstValue.Kind() == reflect.Ptr && dstValue.IsNil() { + dstValue.Set(srcValue) + } else { + opts = append(opts, mergo.WithTransformers(s)) + + return mergo.Merge(dstValue.Interface(), srcValue.Interface(), opts...) + } + } + } + + return nil + } +} + +// OverrideFields when merging override fields even if they are zero values (eg. nil or empty list). +func (s *TransformerMap) OverrideFields(fields ...string) func(_, _ reflect.Value) error { + return func(dst, src reflect.Value) error { + for _, field := range fields { + srcValue := src.FieldByName(field) + dst.FieldByName(field).Set(srcValue) + } + + return nil + } +}