Skip to content

Commit 0cbcd61

Browse files
committed
Refactor to make packet handler an association handler.
1 parent 5704718 commit 0cbcd61

File tree

7 files changed

+358
-484
lines changed

7 files changed

+358
-484
lines changed

caddy/shadowsocks_handler.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ type ShadowsocksHandler struct {
5151
Keys []KeyConfig `json:"keys,omitempty"`
5252

5353
streamHandler outline.StreamHandler
54-
packetHandler outline.PacketHandler
54+
associationHandler outline.AssociationHandler
5555
metrics outline.ServiceMetrics
56-
tgtListener transport.PacketListener
5756
logger *slog.Logger
5857
}
5958

@@ -106,13 +105,12 @@ func (h *ShadowsocksHandler) Provision(ctx caddy.Context) error {
106105
ciphers := outline.NewCipherList()
107106
ciphers.Update(cipherList)
108107

109-
h.streamHandler, h.packetHandler = outline.NewShadowsocksHandlers(
108+
h.streamHandler, h.associationHandler = outline.NewShadowsocksHandlers(
110109
outline.WithLogger(h.logger),
111110
outline.WithCiphers(ciphers),
112111
outline.WithMetrics(h.metrics),
113112
outline.WithReplayCache(&app.ReplayCache),
114113
)
115-
h.tgtListener = outline.MakeTargetUDPListener(defaultNatTimeout, 0)
116114
return nil
117115
}
118116

@@ -122,11 +120,7 @@ func (h *ShadowsocksHandler) Handle(cx *layer4.Connection, _ layer4.Handler) err
122120
case transport.StreamConn:
123121
h.streamHandler.HandleStream(cx.Context, conn, h.metrics.AddOpenTCPConnection(conn))
124122
case net.Conn:
125-
assoc, err := outline.NewPacketAssociation(conn, h.tgtListener, h.metrics.AddOpenUDPAssociation(conn))
126-
if err != nil {
127-
return fmt.Errorf("failed to handle association: %v", err)
128-
}
129-
outline.HandleAssociation(assoc, h.packetHandler.HandlePacket)
123+
h.associationHandler.HandleAssociation(cx.Context, conn, h.metrics.AddOpenUDPAssociation(conn))
130124
default:
131125
return fmt.Errorf("failed to handle unknown connection type: %t", conn)
132126
}

cmd/outline-ss-server/main.go

+10-20
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,11 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
225225
ciphers := service.NewCipherList()
226226
ciphers.Update(cipherList)
227227

228-
streamHandler, packetHandler := service.NewShadowsocksHandlers(
228+
streamHandler, associationHandler := service.NewShadowsocksHandlers(
229229
service.WithCiphers(ciphers),
230230
service.WithMetrics(s.serviceMetrics),
231231
service.WithReplayCache(&s.replayCache),
232+
service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, 0)),
232233
service.WithLogger(slog.Default()),
233234
)
234235
ln, err := lnSet.ListenStream(addr)
@@ -245,27 +246,22 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
245246
return err
246247
}
247248
slog.Info("UDP service started.", "address", pc.LocalAddr().String())
248-
tgtListener := service.MakeTargetUDPListener(s.natTimeout, 0)
249-
go service.PacketServe(pc, func(conn net.Conn) (service.PacketAssociation, error) {
250-
m := s.serviceMetrics.AddOpenUDPAssociation(conn)
251-
assoc, err := service.NewPacketAssociation(conn, tgtListener, m)
252-
if err != nil {
253-
return nil, fmt.Errorf("failed to handle association: %v", err)
254-
}
255-
return assoc, nil
256-
}, packetHandler.HandlePacket, s.serverMetrics)
249+
go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) {
250+
associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn))
251+
}, s.serverMetrics)
257252
}
258253

259254
for _, serviceConfig := range config.Services {
260255
ciphers, err := newCipherListFromConfig(serviceConfig)
261256
if err != nil {
262257
return fmt.Errorf("failed to create cipher list from config: %v", err)
263258
}
264-
streamHandler, packetHandler := service.NewShadowsocksHandlers(
259+
streamHandler, associationHandler := service.NewShadowsocksHandlers(
265260
service.WithCiphers(ciphers),
266261
service.WithMetrics(s.serviceMetrics),
267262
service.WithReplayCache(&s.replayCache),
268263
service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, serviceConfig.Dialer.Fwmark)),
264+
service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark)),
269265
service.WithLogger(slog.Default()),
270266
)
271267
if err != nil {
@@ -298,15 +294,9 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
298294
}
299295
return serviceConfig.Dialer.Fwmark
300296
}())
301-
tgtListener := service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark)
302-
go service.PacketServe(pc, func(conn net.Conn) (service.PacketAssociation, error) {
303-
m := s.serviceMetrics.AddOpenUDPAssociation(conn)
304-
assoc, err := service.NewPacketAssociation(conn, tgtListener, m)
305-
if err != nil {
306-
return nil, fmt.Errorf("failed to handle association: %v", err)
307-
}
308-
return assoc, nil
309-
}, packetHandler.HandlePacket, s.serverMetrics)
297+
go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) {
298+
associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn))
299+
}, s.serverMetrics)
310300
}
311301
}
312302
totalCipherCount += len(serviceConfig.Keys)

internal/integration_test/integration_test.go

+12-15
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,14 @@ func TestUDPEcho(t *testing.T) {
317317
if err != nil {
318318
t.Fatal(err)
319319
}
320-
proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{})
320+
proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{})
321321

322322
proxy.SetTargetIPValidator(allowAll)
323323
natMetrics := &natTestMetrics{}
324324
associationMetrics := &fakeUDPAssociationMetrics{}
325-
go service.PacketServe(proxyConn, func(conn net.Conn) (service.PacketAssociation, error) {
326-
assoc, _ := service.NewPacketAssociation(conn, &transport.UDPListener{Address: ""}, associationMetrics)
327-
return assoc, nil
328-
}, proxy.Handle, natMetrics)
325+
go service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) {
326+
proxy.HandleAssociation(ctx, conn, associationMetrics)
327+
}, natMetrics)
329328

330329
cryptoKey, err := shadowsocks.NewEncryptionKey(shadowsocks.CHACHA20IETFPOLY1305, secrets[0])
331330
require.NoError(t, err)
@@ -546,14 +545,13 @@ func BenchmarkUDPEcho(b *testing.B) {
546545
if err != nil {
547546
b.Fatal(err)
548547
}
549-
proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{})
548+
proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{})
550549
proxy.SetTargetIPValidator(allowAll)
551550
done := make(chan struct{})
552551
go func() {
553-
service.PacketServe(server, func(conn net.Conn) (service.PacketAssociation, error) {
554-
assoc, _ := service.NewPacketAssociation(conn, &transport.UDPListener{Address: ""}, nil)
555-
return assoc, nil
556-
}, proxy.Handle, &natTestMetrics{})
552+
service.PacketServe(server, func(ctx context.Context, conn net.Conn) {
553+
proxy.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{})
554+
}, &natTestMetrics{})
557555
done <- struct{}{}
558556
}()
559557

@@ -593,14 +591,13 @@ func BenchmarkUDPManyKeys(b *testing.B) {
593591
if err != nil {
594592
b.Fatal(err)
595593
}
596-
proxy := service.NewPacketHandler(cipherList, &fakeShadowsocksMetrics{})
594+
proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{})
597595
proxy.SetTargetIPValidator(allowAll)
598596
done := make(chan struct{})
599597
go func() {
600-
service.PacketServe(proxyConn, func(conn net.Conn) (service.PacketAssociation, error) {
601-
assoc, _ := service.NewPacketAssociation(conn, &transport.UDPListener{Address: ""}, nil)
602-
return assoc, nil
603-
}, proxy.Handle, &natTestMetrics{})
598+
service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) {
599+
proxy.HandleAssociation(ctx, conn, &fakeUDPAssociationMetrics{})
600+
}, &natTestMetrics{})
604601
done <- struct{}{}
605602
}()
606603

service/shadowsocks.go

+18-9
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@ import (
2424
onet "github.com/Jigsaw-Code/outline-ss-server/net"
2525
)
2626

27-
const (
28-
// 59 seconds is most common timeout for servers that do not respond to invalid requests
29-
tcpReadTimeout time.Duration = 59 * time.Second
30-
)
27+
// 59 seconds is most common timeout for servers that do not respond to invalid requests
28+
const tcpReadTimeout time.Duration = 59 * time.Second
3129

3230
// ShadowsocksConnMetrics is used to report Shadowsocks related metrics on connections.
3331
type ShadowsocksConnMetrics interface {
@@ -51,11 +49,12 @@ type ssService struct {
5149
targetIPValidator onet.TargetIPValidator
5250
replayCache *ReplayCache
5351

54-
streamDialer transport.StreamDialer
52+
streamDialer transport.StreamDialer
53+
packetListener transport.PacketListener
5554
}
5655

5756
// NewShadowsocksHandlers creates new Shadowsocks stream and packet handlers.
58-
func NewShadowsocksHandlers(opts ...Option) (StreamHandler, PacketHandler) {
57+
func NewShadowsocksHandlers(opts ...Option) (StreamHandler, AssociationHandler) {
5958
s := &ssService{
6059
logger: noopLogger(),
6160
}
@@ -74,10 +73,13 @@ func NewShadowsocksHandlers(opts ...Option) (StreamHandler, PacketHandler) {
7473
}
7574
sh.SetLogger(s.logger)
7675

77-
ph := NewPacketHandler(s.ciphers, &ssConnMetrics{s.metrics.AddUDPCipherSearch})
78-
ph.SetLogger(s.logger)
76+
ah := NewAssociationHandler(s.ciphers, &ssConnMetrics{s.metrics.AddUDPCipherSearch})
77+
if s.packetListener != nil {
78+
ah.SetTargetPacketListener(s.packetListener)
79+
}
80+
ah.SetLogger(s.logger)
7981

80-
return sh, ph
82+
return sh, ah
8183
}
8284

8385
// WithLogger can be used to provide a custom log target. If not provided,
@@ -115,6 +117,13 @@ func WithStreamDialer(dialer transport.StreamDialer) Option {
115117
}
116118
}
117119

120+
// WithPacketListener option function.
121+
func WithPacketListener(listener transport.PacketListener) Option {
122+
return func(s *ssService) {
123+
s.packetListener = listener
124+
}
125+
}
126+
118127
type ssConnMetrics struct {
119128
metricFunc func(accessKeyFound bool, timeToCipher time.Duration)
120129
}

service/tcp_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
368368
testMetrics := &probeTestMetrics{}
369369
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
370370
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
371-
handler.SetTargetDialerStream(MakeValidatingTCPStreamDialer(allowAll, 0))
371+
handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0))
372372
done := make(chan struct{})
373373
go func() {
374374
StreamServe(
@@ -406,7 +406,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
406406
testMetrics := &probeTestMetrics{}
407407
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
408408
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
409-
handler.SetTargetDialerStream(MakeValidatingTCPStreamDialer(allowAll, 0))
409+
handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0))
410410
done := make(chan struct{})
411411
go func() {
412412
StreamServe(
@@ -445,7 +445,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
445445
testMetrics := &probeTestMetrics{}
446446
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
447447
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
448-
handler.SetTargetDialerStream(MakeValidatingTCPStreamDialer(allowAll, 0))
448+
handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0))
449449
done := make(chan struct{})
450450
go func() {
451451
StreamServe(
@@ -747,7 +747,7 @@ func TestStreamServeEarlyClose(t *testing.T) {
747747
err = tcpListener.Close()
748748
require.NoError(t, err)
749749
// This should return quickly, without timing out or calling the handler.
750-
StreamServeStream(WrapStreamAcceptFunc(tcpListener.AcceptTCP), nil)
750+
StreamServe(WrapStreamAcceptFunc(tcpListener.AcceptTCP), nil)
751751
}
752752

753753
// Makes sure the TCP listener returns [io.ErrClosed] on Close().

0 commit comments

Comments
 (0)