Skip to content

Commit

Permalink
Merge pull request #5 from armosec/middleware
Browse files Browse the repository at this point in the history
add utils
  • Loading branch information
avrahams authored Aug 23, 2023
2 parents b1fc845 + 54ba753 commit 38dabab
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 174 deletions.
63 changes: 42 additions & 21 deletions server/handler_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ type RequestHandler func(w http.ResponseWriter, r *http.Request, reqBody string)

type RequestHandlerOption func(opts *requestHandlerOptions) error

//WithMethod option sets the method for the handler
// WithMethod option sets the method for the handler
var WithMethod = func(method string) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
o.method = method
return nil
}
}

//WithPath option sets the path for the handler
// WithPath option sets the path for the handler
var WithPath = func(path string) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if path != "" && (o.pathPrefix != "" || o.pathSuffix != "") {
Expand All @@ -31,7 +31,7 @@ var WithPath = func(path string) RequestHandlerOption {
}
}

//WithPathPrefix option sets the path prefix for the handler
// WithPathPrefix option sets the path prefix for the handler
var WithPathPrefix = func(pathPrefix string) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if pathPrefix != "" && o.path != "" {
Expand All @@ -42,7 +42,7 @@ var WithPathPrefix = func(pathPrefix string) RequestHandlerOption {
}
}

//WithPathSuffix option sets the path suffix for the handler
// WithPathSuffix option sets the path suffix for the handler
var WithPathSuffix = func(pathSuffix string) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if pathSuffix != "" && o.path != "" {
Expand All @@ -53,7 +53,7 @@ var WithPathSuffix = func(pathSuffix string) RequestHandlerOption {
}
}

//WithResponse option sets the response for the handler
// WithResponse option sets the response for the handler
var WithResponse = func(response []byte) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if len(response) != 0 && (o.handler != nil || len(o.responses) != 0) {
Expand All @@ -74,7 +74,7 @@ var WithResponses = func(responses [][]byte) RequestHandlerOption {
}
}

//WithHandler option sets the handler for the handler
// WithHandler option sets the handler for the handler
var WithHandler = func(handler RequestHandler) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if handler != nil && len(o.response) != 0 {
Expand All @@ -85,16 +85,34 @@ var WithHandler = func(handler RequestHandler) RequestHandlerOption {
}
}

//WithRequestNumber option sets the request number for the handler
// WithRequestNumber option sets the request number for the handler
var WithRequestNumber = func(reqNum int) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
o.reqNum = reqNum
return nil
}
}

//WithUpdateExpected option sets the update expected flag for the handler
// Deprecated: Use WithTestRequestV1 - keep only for elastic tests
var WithTestRequest = func(t *testing.T, updateExpected bool, expectedRequest []byte, expectedRequestFile string) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if expectedRequest == nil || t == nil {
return fmt.Errorf("test, expected request must be provided")
}
if updateExpected && expectedRequestFile == "" {
return fmt.Errorf("expectedRequestFile must be provided when update expected is true")

}
o.t = t
o.updateExpected = updateExpected
o.expectedRequest = expectedRequest
o.expectedRequestFile = expectedRequestFile
o.deprecatedTestResponse = true
return nil
}
}

var WithTestRequestV1 = func(t *testing.T, updateExpected bool, expectedRequest []byte, expectedRequestFile string) RequestHandlerOption {
return func(o *requestHandlerOptions) error {
if expectedRequest == nil || t == nil {
return fmt.Errorf("test, expected request must be provided")
Expand All @@ -112,18 +130,19 @@ var WithTestRequest = func(t *testing.T, updateExpected bool, expectedRequest []
}

type requestHandlerOptions struct {
method string
path string
response []byte
responses [][]byte
expectedRequest []byte
expectedRequestFile string
updateExpected bool
reqNum int
pathPrefix string
pathSuffix string
handler RequestHandler
t *testing.T
method string
path string
response []byte
responses [][]byte
expectedRequest []byte
expectedRequestFile string
updateExpected bool
reqNum int
pathPrefix string
pathSuffix string
handler RequestHandler
t *testing.T
deprecatedTestResponse bool
}

func (o *requestHandlerOptions) validate() error {
Expand All @@ -140,8 +159,10 @@ func (o *requestHandlerOptions) getOrCreateHandler() RequestHandler {
return o.handler
}
return func(w http.ResponseWriter, r *http.Request, reqBody string) {
if len(o.expectedRequest) != 0 {
if o.deprecatedTestResponse && len(o.expectedRequest) != 0 {
utils.DeepEqualOrUpdate(o.t, []byte(reqBody), o.expectedRequest, o.expectedRequestFile, o.updateExpected)
} else if len(o.expectedRequest) != 0 {
utils.CompareAndUpdate(o.t, []byte(reqBody), o.expectedRequest, o.expectedRequestFile, o.updateExpected)
}
if len(o.response) != 0 {
w.Write(o.response)
Expand Down
171 changes: 171 additions & 0 deletions utils/deprecated.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package utils

import (
"encoding/json"
"sort"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
)

// Deprecated: use CompareAndUpdate
func CompareOrUpdate[T any](actual T, expectedBytes []byte, expectedFileName string, t *testing.T, update bool) {
if !Equal(t, actual, expectedBytes) && update {
SaveExpected(t, expectedFileName, actual)
} else if update {
assert.False(t, true, "update expected is true, set to false and rerun test")
}
}

// Deprecated: use only for elastic response
func Equal[T any](t *testing.T, actual T, expectedBytes []byte) bool {
var expected T
err := json.Unmarshal(expectedBytes, &expected)
assert.NoError(t, err)
return assert.Equal(t, expected, actual)
}

// Deprecated: use only for elastic response
func DeepEqualOrUpdate(t *testing.T, actualBytes, expectedBytes []byte, expectedFileName string, update bool) {
var actual, expected interface{}
err := json.Unmarshal(actualBytes, &actual)
assert.NoError(t, err)
err = json.Unmarshal(expectedBytes, &expected)
assert.NoError(t, err)

sortSlices := cmpopts.SortSlices(func(a, b interface{}) bool {
if less, ok := isLess(a, b); ok {
return less
}
t.Fatal("don't know how to sort these types")
return true
})
diff := cmp.Diff(expected, actual, sortSlices)
if update {
if diff != "" {
var actual interface{}
err := json.Unmarshal(actualBytes, &actual)
assert.NoError(t, err)
SaveExpected(t, expectedFileName, actual)
} else {
assert.False(t, true, "update expected is true, set to false and rerun test")
}
} else {
assert.Empty(t, diff, "actual compare with expected should not have diffs")
}
}

//TODO improve this compare - works for Elastic response

func isLess(a, b interface{}) (bool, bool) {
strA, okA := a.(string)
strB, okB := b.(string)
if okA && okB {
return strA < strB, true
}
if less, ok := lessSlice(a, b); ok {
return less, true
} else if less, ok := lessMapStr2Interface(a, b); ok {
return less, true
}

return false, false

}

func lessSlice(a, b interface{}) (less bool, ok bool) {
if less, ok := lessStringSlice(a, b); ok {
return less, true
} else if less, ok := lessInterfaceSlice(a, b); ok {
return less, true
}
return false, false
}

func lessStringSlice(a, b interface{}) (less bool, ok bool) {
strSliceA, okA := a.([]string)
strSliceB, okB := b.([]string)
if !okA || !okB {
return false, false
}
if len(strSliceA) != len(strSliceB) {
return len(strSliceA) < len(strSliceB), true
}
for i := range strSliceA {
if strSliceA[i] != strSliceB[i] {
return strings.Compare(strSliceA[i], strSliceB[i]) == -1, true
}
}
return false, true
}

func lessInterfaceSlice(a, b interface{}) (less bool, ok bool) {
interSliceA, okA := a.([]interface{})
interSliceB, okB := b.([]interface{})
if !okA || !okB {
return false, false
}
if len(interSliceA) != len(interSliceB) {
return len(interSliceA) < len(interSliceB), true
}
for i := range interSliceA {
if less, ok := isLess(interSliceA[i], interSliceB[i]); !ok {
return false, false
} else if less {
return true, true
}
}
return false, true
}

func lessMapStr2Interface(a, b interface{}) (bool, bool) {
mapA, okA := a.(map[string]interface{})
mapB, okB := b.(map[string]interface{})

if !okA || !okB {
return false, false
}

if len(mapA) != len(mapB) {
return len(mapA) < len(mapB), true
}

keysA := make([]string, 0, len(mapA))
for k := range mapA {
keysA = append(keysA, k)
}
keysB := make([]string, 0, len(mapB))
for k := range mapB {
keysB = append(keysB, k)
}

sort.StringSlice(keysA).Sort()
sort.StringSlice(keysB).Sort()

for i := range keysA {
if keysA[i] != keysB[i] {
return keysA[i] < keysB[i], true
}
}
for i := range keysA {
less, ok := isLess(mapA[keysA[i]], mapB[keysA[i]])
if !ok {
return false, false
}
if less {
return true, true
}
}
return false, true
}

func loadJson[T any](jsonBytes []byte) T {
var obj T
if err := json.Unmarshal(jsonBytes, &obj); err != nil {
panic(err)
}
return obj
}
Loading

0 comments on commit 38dabab

Please sign in to comment.