Skip to content

Commit

Permalink
not tested cors
Browse files Browse the repository at this point in the history
  • Loading branch information
poteto0 committed Sep 16, 2024
1 parent 8b3e641 commit 460a84f
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 0 deletions.
2 changes: 2 additions & 0 deletions constant/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ package constant
const (
HEADER_CONTENT_TYPE string = "Content-Type"
APPLICATION_JSON string = "application/json"
HEADER_ORIGIN string = "Origin"
HEADER_VARY string = "vary"
)
16 changes: 16 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ type Context interface {
writeContentType(value string)
SetPath(path string)
GetResponse() *response
GetRequest() *http.Request
GetRequestHeaderValue(key string) string
JsonSerialize(value any) error
NoContent() error
}

type context struct {
Expand Down Expand Up @@ -57,7 +60,20 @@ func (ctx *context) GetResponse() *response {
return ctx.response.(*response)
}

func (ctx *context) GetRequest() *http.Request {
return ctx.request
}

func (ctx *context) GetRequestHeaderValue(key string) string {
return ctx.request.Header.Get(key)
}

func (ctx *context) JsonSerialize(value any) error {
encoder := json.NewEncoder(ctx.GetResponse())
return encoder.Encode(value)
}

func (c *context) NoContent() error {
c.response.WriteHeader(http.StatusNoContent)
return nil
}
2 changes: 2 additions & 0 deletions handler_func.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package poteto

type HandlerFunc func(ctx Context) error

type MiddlewareFunc func(next HandlerFunc) HandlerFunc
97 changes: 97 additions & 0 deletions middleware/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package middleware

import (
"net/http"
"regexp"
"strings"

"github.com/poteto0/poteto"
"github.com/poteto0/poteto/constant"
)

type CORSConfig struct {
AllowOrigins []string `yaml:"allow_origins"`
AllowMethods []string `yaml:"allow_methods"`
}

var DefaultCORSConfig = CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete},
}

func CORSWithConfig(config CORSConfig) poteto.MiddlewareFunc {
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}

if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}

allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}

return func(next poteto.HandlerFunc) poteto.HandlerFunc {
return func(ctx poteto.Context) error {
req := ctx.GetRequest()
res := ctx.GetResponse()
origin := req.Header.Get(constant.HEADER_ORIGIN)

res.Header().Add(constant.HEADER_VARY, constant.HEADER_ORIGIN)
preflight := req.Method == http.MethodOptions

// Not From Browser
if origin == "" {
if !preflight {
return next(ctx)
}
return ctx.NoContent()
}

allowOrigin := getAllowOrigin(origin, allowOriginPatterns)

// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(ctx)
}
return ctx.NoContent()
}

if matchMethod(req.Method, config.AllowMethods) {
return next(ctx)
}

return ctx.NoContent()
}
}
}

func getAllowOrigin(origin string, allowOrigins []string) string {
for _, o := range allowOrigins {
if o == "*" || o == origin {
return origin
}
if matchSubdomain(origin, o) {
return origin
}
}

return ""
}

func matchMethod(method string, allowMethods []string) bool {
for _, m := range allowMethods {
if m == method {
return true
}
}

return false
}
53 changes: 53 additions & 0 deletions middleware/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package middleware

import "strings"

func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}

func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}

didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}

domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]

domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}

for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}

0 comments on commit 460a84f

Please sign in to comment.