From 151cdbac3ff82dd093748309ecea043fd53278e0 Mon Sep 17 00:00:00 2001 From: Yi Tao Date: Wed, 21 Feb 2024 18:18:49 +0800 Subject: [PATCH] add licenses in kong state and sync diff --- pkg/diff/diff.go | 2 +- pkg/state/license.go | 146 +++++++++++++++++++++++++ pkg/state/license_test.go | 220 ++++++++++++++++++++++++++++++++++++++ pkg/state/state.go | 3 + pkg/state/types.go | 37 +++++++ pkg/types/core.go | 20 +++- pkg/types/license.go | 156 +++++++++++++++++++++++++++ pkg/types/postProcess.go | 16 +++ 8 files changed, 598 insertions(+), 2 deletions(-) create mode 100644 pkg/state/license.go create mode 100644 pkg/state/license_test.go create mode 100644 pkg/types/license.go diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index 7820642..ab8a901 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -172,7 +172,7 @@ func (sc *Syncer) init() error { types.HMACAuth, types.JWTAuth, types.OAuth2Cred, types.MTLSAuth, - types.Vault, + types.Vault, types.License, types.RBACRole, types.RBACEndpointPermission, diff --git a/pkg/state/license.go b/pkg/state/license.go new file mode 100644 index 0000000..4fd63fd --- /dev/null +++ b/pkg/state/license.go @@ -0,0 +1,146 @@ +package state + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" + "github.com/kong/go-database-reconciler/pkg/utils" +) + +const ( + licenseTableName = "license" +) + +var licenseTableSchema = &memdb.TableSchema{ + Name: licenseTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "ID"}, + }, + all: allIndex, + }, +} + +type LicensesCollection collection + +func getLicense(txn *memdb.Txn, id string) (*License, error) { + res, err := multiIndexLookupUsingTxn(txn, licenseTableName, []string{"id"}, id) + if err != nil { + return nil, err + } + l, ok := res.(*License) + if !ok { + panic(unexpectedType) + } + return &License{License: *l.DeepCopy()}, nil +} + +func (k *LicensesCollection) Add(l License) error { + if utils.Empty(l.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + _, err := getLicense(txn, *l.ID) + if err == nil { + return fmt.Errorf("inserting license %v: %w", l.Console(), ErrAlreadyExists) + } + if !errors.Is(err, ErrNotFound) { + return err + } + err = txn.Insert(licenseTableName, &l) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func (k *LicensesCollection) Get(id string) (*License, error) { + if id == "" { + return nil, errIDRequired + } + txn := k.db.Txn(false) + defer txn.Abort() + + l, err := getLicense(txn, id) + + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + txn.Commit() + return l, nil +} + +func deleteLicense(txn *memdb.Txn, id string) error { + l, err := getLicense(txn, id) + if err != nil { + return err + } + + return txn.Delete(licenseTableName, l) +} + +func (k *LicensesCollection) Update(l License) error { + if utils.Empty(l.ID) { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteLicense(txn, *l.ID) + if err != nil { + return err + } + + err = txn.Insert(licenseTableName, &l) + if err != nil { + return err + } + txn.Commit() + return nil +} + +func (k *LicensesCollection) Delete(id string) error { + if id == "" { + return errIDRequired + } + txn := k.db.Txn(true) + defer txn.Abort() + + err := deleteLicense(txn, id) + if err != nil { + return err + } + + txn.Commit() + return err +} + +func (k *LicensesCollection) GetAll() ([]*License, error) { + txn := k.db.Txn(false) + defer txn.Abort() + iter, err := txn.Get(licenseTableName, all, true) + if err != nil { + return nil, err + } + + var res []*License + for el := iter.Next(); el != nil; el = iter.Next() { + l, ok := el.(*License) + if !ok { + panic(unexpectedType) + } + res = append(res, &License{License: *l.DeepCopy()}) + } + txn.Commit() + return res, nil + +} diff --git a/pkg/state/license_test.go b/pkg/state/license_test.go new file mode 100644 index 0000000..cb8d335 --- /dev/null +++ b/pkg/state/license_test.go @@ -0,0 +1,220 @@ +package state + +import ( + "testing" + + "github.com/kong/go-kong/kong" + "github.com/stretchr/testify/assert" +) + +var presetLicense = License{ + License: kong.License{ + ID: kong.String("license-preset"), + Payload: kong.String("preset-license-payload"), + }, +} + +func TestLicenseCollection_Add(t *testing.T) { + testCases := []struct { + name string + license *License + expectedError error + }{ + { + name: "insert with no ID", + license: &License{ + License: kong.License{}, + }, + expectedError: errIDRequired, + }, + { + name: "insert with ID and payload", + license: &License{ + License: kong.License{ + ID: kong.String("1234"), + Payload: kong.String("license-test"), + }, + }, + }, + { + name: "insert a license with existing ID", + license: &License{ + License: kong.License{ + ID: kong.String("license-preset"), + Payload: kong.String("license-test"), + }, + }, + expectedError: ErrAlreadyExists, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + initialState := state() + c := initialState.Licenses + err := c.Add(presetLicense) + assert.NoError(t, err) + + err = c.Add(*tc.license) + if tc.expectedError != nil { + assert.ErrorAs(t, err, &tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestLicenseCollection_Get(t *testing.T) { + testCases := []struct { + name string + id string + expectedPayload string + expectedError error + }{ + { + name: "get existing license", + id: "license-preset", + expectedPayload: "preset-license-payload", + }, + { + name: "get non existing license", + id: "license-non-exist", + expectedError: ErrNotFound, + }, + { + name: "get with empty ID", + id: "", + expectedError: errIDRequired, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + initialState := state() + c := initialState.Licenses + err := c.Add(presetLicense) + assert.NoError(t, err) + + l, err := c.Get(tc.id) + if tc.expectedError == nil { + assert.NoError(t, err) + assert.Equal(t, tc.id, *l.ID) + assert.Equal(t, tc.expectedPayload, *l.Payload) + } else { + assert.ErrorAs(t, err, &tc.expectedError) + } + }) + } +} + +func TestLicenseCollection_Update(t *testing.T) { + testCases := []struct { + name string + license License + expectedError error + }{ + { + name: "update with no ID", + license: License{}, + expectedError: errIDRequired, + }, + { + name: "update non existing license", + license: License{ + License: kong.License{ + ID: kong.String("license-non-exist"), + Payload: kong.String("updated-payload"), + }, + }, + expectedError: ErrNotFound, + }, + { + name: "update existing license", + license: License{ + License: kong.License{ + ID: kong.String("license-preset"), + Payload: kong.String("updated-payload"), + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + initialState := state() + c := initialState.Licenses + err := c.Add(presetLicense) + assert.NoError(t, err) + + err = c.Update(tc.license) + if tc.expectedError == nil { + assert.NoError(t, err) + updatedLicense, err := c.Get(*tc.license.ID) + assert.NoError(t, err) + assert.Equal(t, *tc.license.Payload, *updatedLicense.Payload) + } else { + assert.ErrorAs(t, err, &tc.expectedError) + } + }) + } +} + +func TestLicenseCollection_Delete(t *testing.T) { + testCases := []struct { + name string + id string + expectedError error + }{ + { + name: "delete with no ID", + id: "", + expectedError: errIDRequired, + }, + { + name: "delete non existing license", + id: "license-non-exist", + expectedError: ErrNotFound, + }, + { + name: "delete existing license", + id: "license-preset", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + initialState := state() + c := initialState.Licenses + err := c.Add(presetLicense) + assert.NoError(t, err) + + err = c.Delete(tc.id) + if tc.expectedError == nil { + assert.NoError(t, err) + } else { + assert.ErrorAs(t, err, &tc.expectedError) + } + }) + } +} + +func TestLicenseCollection_GetAll(t *testing.T) { + initialState := state() + c := initialState.Licenses + licenses, err := c.GetAll() + assert.NoError(t, err) + assert.Len(t, licenses, 0, "Should have no licenses") + + err = c.Add(presetLicense) + assert.NoError(t, err) + + licenses, err = c.GetAll() + assert.NoError(t, err) + assert.Len(t, licenses, 1, "Should have 1 license after adding") + +} diff --git a/pkg/state/state.go b/pkg/state/state.go index 7972b77..1c5e255 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -24,6 +24,7 @@ type KongState struct { Plugins *PluginsCollection Consumers *ConsumersCollection Vaults *VaultsCollection + Licenses *LicensesCollection ConsumerGroups *ConsumerGroupsCollection ConsumerGroupConsumers *ConsumerGroupConsumersCollection ConsumerGroupPlugins *ConsumerGroupPluginsCollection @@ -71,6 +72,7 @@ func NewKongState() (*KongState, error) { rbacRoleTableName: rbacRoleTableSchema, rbacEndpointPermissionTableName: rbacEndpointPermissionTableSchema, vaultTableName: vaultTableSchema, + licenseTableName: licenseTableSchema, keyAuthTemp.TableName(): keyAuthTemp.Schema(), hmacAuthTemp.TableName(): hmacAuthTemp.Schema(), @@ -112,6 +114,7 @@ func NewKongState() (*KongState, error) { state.RBACRoles = (*RBACRolesCollection)(&state.common) state.RBACEndpointPermissions = (*RBACEndpointPermissionsCollection)(&state.common) state.Vaults = (*VaultsCollection)(&state.common) + state.Licenses = (*LicensesCollection)(&state.common) state.KeyAuths = newKeyAuthsCollection(state.common) state.HMACAuths = newHMACAuthsCollection(state.common) diff --git a/pkg/state/types.go b/pkg/state/types.go index 8328eae..1bb9dc4 100644 --- a/pkg/state/types.go +++ b/pkg/state/types.go @@ -1602,3 +1602,40 @@ func (v1 *Vault) EqualWithOpts(v2 *Vault, ignoreID, ignoreTS bool) bool { } return reflect.DeepEqual(v1Copy, v2Copy) } + +type License struct { + kong.License `yaml:",inline"` + Meta +} + +// TODO: add a variable definition to notate that License (and definition of other entities) should satisfy crud.Event interface? + +func (l *License) Identifier() string { + return *l.ID +} + +func (l *License) Console() string { + return l.FriendlyName() +} + +func (l1 *License) Equal(l2 *License) bool { + return l1.EqualWithOpts(l2, false, false) +} + +func (l1 *License) EqualWithOpts(l2 *License, ignoreID, ignoreTS bool) bool { + l1Copy := l1.License.DeepCopy() + l2Copy := l2.License.DeepCopy() + + if ignoreID { + l1Copy.ID = nil + l2Copy.ID = nil + } + if ignoreTS { + l1Copy.CreatedAt = nil + l2Copy.CreatedAt = nil + + l1Copy.UpdatedAt = nil + l2Copy.UpdatedAt = nil + } + return reflect.DeepEqual(l1Copy, l2Copy) +} diff --git a/pkg/types/core.go b/pkg/types/core.go index 5c6d661..feff8d1 100644 --- a/pkg/types/core.go +++ b/pkg/types/core.go @@ -119,6 +119,8 @@ const ( // Vault identifies a Vault in Kong. Vault EntityType = "vault" + // License identifies a License in Kong Enterprise. + License EntityType = "license" ) // AllTypes represents all types defined in the @@ -140,7 +142,7 @@ var AllTypes = []EntityType{ ServicePackage, ServiceVersion, Document, - Vault, + Vault, License, } func entityTypeToKind(t EntityType) crud.Kind { @@ -529,6 +531,22 @@ func NewEntity(t EntityType, opts EntityOpts) (Entity, error) { targetState: opts.TargetState, }, }, nil + case License: + return entityImpl{ + typ: License, + crudActions: &licenseCRUD{ + client: opts.KongClient, + isKonnect: opts.IsKonnect, + }, + postProcessActions: &licensePostAction{ + currentState: opts.CurrentState, + }, + differ: &licenseDiffer{ + kind: entityTypeToKind(License), + currentState: opts.CurrentState, + targetState: opts.TargetState, + }, + }, nil default: return nil, fmt.Errorf("unknown type: %q", t) } diff --git a/pkg/types/license.go b/pkg/types/license.go new file mode 100644 index 0000000..a0c3b33 --- /dev/null +++ b/pkg/types/license.go @@ -0,0 +1,156 @@ +package types + +import ( + "context" + "errors" + + "github.com/kong/go-database-reconciler/pkg/crud" + "github.com/kong/go-database-reconciler/pkg/state" + "github.com/kong/go-kong/kong" +) + +type licenseCRUD struct { + client *kong.Client + // TODO: disable CRUDs when Konnect is enabled? + isKonnect bool +} + +var _ crud.Actions = &licenseCRUD{} + +func licenseFromEventStruct(arg crud.Event) *state.License { + license, ok := arg.Obj.(*state.License) + if !ok { + panic("unexpected type, expected *state.License") + } + return license +} + +// Create creates a License in Kong. +func (s *licenseCRUD) Create(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + license := licenseFromEventStruct(event) + createdLicense, err := s.client.Licenses.Create(ctx, &license.License) + if err != nil { + return nil, err + } + return &state.License{License: *createdLicense}, nil +} + +// Delete deletes a License in Kong. +func (s *licenseCRUD) Delete(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + license := licenseFromEventStruct(event) + err := s.client.Licenses.Delete(ctx, license.ID) + if err != nil { + return nil, err + } + return license, nil +} + +// Update updates a license in Kong. +func (s *licenseCRUD) Update(ctx context.Context, arg ...crud.Arg) (crud.Arg, error) { + event := crud.EventFromArg(arg[0]) + license := licenseFromEventStruct(event) + + updatedLicense, err := s.client.Licenses.Create(ctx, &license.License) + if err != nil { + return nil, err + } + return &state.License{License: *updatedLicense}, nil +} + +type licenseDiffer struct { + kind crud.Kind + + currentState, targetState *state.KongState +} + +var _ Differ = &licenseDiffer{} + +func (d *licenseDiffer) maybeCreateOrUpdateLicense(targetLicense *state.License) (*crud.Event, error) { + licenseCopy := &state.License{License: *targetLicense.License.DeepCopy()} + currentLicense, err := d.currentState.Licenses.Get(*targetLicense.ID) + + if err != nil { + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Create, + // TODO: define consts for used `crud.Kind`s? + Kind: "license", + Obj: licenseCopy, + }, nil + } + return nil, err + } + + if !currentLicense.EqualWithOpts(licenseCopy, false, true) { + return &crud.Event{ + Op: crud.Update, + Kind: "license", + Obj: licenseCopy, + OldObj: currentLicense, + }, nil + } + + return nil, nil +} + +func (d *licenseDiffer) CreateAndUpdates(handler func(crud.Event) error) error { + targetLicenses, err := d.targetState.Licenses.GetAll() + if err != nil { + return err + } + + for _, license := range targetLicenses { + event, err := d.maybeCreateOrUpdateLicense(license) + if err != nil { + return err + } + + if event != nil { + err := handler(*event) + if err != nil { + return err + } + } + } + return nil +} + +func (d *licenseDiffer) maybeDeleteLicense(currentLicense *state.License) (*crud.Event, error) { + _, err := d.targetState.Licenses.Get(*currentLicense.ID) + if err != nil { + if errors.Is(err, state.ErrNotFound) { + return &crud.Event{ + Op: crud.Delete, + Kind: "license", + Obj: currentLicense, + }, nil + } + + return nil, err + } + return nil, nil +} + +func (d *licenseDiffer) Deletes(handler func(crud.Event) error) error { + currentLicenses, err := d.currentState.Licenses.GetAll() + if err != nil { + return err + } + + for _, license := range currentLicenses { + event, err := d.maybeDeleteLicense(license) + if err != nil { + return err + } + if event != nil { + err := handler(*event) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/types/postProcess.go b/pkg/types/postProcess.go index 9d0bcf7..56ea79a 100644 --- a/pkg/types/postProcess.go +++ b/pkg/types/postProcess.go @@ -455,3 +455,19 @@ func (crud vaultPostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Ar func (crud vaultPostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { return nil, crud.currentState.Vaults.Update(*args[0].(*state.Vault)) } + +type licensePostAction struct { + currentState *state.KongState +} + +func (crud licensePostAction) Create(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Licenses.Add(*args[0].(*state.License)) +} + +func (crud licensePostAction) Delete(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Licenses.Delete(*((args[0].(*state.License)).ID)) +} + +func (crud licensePostAction) Update(_ context.Context, args ...crud.Arg) (crud.Arg, error) { + return nil, crud.currentState.Licenses.Update(*args[0].(*state.License)) +}