Skip to content

Commit

Permalink
fix: add test for trustedIp code
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbajoe committed Jun 14, 2024
1 parent 9bcbce4 commit ff9ecb3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
5 changes: 3 additions & 2 deletions internal/proxy/proxy_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ func (ps *ProxyState) SetupRaft(r *raft.Raft, oc chan raft.Observation) {
}

func (ps *ProxyState) WaitForChanges() error {
ps.proxyLock.RLock()
defer ps.proxyLock.RUnlock()
if rft := ps.Raft(); rft != nil {
return rft.Barrier(time.Second * 5).Error()
} else {
ps.proxyLock.RLock()
defer ps.proxyLock.RUnlock()
}
return nil
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/util/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ func WriteStatusCodeError(w http.ResponseWriter, code int) {
// the request has passed through.
func GetTrustedIP(r *http.Request, depth int) string {
ips := r.Header.Values("X-Forwarded-For")
if len(ips) == 0 || depth > len(ips) {
depth = min(depth, len(ips))
if depth <= 0 {
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
Expand Down
54 changes: 54 additions & 0 deletions pkg/util/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package util_test

import (
"math"
"net/http"
"testing"

"github.com/dgate-io/dgate/pkg/util"
"github.com/stretchr/testify/require"
)

func TestGetTrustedIP_Depth(t *testing.T) {
req := requestWithXForwardedFor(t, "1.2.3.4", "1.2.3.5", "1.2.3.6")

t.Run("Depth 0", func(t *testing.T) {
require.Equal(t, util.GetTrustedIP(req, 0), "127.0.0.1")
})

t.Run("Depth 1", func(t *testing.T) {
require.Equal(t, util.GetTrustedIP(req, 1), "1.2.3.6")
})

t.Run("Depth 2", func(t *testing.T) {
require.Equal(t, util.GetTrustedIP(req, 2), "1.2.3.5")
})

t.Run("Depth 3", func(t *testing.T) {
require.Equal(t, util.GetTrustedIP(req, 3), "1.2.3.4")
})

t.Run("Depth too High", func(t *testing.T) {
require.Equal(t, util.GetTrustedIP(req, 4), "1.2.3.4")
require.Equal(t, util.GetTrustedIP(req, 8), "1.2.3.4")
require.Equal(t, util.GetTrustedIP(req, 16), "1.2.3.4")
})

t.Run("Depth too Low", func(t *testing.T) {
require.Equal(t, util.GetTrustedIP(req, -1), "127.0.0.1")
require.Equal(t, util.GetTrustedIP(req, -10), "127.0.0.1")
require.Equal(t, util.GetTrustedIP(req, math.MinInt), "127.0.0.1")
})
}

func requestWithXForwardedFor(t *testing.T, ips ...string) *http.Request {
req, err := http.NewRequest("GET", "http://localhost:8080", nil)
if err != nil {
t.Fatal(err)
}
req.RemoteAddr = "127.0.0.1"
for _, ip := range ips {
req.Header.Add("X-Forwarded-For", ip)
}
return req
}

0 comments on commit ff9ecb3

Please sign in to comment.