Skip to content

Commit

Permalink
get back aa6fafd, accidentally reverted in 03cad9f
Browse files Browse the repository at this point in the history
  • Loading branch information
yusing committed Oct 6, 2024
1 parent de7805f commit 929b7f7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 31 deletions.
29 changes: 20 additions & 9 deletions internal/net/http/header_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package http

import (
"net/http"
"slices"
)

func RemoveHop(h http.Header) {
Expand All @@ -25,18 +24,30 @@ func CopyHeader(dst, src http.Header) {
}
}

func FilterHeaders(h http.Header, allowed []string) {
if allowed == nil {
return
func FilterHeaders(h http.Header, allowed []string) http.Header {
if len(allowed) == 0 {
return h
}

for i := range allowed {
allowed[i] = http.CanonicalHeaderKey(allowed[i])
filtered := make(http.Header)

for i, header := range allowed {
values := h.Values(header)
if len(values) == 0 {
continue
}
filtered[http.CanonicalHeaderKey(allowed[i])] = append([]string(nil), values...)
}

for key := range h {
if !slices.Contains(allowed, key) {
h.Del(key)
return filtered
}

func HeaderToMap(h http.Header) map[string]string {
result := make(map[string]string)
for k, v := range h {
if len(v) > 0 {
result[k] = v[0] // Take the first value
}
}
return result
}
19 changes: 10 additions & 9 deletions internal/net/http/middleware/forward_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
}

// TODO: use tr from reverse proxy
tr, ok := fa.forwardAuthOpts.transport.(*http.Transport)
tr, ok := fa.transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
Expand All @@ -81,51 +81,52 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}

gpHTTP.CopyHeader(faReq.Header, req.Header)
gpHTTP.RemoveHop(faReq.Header)

gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
faReq.Header = gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
fa.m.AddTraceRequest("forward auth request", faReq)

faResp, err := fa.client.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer faResp.Body.Close()

body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}

if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTracef("status %d", faResp.StatusCode)
fa.m.AddTraceResponse("forward auth response", faResp)
gpHTTP.CopyHeader(w.Header(), faResp.Header)
gpHTTP.RemoveHop(w.Header())

redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String()).WithResponse(faResp)
}

w.WriteHeader(faResp.StatusCode)

if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp)
}
return
}
Expand Down
38 changes: 25 additions & 13 deletions internal/net/http/middleware/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@ package middleware

import (
"fmt"
"net/http"
"sync"
"time"

gpHTTP "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
)

type Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders http.Header `json:"req_headers,omitempty"`
RespHeaders http.Header `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders map[string]string `json:"req_headers,omitempty"`
RespHeaders map[string]string `json:"resp_headers,omitempty"`
RespStatus int `json:"resp_status,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}

type Traces []*Trace

var traces = Traces{}
var tracesMu sync.Mutex

const MaxTraceNum = 1000
const MaxTraceNum = 100

func GetAllTrace() []*Trace {
return traces
Expand All @@ -36,7 +36,7 @@ func (tr *Trace) WithRequest(req *Request) *Trace {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = req.Header.Clone()
tr.ReqHeaders = gpHTTP.HeaderToMap(req.Header)
return tr
}

Expand All @@ -45,8 +45,8 @@ func (tr *Trace) WithResponse(resp *Response) *Trace {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = resp.Request.Header.Clone()
tr.RespHeaders = resp.Header.Clone()
tr.ReqHeaders = gpHTTP.HeaderToMap(resp.Request.Header)
tr.RespHeaders = gpHTTP.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode
return tr
}
Expand All @@ -63,6 +63,18 @@ func (tr *Trace) With(what string, additional any) *Trace {
return tr
}

func (tr *Trace) WithError(err error) *Trace {
if tr == nil {
return nil
}

if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional["error"] = err.Error()
return tr
}

func (m *Middleware) EnableTrace() {
m.trace = true
for _, child := range m.children {
Expand Down

0 comments on commit 929b7f7

Please sign in to comment.