Skip to content

Commit

Permalink
fix request handling after closing the proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 20, 2024
1 parent 0916585 commit 2e37bbf
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
10 changes: 3 additions & 7 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Proxy struct {
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
if s.closed.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
return net.ErrClosed
}

Expand All @@ -65,6 +66,8 @@ func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error {
// but MUST NOT call WriteHeader on the http.ResponseWriter.
func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error {
if s.closed.Load() {
conn.Close()
w.WriteHeader(http.StatusServiceUnavailable)
return net.ErrClosed
}

Expand All @@ -75,14 +78,7 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne
w.WriteHeader(http.StatusOK)

str := w.(http3.HTTPStreamer).HTTPStream()

s.mx.Lock()
if s.closed.Load() {
str.CancelRead(quic.StreamErrorCode(http3.ErrCodeNoError))
str.Close()
conn.Close()
w.WriteHeader(http.StatusServiceUnavailable)
}
if s.conns == nil {
s.conns = make(map[proxyEntry]struct{})
}
Expand Down
31 changes: 27 additions & 4 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestProxyCloseProxiedConn(t *testing.T) {
remoteServerConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)

s := Proxy{}
p := Proxy{}
req := newRequest(fmt.Sprintf("https://localhost:1234/masque?h=localhost&p=%d", remoteServerConn.LocalAddr().(*net.UDPAddr).Port))
rec := httptest.NewRecorder()
done := make(chan struct{})
Expand All @@ -75,7 +75,7 @@ func TestProxyCloseProxiedConn(t *testing.T) {
})
r, err := ParseRequest(req, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}"))
require.NoError(t, err)
go s.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r)
go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r)
require.Equal(t, http.StatusOK, rec.Code)

b := make([]byte, 100)
Expand All @@ -97,12 +97,35 @@ func TestProxyCloseProxiedConn(t *testing.T) {
}

func TestProxyDialFailure(t *testing.T) {
s := Proxy{}
p := Proxy{}
r := newRequest("https://localhost:1234/masque?h=localhost&p=70000") // invalid port number
req, err := ParseRequest(r, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}"))
require.NoError(t, err)
rec := httptest.NewRecorder()

require.ErrorContains(t, s.Proxy(rec, req), "invalid port")
require.ErrorContains(t, p.Proxy(rec, req), "invalid port")
require.Equal(t, http.StatusGatewayTimeout, rec.Code)
}

func TestProxyingAfterClose(t *testing.T) {
p := &Proxy{}
require.NoError(t, p.Close())

r := newRequest("https://localhost:1234/masque?h=localhost&p=1234")
req, err := ParseRequest(r, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}"))
require.NoError(t, err)

t.Run("proxying", func(t *testing.T) {
rec := httptest.NewRecorder()
require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})

t.Run("proxying connected socket", func(t *testing.T) {
rec := httptest.NewRecorder()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
})
}

0 comments on commit 2e37bbf

Please sign in to comment.