Skip to content

Commit

Permalink
refactor: decouple the TCP handler from the listener (#193)
Browse files Browse the repository at this point in the history
* refactor: don't link the TCP handler to a specific listener

* Pass in address instead.
  • Loading branch information
sbruens authored Jul 23, 2024
1 parent 6a0a242 commit 55aadb4
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 18 deletions.
2 changes: 1 addition & 1 deletion cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (s *SSServer) startPort(portNum int) error {
port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()}
authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m)
// TODO: Register initial data metrics at zero.
tcpHandler := service.NewTCPHandler(listener.Addr().String(), authFunc, s.m, tcpReadTimeout)
tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout)
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m)
s.ports[portNum] = port
go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle)
Expand Down
8 changes: 4 additions & 4 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestTCPEcho(t *testing.T) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -202,7 +202,7 @@ func TestRestrictedAddresses(t *testing.T) {
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
done := make(chan struct{})
go func() {
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -384,7 +384,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -448,7 +448,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(proxyListener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down
9 changes: 4 additions & 5 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ type tcpHandler struct {
}

// NewTCPService creates a TCPService
func NewTCPHandler(listenerId string, authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler {
func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler {
return &tcpHandler{
listenerId: listenerId,
m: m,
readTimeout: timeout,
authenticate: authenticate,
Expand Down Expand Up @@ -342,7 +341,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S
id, innerConn, authErr := h.authenticate(outerConn)
if authErr != nil {
// Drain to protect against probing attacks.
h.absorbProbe(outerConn, authErr.Status, proxyMetrics)
h.absorbProbe(outerConn, outerConn.LocalAddr().String(), authErr.Status, proxyMetrics)
return id, authErr
}
h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id)
Expand Down Expand Up @@ -370,12 +369,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
// `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, addr, status string, proxyMetrics *metrics.ProxyMetrics) {
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
drainResult := drainErrToString(drainErr)
logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult)
h.m.AddTCPProbe(status, drainResult, h.listenerId, proxyMetrics.ClientProxy)
h.m.AddTCPProbe(status, drainResult, addr, proxyMetrics.ClientProxy)
}

func drainErrToString(drainErr error) string {
Expand Down
16 changes: 8 additions & 8 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func TestProbeRandom(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -358,7 +358,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -393,7 +393,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -429,7 +429,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -472,7 +472,7 @@ func TestProbeServerBytesModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -503,7 +503,7 @@ func TestReplayDefense(t *testing.T) {
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -582,7 +582,7 @@ func TestReverseReplayDefense(t *testing.T) {
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -653,7 +653,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(listener.Addr().String(), authFunc, testMetrics, testTimeout)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)

done := make(chan struct{})
go func() {
Expand Down

0 comments on commit 55aadb4

Please sign in to comment.