Skip to content

Commit b9cb68e

Browse files
authored
Modularize connection handling (#173)
1 parent de41d3e commit b9cb68e

File tree

2 files changed

+74
-32
lines changed

2 files changed

+74
-32
lines changed

service/metrics/metrics.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ func (c *measuredConn) Write(b []byte) (int, error) {
5353
return n, err
5454
}
5555

56-
func (c *measuredConn) ReadFrom(r io.Reader) (int64, error) {
57-
n, err := io.Copy(c.StreamConn, r)
56+
func (c *measuredConn) ReadFrom(r io.Reader) (n int64, err error) {
57+
if rf, ok := c.StreamConn.(io.ReaderFrom); ok {
58+
// Prefer ReadFrom if we are calling ReadFrom. Otherwise io.Copy will try WriteTo first.
59+
n, err = rf.ReadFrom(r)
60+
} else {
61+
n, err = io.Copy(c.StreamConn, r)
62+
}
5863
*c.writeCount += n
5964
return n, err
6065
}

service/tcp.go

+67-30
Original file line numberDiff line numberDiff line change
@@ -239,25 +239,22 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
239239
logger.Debugf("Done with status %v, duration %v", status, connDuration)
240240
}
241241

242-
func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
243-
// Set a deadline to receive the address to the target.
244-
clientConn.SetReadDeadline(time.Now().Add(h.readTimeout))
245-
246-
// 1. Find the cipher and acess key id.
242+
func (h *tcpHandler) authenticate(clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, transport.StreamConn, *onet.ConnectionError) {
243+
// TODO(fortuna): Offer alternative transports.
244+
// Find the cipher and acess key id.
247245
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), h.ciphers)
248246
h.m.AddTCPCipherSearch(keyErr == nil, timeToCipher)
249247
if keyErr != nil {
250248
logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr)
251249
const status = "ERR_CIPHER"
252-
h.absorbProbe(listenerPort, clientConn, status, proxyMetrics)
253-
return "", onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
250+
return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
254251
}
255252
var id string
256253
if cipherEntry != nil {
257254
id = cipherEntry.ID
258255
}
259256

260-
// 2. Check if the connection is a replay.
257+
// Check if the connection is a replay.
261258
isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
262259
// Only check the cache if findAccessKey succeeded and the salt is unrecognized.
263260
if isServerSalt || !h.replayCache.Add(cipherEntry.ID, clientSalt) {
@@ -267,38 +264,39 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
267264
} else {
268265
status = "ERR_REPLAY_CLIENT"
269266
}
270-
h.absorbProbe(listenerPort, clientConn, status, proxyMetrics)
271267
logger.Debugf(status+": %v sent %d bytes", clientConn.RemoteAddr(), proxyMetrics.ClientProxy)
272-
return id, onet.NewConnectionError(status, "Replay detected", nil)
268+
return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
273269
}
274-
275-
// 3. Read target address and dial it.
276270
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
277-
tgtAddr, err := socks.ReadAddr(ssr)
271+
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
272+
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
273+
return id, transport.WrapConn(clientConn, ssr, ssw), nil
274+
}
278275

279-
// Clear the deadline for the target address
280-
clientConn.SetReadDeadline(time.Time{})
276+
func getProxyRequest(clientConn transport.StreamConn) (string, error) {
277+
// TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte:
278+
// case 1, 3 or 4: Shadowsocks (address type)
279+
// case 5: SOCKS5 (protocol version)
280+
// case "C": HTTP CONNECT (first char of method)
281+
tgtAddr, err := socks.ReadAddr(clientConn)
281282
if err != nil {
282-
// Drain to prevent a close on cipher error.
283-
io.Copy(io.Discard, clientConn)
284-
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
283+
return "", err
285284
}
286-
tgtConn, dialErr := h.dialer.DialStream(ctx, tgtAddr.String())
285+
return tgtAddr.String(), nil
286+
}
287+
288+
func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
289+
tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
287290
if dialErr != nil {
288291
// We don't drain so dial errors and invalid addresses are communicated quickly.
289-
return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
292+
return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
290293
}
291-
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
292294
defer tgtConn.Close()
293-
294-
// 4. Bridge the client and target connections
295295
logger.Debugf("proxy %s <-> %s", clientConn.RemoteAddr().String(), tgtConn.RemoteAddr().String())
296-
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
297-
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
298296

299297
fromClientErrCh := make(chan error)
300298
go func() {
301-
_, fromClientErr := ssr.WriteTo(tgtConn)
299+
_, fromClientErr := io.Copy(tgtConn, clientConn)
302300
if fromClientErr != nil {
303301
// Drain to prevent a close in the case of a cipher error.
304302
io.Copy(io.Discard, clientConn)
@@ -310,19 +308,58 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
310308
tgtConn.CloseWrite()
311309
fromClientErrCh <- fromClientErr
312310
}()
313-
_, fromTargetErr := ssw.ReadFrom(tgtConn)
311+
_, fromTargetErr := io.Copy(clientConn, tgtConn)
314312
// Send FIN to client.
315313
clientConn.CloseWrite()
316314
tgtConn.CloseRead()
317315

318316
fromClientErr := <-fromClientErrCh
319317
if fromClientErr != nil {
320-
return id, onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
318+
return onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr)
321319
}
322320
if fromTargetErr != nil {
323-
return id, onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
321+
return onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr)
324322
}
325-
return id, nil
323+
return nil
324+
}
325+
326+
func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
327+
// Set a deadline to receive the address to the target.
328+
readDeadline := time.Now().Add(h.readTimeout)
329+
if deadline, ok := ctx.Deadline(); ok {
330+
outerConn.SetDeadline(deadline)
331+
if deadline.Before(readDeadline) {
332+
readDeadline = deadline
333+
}
334+
}
335+
outerConn.SetReadDeadline(readDeadline)
336+
337+
id, innerConn, authErr := h.authenticate(outerConn, proxyMetrics)
338+
if authErr != nil {
339+
// Drain to protect against probing attacks.
340+
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
341+
return id, authErr
342+
}
343+
344+
// Read target address and dial it.
345+
tgtAddr, err := getProxyRequest(innerConn)
346+
// Clear the deadline for the target address
347+
outerConn.SetReadDeadline(time.Time{})
348+
if err != nil {
349+
// Drain to prevent a close on cipher error.
350+
io.Copy(io.Discard, outerConn)
351+
return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err)
352+
}
353+
354+
dialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
355+
tgtConn, err := h.dialer.DialStream(ctx, tgtAddr)
356+
if err != nil {
357+
return nil, err
358+
}
359+
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
360+
return tgtConn, nil
361+
})
362+
return id, proxyConnection(ctx, dialer, tgtAddr, innerConn)
326363
}
327364

328365
// Keep the connection open until we hit the authentication deadline to protect against probing attacks

0 commit comments

Comments
 (0)