Skip to content

Commit 6a86ac4

Browse files
committed
dial: Connect() does not cancel by context if no i/o
NetDialer, GreetingDialer, AuthDialer and ProtocolDialer may not cancel Dial() on context expiration when network connection hangs. The issue occurred because context wasn't properly handled during network I/O operations, potentially causing infinite waiting. Part of TNTP-2018
1 parent c043771 commit 6a86ac4

File tree

4 files changed

+214
-27
lines changed

4 files changed

+214
-27
lines changed

CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
1414

1515
### Fixed
1616

17+
- Connect() may not cancel Dial() call on context expiration if network
18+
connection hangs (#443).
19+
1720
## [v2.3.1] - 2025-04-03
1821

1922
The patch releases fixes expected Connect() behavior and reduces allocations.
2023

2124
### Added
2225

2326
- A usage of sync.Pool of msgpack.Decoder saves 2 object allocations per
24-
a response decoding.
27+
a response decoding (#440).
2528

2629
### Changed
2730

connection.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ func (conn *Connection) dial(ctx context.Context) error {
489489
}
490490

491491
req := newWatchRequest(key.(string))
492-
if err = writeRequest(c, req); err != nil {
492+
if err = writeRequest(ctx, c, req); err != nil {
493493
st <- state
494494
return false
495495
}

dial.go

+93-16
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289289
if err != nil {
290290
return conn, err
291291
}
292+
292293
greeting := conn.Greeting()
293294
if greeting.Salt == "" {
294295
conn.Close()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309310
}
310311
}
311312

312-
if err := authenticate(conn, d.Auth, d.Username, d.Password,
313+
if err := authenticate(ctx, conn, d.Auth, d.Username, d.Password,
313314
conn.Greeting().Salt); err != nil {
314315
conn.Close()
315316
return nil, fmt.Errorf("failed to authenticate: %w", err)
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340341
protocolInfo: d.RequiredProtocolInfo,
341342
}
342343

343-
protocolConn.protocolInfo, err = identify(&protocolConn)
344+
protocolConn.protocolInfo, err = identify(ctx, &protocolConn)
344345
if err != nil {
345346
protocolConn.Close()
346347
return nil, fmt.Errorf("failed to identify: %w", err)
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372373
greetingConn := greetingConn{
373374
Conn: conn,
374375
}
375-
version, salt, err := readGreeting(greetingConn)
376+
version, salt, err := readGreeting(ctx, &greetingConn)
376377
if err != nil {
377378
greetingConn.Close()
378379
return nil, fmt.Errorf("failed to read greeting: %w", err)
379380
}
381+
380382
greetingConn.greeting = Greeting{
381383
Version: version,
382384
Salt: salt,
@@ -410,31 +412,67 @@ func parseAddress(address string) (string, string) {
410412
return network, address
411413
}
412414

415+
// ioWaiter waits in a background until an io operation done or a context
416+
// is expired. It closes the connection and writes a context error into the
417+
// output channel on context expiration.
418+
//
419+
// A user of the helper should close the first output channel after an IO
420+
// operation done and read an error from a second channel to get the result
421+
// of waiting.
422+
func ioWaiter(ctx context.Context, conn Conn) (chan<- struct{}, <-chan error) {
423+
doneIO := make(chan struct{})
424+
doneWait := make(chan error, 1)
425+
426+
go func() {
427+
defer close(doneWait)
428+
429+
select {
430+
case <-ctx.Done():
431+
conn.Close()
432+
<-doneIO
433+
doneWait <- ctx.Err()
434+
case <-doneIO:
435+
doneWait <- nil
436+
}
437+
}()
438+
439+
return doneIO, doneWait
440+
}
441+
413442
// readGreeting reads a greeting message.
414-
func readGreeting(reader io.Reader) (string, string, error) {
443+
func readGreeting(ctx context.Context, conn Conn) (string, string, error) {
415444
var version, salt string
416445

446+
doneRead, doneWait := ioWaiter(ctx, conn)
447+
417448
data := make([]byte, 128)
418-
_, err := io.ReadFull(reader, data)
449+
_, err := io.ReadFull(conn, data)
450+
451+
close(doneRead)
452+
419453
if err == nil {
420454
version = bytes.NewBuffer(data[:64]).String()
421455
salt = bytes.NewBuffer(data[64:108]).String()
422456
}
423457

458+
if waitErr := <-doneWait; waitErr != nil {
459+
err = waitErr
460+
}
461+
424462
return version, salt, err
425463
}
426464

427465
// identify sends info about client protocol, receives info
428466
// about server protocol in response and stores it in the connection.
429-
func identify(conn Conn) (ProtocolInfo, error) {
467+
func identify(ctx context.Context, conn Conn) (ProtocolInfo, error) {
430468
var info ProtocolInfo
431469

432470
req := NewIdRequest(clientProtocolInfo)
433-
if err := writeRequest(conn, req); err != nil {
471+
if err := writeRequest(ctx, conn, req); err != nil {
434472
return info, err
435473
}
436474

437-
resp, err := readResponse(conn, req)
475+
resp, err := readResponse(ctx, conn, req)
438476
if err != nil {
439477
if resp != nil &&
440478
resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +533,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495533
}
496534

497535
// authenticate authenticates for a connection.
498-
func authenticate(c Conn, auth Auth, user string, pass string, salt string) error {
536+
func authenticate(ctx context.Context, c Conn, auth Auth, user, pass, salt string) error {
499537
var req Request
500538
var err error
501539

@@ -511,37 +549,73 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511549
return errors.New("unsupported method " + auth.String())
512550
}
513551

514-
if err = writeRequest(c, req); err != nil {
552+
if err = writeRequest(ctx, c, req); err != nil {
515553
return err
516554
}
517-
if _, err = readResponse(c, req); err != nil {
555+
if _, err = readResponse(ctx, c, req); err != nil {
518556
return err
519557
}
520558
return nil
521559
}
522560

523561
// writeRequest writes a request to the writer.
524-
func writeRequest(w writeFlusher, req Request) error {
562+
func writeRequest(ctx context.Context, conn Conn, req Request) error {
525563
var packet smallWBuf
526564
err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil)
527565

528566
if err != nil {
529567
return fmt.Errorf("pack error: %w", err)
530568
}
531-
if _, err = w.Write(packet.b); err != nil {
569+
570+
doneWrite, doneWait := ioWaiter(ctx, conn)
571+
572+
_, err = conn.Write(packet.b)
573+
574+
close(doneWrite)
575+
576+
if waitErr := <-doneWait; waitErr != nil {
577+
err = waitErr
578+
}
579+
580+
if err != nil {
532581
return fmt.Errorf("write error: %w", err)
533582
}
534-
if err = w.Flush(); err != nil {
583+
584+
doneWrite, doneWait = ioWaiter(ctx, conn)
585+
586+
err = conn.Flush()
587+
588+
close(doneWrite)
589+
590+
if waitErr := <-doneWait; waitErr != nil {
591+
err = waitErr
592+
}
593+
594+
if err != nil {
535595
return fmt.Errorf("flush error: %w", err)
536596
}
597+
598+
if waitErr := <-doneWait; waitErr != nil {
599+
err = waitErr
600+
}
601+
537602
return err
538603
}
539604

540605
// readResponse reads a response from the reader.
541-
func readResponse(r io.Reader, req Request) (Response, error) {
606+
func readResponse(ctx context.Context, conn Conn, req Request) (Response, error) {
542607
var lenbuf [packetLengthBytes]byte
543608

544-
respBytes, err := read(r, lenbuf[:])
609+
doneRead, doneWait := ioWaiter(ctx, conn)
610+
611+
respBytes, err := read(conn, lenbuf[:])
612+
613+
close(doneRead)
614+
615+
if waitErr := <-doneWait; waitErr != nil {
616+
err = waitErr
617+
}
618+
545619
if err != nil {
546620
return nil, fmt.Errorf("read error: %w", err)
547621
}
@@ -555,10 +629,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555629
if err != nil {
556630
return nil, fmt.Errorf("decode response header error: %w", err)
557631
}
632+
558633
resp, err := req.Response(header, &buf)
559634
if err != nil {
560635
return nil, fmt.Errorf("creating response error: %w", err)
561636
}
637+
562638
_, err = resp.Decode()
563639
if err != nil {
564640
switch err.(type) {
@@ -568,5 +644,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568644
return resp, fmt.Errorf("decode response body error: %w", err)
569645
}
570646
}
647+
571648
return resp, nil
572649
}

0 commit comments

Comments
 (0)