From 460a84f81fb31f2a542813e38b8280b91dfb8099 Mon Sep 17 00:00:00 2001 From: poteto0 Date: Mon, 16 Sep 2024 16:46:55 +0900 Subject: [PATCH 1/3] not tested cors --- constant/constant.go | 2 + context.go | 16 ++++++++ handler_func.go | 2 + middleware/cors.go | 97 ++++++++++++++++++++++++++++++++++++++++++++ middleware/util.go | 53 ++++++++++++++++++++++++ 5 files changed, 170 insertions(+) create mode 100644 middleware/cors.go create mode 100644 middleware/util.go diff --git a/constant/constant.go b/constant/constant.go index 028e6f6..97bb10f 100644 --- a/constant/constant.go +++ b/constant/constant.go @@ -3,4 +3,6 @@ package constant const ( HEADER_CONTENT_TYPE string = "Content-Type" APPLICATION_JSON string = "application/json" + HEADER_ORIGIN string = "Origin" + HEADER_VARY string = "vary" ) diff --git a/context.go b/context.go index 654aa36..7dcd02d 100644 --- a/context.go +++ b/context.go @@ -14,7 +14,10 @@ type Context interface { writeContentType(value string) SetPath(path string) GetResponse() *response + GetRequest() *http.Request + GetRequestHeaderValue(key string) string JsonSerialize(value any) error + NoContent() error } type context struct { @@ -57,7 +60,20 @@ func (ctx *context) GetResponse() *response { return ctx.response.(*response) } +func (ctx *context) GetRequest() *http.Request { + return ctx.request +} + +func (ctx *context) GetRequestHeaderValue(key string) string { + return ctx.request.Header.Get(key) +} + func (ctx *context) JsonSerialize(value any) error { encoder := json.NewEncoder(ctx.GetResponse()) return encoder.Encode(value) } + +func (c *context) NoContent() error { + c.response.WriteHeader(http.StatusNoContent) + return nil +} diff --git a/handler_func.go b/handler_func.go index 4a63aa5..869c76c 100644 --- a/handler_func.go +++ b/handler_func.go @@ -1,3 +1,5 @@ package poteto type HandlerFunc func(ctx Context) error + +type MiddlewareFunc func(next HandlerFunc) HandlerFunc diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..ef4dd62 --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,97 @@ +package middleware + +import ( + "net/http" + "regexp" + "strings" + + "github.com/poteto0/poteto" + "github.com/poteto0/poteto/constant" +) + +type CORSConfig struct { + AllowOrigins []string `yaml:"allow_origins"` + AllowMethods []string `yaml:"allow_methods"` +} + +var DefaultCORSConfig = CORSConfig{ + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete}, +} + +func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc { + if len(config.AllowOrigins) == 0 { + config.AllowOrigins = DefaultCORSConfig.AllowOrigins + } + + if len(config.AllowMethods) == 0 { + config.AllowMethods = DefaultCORSConfig.AllowMethods + } + + allowOriginPatterns := []string{} + for _, origin := range config.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.ReplaceAll(pattern, "\\*", ".*") + pattern = strings.ReplaceAll(pattern, "\\?", ".") + pattern = "^" + pattern + "$" + allowOriginPatterns = append(allowOriginPatterns, pattern) + } + + return func(next poteto.HandlerFunc) poteto.HandlerFunc { + return func(ctx poteto.Context) error { + req := ctx.GetRequest() + res := ctx.GetResponse() + origin := req.Header.Get(constant.HEADER_ORIGIN) + + res.Header().Add(constant.HEADER_VARY, constant.HEADER_ORIGIN) + preflight := req.Method == http.MethodOptions + + // Not From Browser + if origin == "" { + if !preflight { + return next(ctx) + } + return ctx.NoContent() + } + + allowOrigin := getAllowOrigin(origin, allowOriginPatterns) + + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(ctx) + } + return ctx.NoContent() + } + + if matchMethod(req.Method, config.AllowMethods) { + return next(ctx) + } + + return ctx.NoContent() + } + } +} + +func getAllowOrigin(origin string, allowOrigins []string) string { + for _, o := range allowOrigins { + if o == "*" || o == origin { + return origin + } + if matchSubdomain(origin, o) { + return origin + } + } + + return "" +} + +func matchMethod(method string, allowMethods []string) bool { + for _, m := range allowMethods { + if m == method { + return true + } + } + + return false +} diff --git a/middleware/util.go b/middleware/util.go new file mode 100644 index 0000000..d8f8599 --- /dev/null +++ b/middleware/util.go @@ -0,0 +1,53 @@ +package middleware + +import "strings" + +func matchScheme(domain, pattern string) bool { + didx := strings.Index(domain, ":") + pidx := strings.Index(pattern, ":") + return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx] +} + +func matchSubdomain(domain, pattern string) bool { + if !matchScheme(domain, pattern) { + return false + } + + didx := strings.Index(domain, "://") + pidx := strings.Index(pattern, "://") + if didx == -1 || pidx == -1 { + return false + } + + domAuth := domain[didx+3:] + // to avoid long loop by invalid long domain + if len(domAuth) > 253 { + return false + } + patAuth := pattern[pidx+3:] + + domComp := strings.Split(domAuth, ".") + patComp := strings.Split(patAuth, ".") + for i := len(domComp)/2 - 1; i >= 0; i-- { + opp := len(domComp) - 1 - i + domComp[i], domComp[opp] = domComp[opp], domComp[i] + } + for i := len(patComp)/2 - 1; i >= 0; i-- { + opp := len(patComp) - 1 - i + patComp[i], patComp[opp] = patComp[opp], patComp[i] + } + + for i, v := range domComp { + if len(patComp) <= i { + return false + } + p := patComp[i] + if p == "*" { + return true + } + if p != v { + return false + } + } + return false +} From 0084d11b2b4f811a915d870eca5c6e20a23d9a70 Mon Sep 17 00:00:00 2001 From: poteto0 Date: Thu, 19 Sep 2024 13:56:06 +0900 Subject: [PATCH 2/3] test utils --- .github/workflows/test.yaml | 2 +- middleware/cors.go | 8 ++--- middleware/{util.go => utils.go} | 15 ++++++++- middleware/utils_test.go | 53 ++++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 8 deletions(-) rename middleware/{util.go => utils.go} (75%) create mode 100644 middleware/utils_test.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 18065ad..506ba8f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,4 +22,4 @@ jobs: with: go-version: "1.21.x" - name: Run Test - run: go test -cover -bench . -benchmem + run: go test ./... -cover -bench . -benchmem diff --git a/middleware/cors.go b/middleware/cors.go index ef4dd62..3042e72 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -2,8 +2,6 @@ package middleware import ( "net/http" - "regexp" - "strings" "github.com/poteto0/poteto" "github.com/poteto0/poteto/constant" @@ -30,10 +28,7 @@ func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc { allowOriginPatterns := []string{} for _, origin := range config.AllowOrigins { - pattern := regexp.QuoteMeta(origin) - pattern = strings.ReplaceAll(pattern, "\\*", ".*") - pattern = strings.ReplaceAll(pattern, "\\?", ".") - pattern = "^" + pattern + "$" + pattern := wrapRegExp(origin) allowOriginPatterns = append(allowOriginPatterns, pattern) } @@ -64,6 +59,7 @@ func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc { return ctx.NoContent() } + // allowed method if matchMethod(req.Method, config.AllowMethods) { return next(ctx) } diff --git a/middleware/util.go b/middleware/utils.go similarity index 75% rename from middleware/util.go rename to middleware/utils.go index d8f8599..4b2d22d 100644 --- a/middleware/util.go +++ b/middleware/utils.go @@ -1,7 +1,20 @@ package middleware -import "strings" +import ( + "regexp" + "strings" +) +// EX: https://example.com:* => ^https://example\.com:.*$ +func wrapRegExp(target string) string { + pattern := regexp.QuoteMeta(target) // .をescapeする + pattern = strings.ReplaceAll(pattern, "\\*", ".*") + pattern = strings.ReplaceAll(pattern, "\\?", ".") + pattern = "^" + pattern + "$" + return pattern +} + +// http vs https func matchScheme(domain, pattern string) bool { didx := strings.Index(domain, ":") pidx := strings.Index(pattern, ":") diff --git a/middleware/utils_test.go b/middleware/utils_test.go new file mode 100644 index 0000000..8bbc30a --- /dev/null +++ b/middleware/utils_test.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "fmt" + "testing" +) + +func TestWrapRegExp(t *testing.T) { + tests := []struct { + name string + target string + expected string + }{ + {"test * url", "https://example.com:*", `^https://example\.com:.*$`}, + {"test ? url", "https://example.com:300?", `^https://example\.com:300.$`}, + } + + for _, it := range tests { + t.Run(it.name, func(t *testing.T) { + result := wrapRegExp(it.target) + if result != it.expected { + t.Errorf("Not matched") + } + }) + } +} + +func TestMatchScheme(t *testing.T) { + tests := []struct { + name string + domain string + pattern string + expected bool + }{ + {"If : is not existed return false(both)", "example.com", "example.com", false}, + {"If : is not existed return false(pattern)", "example.com", "http://example.com", false}, + {"If : is not existed return false(domain)", "http://example.com", "example.com", false}, + {"matched", "http://example1.com", "http://example2.com", true}, + {"not matched", "http://example1.com", "https://example2.com", false}, + } + + for _, it := range tests { + t.Run((it.name), func(t *testing.T) { + result := matchScheme(it.domain, it.pattern) + + if result != it.expected { + t.Errorf("Not matched") + t.Errorf(fmt.Sprintf("expected: %t", it.expected)) + t.Errorf(fmt.Sprintf("actual: %t", result)) + } + }) + } +} From 890e345d89960e1fddbdc650816b982ab7473ea9 Mon Sep 17 00:00:00 2001 From: poteto0 Date: Mon, 23 Sep 2024 17:13:36 +0900 Subject: [PATCH 3/3] test cors --- constant/constant.go | 1 + middleware/cors.go | 18 ++++---- middleware/cors_test.go | 92 ++++++++++++++++++++++++++++++++++++++++ middleware/utils.go | 67 +++++++++++++++++++---------- middleware/utils_test.go | 61 ++++++++++++++++++++++++++ utils/math.go | 15 +++++++ 6 files changed, 224 insertions(+), 30 deletions(-) create mode 100644 middleware/cors_test.go create mode 100644 utils/math.go diff --git a/constant/constant.go b/constant/constant.go index 97bb10f..8ff5f70 100644 --- a/constant/constant.go +++ b/constant/constant.go @@ -5,4 +5,5 @@ const ( APPLICATION_JSON string = "application/json" HEADER_ORIGIN string = "Origin" HEADER_VARY string = "vary" + MAX_DOMAIN_LENGTH int = 255 ) diff --git a/middleware/cors.go b/middleware/cors.go index 3042e72..065fa19 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "regexp" "github.com/poteto0/poteto" "github.com/poteto0/poteto/constant" @@ -49,7 +50,9 @@ func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc { return ctx.NoContent() } - allowOrigin := getAllowOrigin(origin, allowOriginPatterns) + allowSubDomain := getAllowSubDomain(origin, config.AllowOrigins) + // allowed origin path + allowOrigin := getAllowOrigin(allowSubDomain, allowOriginPatterns) // Origin not allowed if allowOrigin == "" { @@ -69,7 +72,7 @@ func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc { } } -func getAllowOrigin(origin string, allowOrigins []string) string { +func getAllowSubDomain(origin string, allowOrigins []string) string { for _, o := range allowOrigins { if o == "*" || o == origin { return origin @@ -82,12 +85,11 @@ func getAllowOrigin(origin string, allowOrigins []string) string { return "" } -func matchMethod(method string, allowMethods []string) bool { - for _, m := range allowMethods { - if m == method { - return true +func getAllowOrigin(origin string, allowOriginPatterns []string) string { + for _, pattern := range allowOriginPatterns { + if match, _ := regexp.MatchString(pattern, origin); match { + return origin } } - - return false + return "" } diff --git a/middleware/cors_test.go b/middleware/cors_test.go new file mode 100644 index 0000000..9f108f6 --- /dev/null +++ b/middleware/cors_test.go @@ -0,0 +1,92 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/poteto0/poteto" +) + +type TestVal struct { + Name string `json:"name"` + Val string `json:"val"` +} + +func TestCORSWithConfigByDefault(t *testing.T) { + config := CORSConfig{ + AllowOrigins: []string{}, + AllowMethods: []string{}, + } + + t.Run("allow all origins", func(t *testing.T) { + cors := CORSWithConfig(config) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "https://example.com/test", nil) + context := poteto.NewContext(w, req) + + handler := func(ctx poteto.Context) error { + return ctx.JSON(http.StatusOK, TestVal{Name: "test", Val: "val"}) + } + + cors_handler := cors(handler) + cors_handler(context) + result := w.Body.String() + expected := `{"name":"test","val":"val"}` + if result[0:27] != expected[0:27] { + t.Errorf("Wrong result") + t.Errorf(fmt.Sprintf("expected: %s", expected)) + t.Errorf(fmt.Sprintf("actual: %s", result)) + } + }) +} + +func TestGetAllowSubDomain(t *testing.T) { + tests := []struct { + name string + origin string + allowOrigins []string + expected string + }{ + {"test wildcard return true", "https://example.com", []string{"*"}, "https://example.com"}, + {"test match same domain", "https://example.com", []string{"https://example.com"}, "https://example.com"}, + {"test matched subdomain", "https://exmaple.com.test", []string{"https://example.com.*"}, "https://exmaple.com.test"}, + {"test not matched", "https://hello.world.com", []string{"https://exmaple.com"}, ""}, + } + + for _, it := range tests { + t.Run(it.name, func(t *testing.T) { + result := getAllowSubDomain(it.origin, it.allowOrigins) + if result != it.expected { + t.Errorf("Not matched") + t.Errorf(fmt.Sprintf("expected: %s", it.expected)) + t.Errorf(fmt.Sprintf("actual: %s", result)) + } + }) + } +} + +func TestGetAllowOrigin(t *testing.T) { + tests := []struct { + name string + origin string + allowOriginPatterns []string + expected string + }{ + {"test match case", "https://example.com", []string{wrapRegExp("https://example.*")}, "https://example.com"}, + {"test not match case", "https://example.com", []string{wrapRegExp("https://hello.world.com")}, ""}, + } + + for _, it := range tests { + t.Run(it.name, func(t *testing.T) { + result := getAllowOrigin(it.origin, it.allowOriginPatterns) + if result != it.expected { + t.Errorf("Not matched") + t.Errorf(fmt.Sprintf("expected: %s", it.expected)) + t.Errorf(fmt.Sprintf("actual: %s", result)) + } + }) + } +} diff --git a/middleware/utils.go b/middleware/utils.go index 4b2d22d..24f1765 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -3,6 +3,8 @@ package middleware import ( "regexp" "strings" + + "github.com/poteto0/poteto/constant" ) // EX: https://example.com:* => ^https://example\.com:.*$ @@ -14,13 +16,8 @@ func wrapRegExp(target string) string { return pattern } -// http vs https -func matchScheme(domain, pattern string) bool { - didx := strings.Index(domain, ":") - pidx := strings.Index(pattern, ":") - return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx] -} - +// just sub domain +// only wild card func matchSubdomain(domain, pattern string) bool { if !matchScheme(domain, pattern) { return false @@ -32,35 +29,61 @@ func matchSubdomain(domain, pattern string) bool { return false } - domAuth := domain[didx+3:] - // to avoid long loop by invalid long domain - if len(domAuth) > 253 { + // more fast on opp + domAuth := domain[didx+3:] // after [://] + // avoid too long + if len(domAuth) > constant.MAX_DOMAIN_LENGTH { return false } patAuth := pattern[pidx+3:] + // Opposite by . domComp := strings.Split(domAuth, ".") + domComp = reverseStringArray(domComp) + // do pattern patComp := strings.Split(patAuth, ".") - for i := len(domComp)/2 - 1; i >= 0; i-- { - opp := len(domComp) - 1 - i - domComp[i], domComp[opp] = domComp[opp], domComp[i] - } - for i := len(patComp)/2 - 1; i >= 0; i-- { - opp := len(patComp) - 1 - i - patComp[i], patComp[opp] = patComp[opp], patComp[i] - } + patComp = reverseStringArray(patComp) - for i, v := range domComp { + for i, dom := range domComp { if len(patComp) <= i { return false } - p := patComp[i] - if p == "*" { + + pat := patComp[i] + if pat == "*" { return true } - if p != v { + + if pat != dom { return false } } return false } + +// http vs https +func matchScheme(domain, pattern string) bool { + didx := strings.Index(domain, ":") + pidx := strings.Index(pattern, ":") + return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx] +} + +func reverseStringArray(targets []string) []string { + n := len(targets) + for i := n/2 - 1; i >= 0; i-- { + oppidx := n - i - 1 + targets[i], targets[oppidx] = targets[oppidx], targets[i] + } + + return targets +} + +func matchMethod(method string, allowMethods []string) bool { + for _, m := range allowMethods { + if m == method { + return true + } + } + + return false +} diff --git a/middleware/utils_test.go b/middleware/utils_test.go index 8bbc30a..24de2a9 100644 --- a/middleware/utils_test.go +++ b/middleware/utils_test.go @@ -2,7 +2,10 @@ package middleware import ( "fmt" + "net/http" "testing" + + "github.com/poteto0/poteto/utils" ) func TestWrapRegExp(t *testing.T) { @@ -25,6 +28,31 @@ func TestWrapRegExp(t *testing.T) { } } +func TestMatchSubDomain(t *testing.T) { + tests := []struct { + name string + domain string + pattern string + expected bool + }{ + {"test same url", "http://hello.world.com.test", "http://hello.world.com.test", false}, + {"test http & https return false", "http://hello.world.com.test", "https://hello.world.com.*", false}, + {"test not :// type return false", "hello.world.com.test", "hello.world.com.test", false}, + {"test wild card pattern return true", "http://hello.world.com.test", "http://hello.world.com.*", true}, + } + + for _, it := range tests { + t.Run(it.name, func(t *testing.T) { + result := matchSubdomain(it.domain, it.pattern) + if result != it.expected { + t.Errorf("Not matched") + t.Errorf(fmt.Sprintf("expected: %t", it.expected)) + t.Errorf(fmt.Sprintf("actual: %t", result)) + } + }) + } +} + func TestMatchScheme(t *testing.T) { tests := []struct { name string @@ -51,3 +79,36 @@ func TestMatchScheme(t *testing.T) { }) } } + +func TestreverseStringArray(t *testing.T) { + targets := []string{"!!", "world", "hello"} + expected := []string{"hello", "world", "!!"} + + result := reverseStringArray(targets) + if !utils.SliceUtils(result, expected) { + t.Errorf("Not matched") + } +} + +func TestMatchMethod(t *testing.T) { + tests := []struct { + name string + target string + allowMethods []string + expected bool + }{ + {"test including method return true", http.MethodGet, []string{http.MethodGet}, true}, + {"test not including method return false", http.MethodPost, []string{http.MethodGet}, false}, + } + + for _, it := range tests { + t.Run(it.name, func(t *testing.T) { + result := matchMethod(it.target, it.allowMethods) + if result != it.expected { + t.Errorf("Not matched") + t.Errorf(fmt.Sprintf("expected: %t", it.expected)) + t.Errorf(fmt.Sprintf("actual: %t", result)) + } + }) + } +} diff --git a/utils/math.go b/utils/math.go new file mode 100644 index 0000000..596e7d8 --- /dev/null +++ b/utils/math.go @@ -0,0 +1,15 @@ +package utils + +func SliceUtils[T comparable](as, bs []T) bool { + if len(as) != len(bs) { + return false + } + + for i := 0; i < len(as); i++ { + if as[i] != bs[i] { + return false + } + } + + return true +}