diff --git a/internal/proxy/proxy_state.go b/internal/proxy/proxy_state.go index 5ca4df3..e0b10b7 100644 --- a/internal/proxy/proxy_state.go +++ b/internal/proxy/proxy_state.go @@ -216,10 +216,11 @@ func (ps *ProxyState) SetupRaft(r *raft.Raft, oc chan raft.Observation) { } func (ps *ProxyState) WaitForChanges() error { - ps.proxyLock.RLock() - defer ps.proxyLock.RUnlock() if rft := ps.Raft(); rft != nil { return rft.Barrier(time.Second * 5).Error() + } else { + ps.proxyLock.RLock() + defer ps.proxyLock.RUnlock() } return nil } diff --git a/pkg/util/http.go b/pkg/util/http.go index 741f6af..6676a48 100644 --- a/pkg/util/http.go +++ b/pkg/util/http.go @@ -17,7 +17,8 @@ func WriteStatusCodeError(w http.ResponseWriter, code int) { // the request has passed through. func GetTrustedIP(r *http.Request, depth int) string { ips := r.Header.Values("X-Forwarded-For") - if len(ips) == 0 || depth > len(ips) { + depth = min(depth, len(ips)) + if depth <= 0 { remoteHost, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr diff --git a/pkg/util/http_test.go b/pkg/util/http_test.go new file mode 100644 index 0000000..46d1e19 --- /dev/null +++ b/pkg/util/http_test.go @@ -0,0 +1,54 @@ +package util_test + +import ( + "math" + "net/http" + "testing" + + "github.com/dgate-io/dgate/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestGetTrustedIP_Depth(t *testing.T) { + req := requestWithXForwardedFor(t, "1.2.3.4", "1.2.3.5", "1.2.3.6") + + t.Run("Depth 0", func(t *testing.T) { + require.Equal(t, util.GetTrustedIP(req, 0), "127.0.0.1") + }) + + t.Run("Depth 1", func(t *testing.T) { + require.Equal(t, util.GetTrustedIP(req, 1), "1.2.3.6") + }) + + t.Run("Depth 2", func(t *testing.T) { + require.Equal(t, util.GetTrustedIP(req, 2), "1.2.3.5") + }) + + t.Run("Depth 3", func(t *testing.T) { + require.Equal(t, util.GetTrustedIP(req, 3), "1.2.3.4") + }) + + t.Run("Depth too High", func(t *testing.T) { + require.Equal(t, util.GetTrustedIP(req, 4), "1.2.3.4") + require.Equal(t, util.GetTrustedIP(req, 8), "1.2.3.4") + require.Equal(t, util.GetTrustedIP(req, 16), "1.2.3.4") + }) + + t.Run("Depth too Low", func(t *testing.T) { + require.Equal(t, util.GetTrustedIP(req, -1), "127.0.0.1") + require.Equal(t, util.GetTrustedIP(req, -10), "127.0.0.1") + require.Equal(t, util.GetTrustedIP(req, math.MinInt), "127.0.0.1") + }) +} + +func requestWithXForwardedFor(t *testing.T, ips ...string) *http.Request { + req, err := http.NewRequest("GET", "http://localhost:8080", nil) + if err != nil { + t.Fatal(err) + } + req.RemoteAddr = "127.0.0.1" + for _, ip := range ips { + req.Header.Add("X-Forwarded-For", ip) + } + return req +}