Skip to content

Commit

Permalink
Use a zero-change patch to do find scheme caused fields removal.
Browse files Browse the repository at this point in the history
  • Loading branch information
trasc committed Oct 14, 2024
1 parent 13c4991 commit 5e7dd60
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
52 changes: 41 additions & 11 deletions pkg/webhook/admission/defaulter_custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ import (
"encoding/json"
"errors"
"net/http"
"slices"

"gomodules.xyz/jsonpatch/v2"
admissionv1 "k8s.io/api/admission/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/sets"
)

// CustomDefaulter defines functions for setting defaults on resources.
Expand Down Expand Up @@ -71,32 +74,59 @@ func (h *defaulterForType) Handle(ctx context.Context, req Request) Response {
ctx = NewContextWithRequest(ctx, req)

// Get the object in the request
original := h.object.DeepCopyObject()
if err := h.decoder.Decode(req, original); err != nil {
obj := h.object.DeepCopyObject()
if err := h.decoder.Decode(req, obj); err != nil {
return Errored(http.StatusBadRequest, err)
}

// Keep a copy of the object
originalObj := obj.DeepCopyObject()

// Default the object
updated := original.DeepCopyObject()
if err := h.defaulter.Default(ctx, updated); err != nil {
if err := h.defaulter.Default(ctx, obj); err != nil {
var apiStatus apierrors.APIStatus
if errors.As(err, &apiStatus) {
return validationResponseFromStatus(false, apiStatus.Status())
}
return Denied(err.Error())
}

// Create the patch.
// We need to decode and marshall the original because the type registered in the
// decoder might not match the latest version of the API.
// Creating a diff from the raw object might cause new fields to be dropped.
marshalledOriginal, err := json.Marshal(original)
// Create the patch
marshalled, err := json.Marshal(obj)
if err != nil {
return Errored(http.StatusInternalServerError, err)
}
marshalledUpdated, err := json.Marshal(updated)
handlerResponse := PatchResponseFromRaw(req.Object.Raw, marshalled)

return h.dropSchemeRemovals(handlerResponse, originalObj, req.Object.Raw)
}

func (h *defaulterForType) dropSchemeRemovals(r Response, original runtime.Object, raw []byte) Response {
const opRemove = "remove"
if !r.Allowed || r.PatchType == nil {
return r
}

// If we don't have removals in the patch.
if !slices.ContainsFunc(r.Patches, func(o jsonpatch.JsonPatchOperation) bool { return o.Operation == opRemove }) {
return r
}

// Get the raw to original patch
marshalledOriginal, err := json.Marshal(original)
if err != nil {
return Errored(http.StatusInternalServerError, err)
}
return PatchResponseFromRaw(marshalledOriginal, marshalledUpdated)

patchOriginal := PatchResponseFromRaw(raw, marshalledOriginal).Patches
removedByScheme := sets.New(slices.DeleteFunc(patchOriginal, func(p jsonpatch.JsonPatchOperation) bool { return p.Operation != opRemove })...)

r.Patches = slices.DeleteFunc(r.Patches, func(p jsonpatch.JsonPatchOperation) bool {
return removedByScheme.Has(p)
})

if len(r.Patches) == 0 {
r.PatchType = nil
}
return r
}
25 changes: 17 additions & 8 deletions pkg/webhook/admission/defaulter_custom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,22 @@ var _ = Describe("Defaulter Handler", func() {
AdmissionRequest: admissionv1.AdmissionRequest{
Operation: admissionv1.Create,
Object: runtime.RawExtension{
Raw: []byte(`{"newField":"foo"}`),
Raw: []byte(`{"newField":"foo", "totalReplicas":5}`),
},
},
})
Expect(resp.Allowed).Should(BeTrue())
Expect(resp.Patches).To(Equal([]jsonpatch.JsonPatchOperation{{
Operation: "add",
Path: "/replica",
Value: 2.0,
}}))
Expect(resp.Patches).To(Equal([]jsonpatch.JsonPatchOperation{
{
Operation: "add",
Path: "/replica",
Value: 2.0,
},
{
Operation: "remove",
Path: "/totalReplicas",
},
}))
Expect(resp.Result.Code).Should(Equal(int32(http.StatusOK)))
})

Expand All @@ -70,15 +76,17 @@ var _ = Describe("Defaulter Handler", func() {
var _ runtime.Object = &TestDefaulter{}

type TestDefaulter struct {
Replica int `json:"replica,omitempty"`
Replica int `json:"replica,omitempty"`
TotalReplicas int `json:"totalReplicas,omitempty"`
}

var testDefaulterGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: "TestDefaulter"}

func (d *TestDefaulter) GetObjectKind() schema.ObjectKind { return d }
func (d *TestDefaulter) DeepCopyObject() runtime.Object {
return &TestDefaulter{
Replica: d.Replica,
Replica: d.Replica,
TotalReplicas: d.TotalReplicas,
}
}

Expand All @@ -103,5 +111,6 @@ func (d *TestCustomDefaulter) Default(ctx context.Context, obj runtime.Object) e
if o.Replica < 2 {
o.Replica = 2
}
o.TotalReplicas = 0
return nil
}

0 comments on commit 5e7dd60

Please sign in to comment.