Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20240916/middleware first #13

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ jobs:
with:
go-version: "1.21.x"
- name: Run Test
run: go test -cover -bench . -benchmem
run: go test ./... -cover -bench . -benchmem
3 changes: 3 additions & 0 deletions constant/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ package constant
const (
HEADER_CONTENT_TYPE string = "Content-Type"
APPLICATION_JSON string = "application/json"
HEADER_ORIGIN string = "Origin"
HEADER_VARY string = "vary"
MAX_DOMAIN_LENGTH int = 255
)
16 changes: 16 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ type Context interface {
writeContentType(value string)
SetPath(path string)
GetResponse() *response
GetRequest() *http.Request
GetRequestHeaderValue(key string) string
JsonSerialize(value any) error
NoContent() error
}

type context struct {
Expand Down Expand Up @@ -57,7 +60,20 @@ func (ctx *context) GetResponse() *response {
return ctx.response.(*response)
}

func (ctx *context) GetRequest() *http.Request {
return ctx.request
}

func (ctx *context) GetRequestHeaderValue(key string) string {
return ctx.request.Header.Get(key)
}

func (ctx *context) JsonSerialize(value any) error {
encoder := json.NewEncoder(ctx.GetResponse())
return encoder.Encode(value)
}

func (c *context) NoContent() error {
c.response.WriteHeader(http.StatusNoContent)
return nil
}
2 changes: 2 additions & 0 deletions handler_func.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package poteto

type HandlerFunc func(ctx Context) error

type MiddlewareFunc func(next HandlerFunc) HandlerFunc
95 changes: 95 additions & 0 deletions middleware/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package middleware

import (
"net/http"
"regexp"

"github.com/poteto0/poteto"
"github.com/poteto0/poteto/constant"
)

type CORSConfig struct {
AllowOrigins []string `yaml:"allow_origins"`
AllowMethods []string `yaml:"allow_methods"`
}

var DefaultCORSConfig = CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete},
}

func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc {
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}

if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}

allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := wrapRegExp(origin)
allowOriginPatterns = append(allowOriginPatterns, pattern)
}

return func(next poteto.HandlerFunc) poteto.HandlerFunc {
return func(ctx poteto.Context) error {
req := ctx.GetRequest()
res := ctx.GetResponse()
origin := req.Header.Get(constant.HEADER_ORIGIN)

res.Header().Add(constant.HEADER_VARY, constant.HEADER_ORIGIN)
preflight := req.Method == http.MethodOptions

// Not From Browser
if origin == "" {
if !preflight {
return next(ctx)
}
return ctx.NoContent()
}

allowSubDomain := getAllowSubDomain(origin, config.AllowOrigins)
// allowed origin path
allowOrigin := getAllowOrigin(allowSubDomain, allowOriginPatterns)

// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(ctx)
}
return ctx.NoContent()
}

// allowed method
if matchMethod(req.Method, config.AllowMethods) {
return next(ctx)
}

return ctx.NoContent()
}
}
}

func getAllowSubDomain(origin string, allowOrigins []string) string {
for _, o := range allowOrigins {
if o == "*" || o == origin {
return origin
}
if matchSubdomain(origin, o) {
return origin
}
}

return ""
}

func getAllowOrigin(origin string, allowOriginPatterns []string) string {
for _, pattern := range allowOriginPatterns {
if match, _ := regexp.MatchString(pattern, origin); match {
return origin
}
}
return ""
}
92 changes: 92 additions & 0 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package middleware

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/poteto0/poteto"
)

type TestVal struct {
Name string `json:"name"`
Val string `json:"val"`
}

func TestCORSWithConfigByDefault(t *testing.T) {
config := CORSConfig{
AllowOrigins: []string{},
AllowMethods: []string{},
}

t.Run("allow all origins", func(t *testing.T) {
cors := CORSWithConfig(config)

w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "https://example.com/test", nil)
context := poteto.NewContext(w, req)

handler := func(ctx poteto.Context) error {
return ctx.JSON(http.StatusOK, TestVal{Name: "test", Val: "val"})
}

cors_handler := cors(handler)
cors_handler(context)
result := w.Body.String()
expected := `{"name":"test","val":"val"}`
if result[0:27] != expected[0:27] {
t.Errorf("Wrong result")
t.Errorf(fmt.Sprintf("expected: %s", expected))
t.Errorf(fmt.Sprintf("actual: %s", result))
}
})
}

func TestGetAllowSubDomain(t *testing.T) {
tests := []struct {
name string
origin string
allowOrigins []string
expected string
}{
{"test wildcard return true", "https://example.com", []string{"*"}, "https://example.com"},
{"test match same domain", "https://example.com", []string{"https://example.com"}, "https://example.com"},
{"test matched subdomain", "https://exmaple.com.test", []string{"https://example.com.*"}, "https://exmaple.com.test"},
{"test not matched", "https://hello.world.com", []string{"https://exmaple.com"}, ""},
}

for _, it := range tests {
t.Run(it.name, func(t *testing.T) {
result := getAllowSubDomain(it.origin, it.allowOrigins)
if result != it.expected {
t.Errorf("Not matched")
t.Errorf(fmt.Sprintf("expected: %s", it.expected))
t.Errorf(fmt.Sprintf("actual: %s", result))
}
})
}
}

func TestGetAllowOrigin(t *testing.T) {
tests := []struct {
name string
origin string
allowOriginPatterns []string
expected string
}{
{"test match case", "https://example.com", []string{wrapRegExp("https://example.*")}, "https://example.com"},
{"test not match case", "https://example.com", []string{wrapRegExp("https://hello.world.com")}, ""},
}

for _, it := range tests {
t.Run(it.name, func(t *testing.T) {
result := getAllowOrigin(it.origin, it.allowOriginPatterns)
if result != it.expected {
t.Errorf("Not matched")
t.Errorf(fmt.Sprintf("expected: %s", it.expected))
t.Errorf(fmt.Sprintf("actual: %s", result))
}
})
}
}
89 changes: 89 additions & 0 deletions middleware/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package middleware

import (
"regexp"
"strings"

"github.com/poteto0/poteto/constant"
)

// EX: https://example.com:* => ^https://example\.com:.*$
func wrapRegExp(target string) string {
pattern := regexp.QuoteMeta(target) // .をescapeする
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
return pattern
}

// just sub domain
// only wild card
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}

didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}

// more fast on opp
domAuth := domain[didx+3:] // after [://]
// avoid too long
if len(domAuth) > constant.MAX_DOMAIN_LENGTH {
return false
}
patAuth := pattern[pidx+3:]

// Opposite by .
domComp := strings.Split(domAuth, ".")
domComp = reverseStringArray(domComp)
// do pattern
patComp := strings.Split(patAuth, ".")
patComp = reverseStringArray(patComp)

for i, dom := range domComp {
if len(patComp) <= i {
return false
}

pat := patComp[i]
if pat == "*" {
return true
}

if pat != dom {
return false
}
}
return false
}

// http vs https
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}

func reverseStringArray(targets []string) []string {
n := len(targets)
for i := n/2 - 1; i >= 0; i-- {
oppidx := n - i - 1
targets[i], targets[oppidx] = targets[oppidx], targets[i]
}

return targets
}

func matchMethod(method string, allowMethods []string) bool {
for _, m := range allowMethods {
if m == method {
return true
}
}

return false
}
Loading
Loading