diff --git a/proxy.go b/proxy.go index 1b4ff05..04da117 100644 --- a/proxy.go +++ b/proxy.go @@ -40,6 +40,7 @@ type Proxy struct { // but MUST NOT call WriteHeader on the http.ResponseWriter. func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error { if s.closed.Load() { + w.WriteHeader(http.StatusServiceUnavailable) return net.ErrClosed } @@ -65,6 +66,8 @@ func (s *Proxy) Proxy(w http.ResponseWriter, r *Request) error { // but MUST NOT call WriteHeader on the http.ResponseWriter. func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *net.UDPConn) error { if s.closed.Load() { + conn.Close() + w.WriteHeader(http.StatusServiceUnavailable) return net.ErrClosed } @@ -75,14 +78,7 @@ func (s *Proxy) ProxyConnectedSocket(w http.ResponseWriter, _ *Request, conn *ne w.WriteHeader(http.StatusOK) str := w.(http3.HTTPStreamer).HTTPStream() - s.mx.Lock() - if s.closed.Load() { - str.CancelRead(quic.StreamErrorCode(http3.ErrCodeNoError)) - str.Close() - conn.Close() - w.WriteHeader(http.StatusServiceUnavailable) - } if s.conns == nil { s.conns = make(map[proxyEntry]struct{}) } diff --git a/proxy_test.go b/proxy_test.go index 547632d..5438c3f 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -50,7 +50,7 @@ func TestProxyCloseProxiedConn(t *testing.T) { remoteServerConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) require.NoError(t, err) - s := Proxy{} + p := Proxy{} req := newRequest(fmt.Sprintf("https://localhost:1234/masque?h=localhost&p=%d", remoteServerConn.LocalAddr().(*net.UDPAddr).Port)) rec := httptest.NewRecorder() done := make(chan struct{}) @@ -75,7 +75,7 @@ func TestProxyCloseProxiedConn(t *testing.T) { }) r, err := ParseRequest(req, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}")) require.NoError(t, err) - go s.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r) + go p.Proxy(&http3ResponseWriter{ResponseWriter: rec, str: str}, r) require.Equal(t, http.StatusOK, rec.Code) b := make([]byte, 100) @@ -97,12 +97,35 @@ func TestProxyCloseProxiedConn(t *testing.T) { } func TestProxyDialFailure(t *testing.T) { - s := Proxy{} + p := Proxy{} r := newRequest("https://localhost:1234/masque?h=localhost&p=70000") // invalid port number req, err := ParseRequest(r, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}")) require.NoError(t, err) rec := httptest.NewRecorder() - require.ErrorContains(t, s.Proxy(rec, req), "invalid port") + require.ErrorContains(t, p.Proxy(rec, req), "invalid port") require.Equal(t, http.StatusGatewayTimeout, rec.Code) } + +func TestProxyingAfterClose(t *testing.T) { + p := &Proxy{} + require.NoError(t, p.Close()) + + r := newRequest("https://localhost:1234/masque?h=localhost&p=1234") + req, err := ParseRequest(r, uritemplate.MustNew("https://localhost:1234/masque?h={target_host}&p={target_port}")) + require.NoError(t, err) + + t.Run("proxying", func(t *testing.T) { + rec := httptest.NewRecorder() + require.ErrorIs(t, p.Proxy(rec, req), net.ErrClosed) + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + }) + + t.Run("proxying connected socket", func(t *testing.T) { + rec := httptest.NewRecorder() + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + require.ErrorIs(t, p.ProxyConnectedSocket(rec, req, conn), net.ErrClosed) + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + }) +}