Skip to content

Commit 7d78206

Browse files
committed
dial: Connect() does not cancel by context if no i/o
NetDialer, GreetingDialer, AuthDialer and ProtocolDialer may not cancel Dial() on context expiration when network connection hangs. The issue occurred because context wasn't properly handled during network I/O operations, potentially causing infinite waiting. Part of TNTP-2018
1 parent c043771 commit 7d78206

File tree

4 files changed

+210
-27
lines changed

4 files changed

+210
-27
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
1414

1515
### Fixed
1616

17+
- Connect() may not cancel Dial() call on context expiration if network
18+
connection hangs (#443).
19+
1720
## [v2.3.1] - 2025-04-03
1821

1922
The patch releases fixes expected Connect() behavior and reduces allocations.
2023

2124
### Added
2225

2326
- A usage of sync.Pool of msgpack.Decoder saves 2 object allocations per
24-
a response decoding.
27+
a response decoding (#440).
2528

2629
### Changed
2730

connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ func (conn *Connection) dial(ctx context.Context) error {
489489
}
490490

491491
req := newWatchRequest(key.(string))
492-
if err = writeRequest(c, req); err != nil {
492+
if err = writeRequest(ctx, c, req); err != nil {
493493
st <- state
494494
return false
495495
}

dial.go

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289289
if err != nil {
290290
return conn, err
291291
}
292+
292293
greeting := conn.Greeting()
293294
if greeting.Salt == "" {
294295
conn.Close()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309310
}
310311
}
311312

312-
if err := authenticate(conn, d.Auth, d.Username, d.Password,
313+
if err := authenticate(ctx, conn, d.Auth, d.Username, d.Password,
313314
conn.Greeting().Salt); err != nil {
314315
conn.Close()
315316
return nil, fmt.Errorf("failed to authenticate: %w", err)
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340341
protocolInfo: d.RequiredProtocolInfo,
341342
}
342343

343-
protocolConn.protocolInfo, err = identify(&protocolConn)
344+
protocolConn.protocolInfo, err = identify(ctx, &protocolConn)
344345
if err != nil {
345346
protocolConn.Close()
346347
return nil, fmt.Errorf("failed to identify: %w", err)
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372373
greetingConn := greetingConn{
373374
Conn: conn,
374375
}
375-
version, salt, err := readGreeting(greetingConn)
376+
version, salt, err := readGreeting(ctx, &greetingConn)
376377
if err != nil {
377378
greetingConn.Close()
378379
return nil, fmt.Errorf("failed to read greeting: %w", err)
379380
}
381+
380382
greetingConn.greeting = Greeting{
381383
Version: version,
382384
Salt: salt,
@@ -410,31 +412,62 @@ func parseAddress(address string) (string, string) {
410412
return network, address
411413
}
412414

415+
// ioWaiter waits in a background until an io operation done or a context
416+
// is expired. It closes the connection and writes a context error into the
417+
// output channel if a context expired.
418+
func ioWaiter(ctx context.Context, conn Conn, done <-chan struct{}) <-chan error {
419+
doneWait := make(chan error, 1)
420+
421+
go func() {
422+
defer close(doneWait)
423+
424+
select {
425+
case <-ctx.Done():
426+
conn.Close()
427+
<-done
428+
doneWait <- ctx.Err()
429+
case <-done:
430+
doneWait <- nil
431+
}
432+
}()
433+
434+
return doneWait
435+
}
436+
413437
// readGreeting reads a greeting message.
414-
func readGreeting(reader io.Reader) (string, string, error) {
438+
func readGreeting(ctx context.Context, conn Conn) (string, string, error) {
415439
var version, salt string
416440

441+
doneRead := make(chan struct{})
442+
doneWait := ioWaiter(ctx, conn, doneRead)
443+
417444
data := make([]byte, 128)
418-
_, err := io.ReadFull(reader, data)
445+
_, err := io.ReadFull(conn, data)
446+
close(doneRead)
447+
419448
if err == nil {
420449
version = bytes.NewBuffer(data[:64]).String()
421450
salt = bytes.NewBuffer(data[64:108]).String()
422451
}
423452

453+
if waitErr := <-doneWait; waitErr != nil {
454+
err = waitErr
455+
}
456+
424457
return version, salt, err
425458
}
426459

427460
// identify sends info about client protocol, receives info
428461
// about server protocol in response and stores it in the connection.
429-
func identify(conn Conn) (ProtocolInfo, error) {
462+
func identify(ctx context.Context, conn Conn) (ProtocolInfo, error) {
430463
var info ProtocolInfo
431464

432465
req := NewIdRequest(clientProtocolInfo)
433-
if err := writeRequest(conn, req); err != nil {
466+
if err := writeRequest(ctx, conn, req); err != nil {
434467
return info, err
435468
}
436469

437-
resp, err := readResponse(conn, req)
470+
resp, err := readResponse(ctx, conn, req)
438471
if err != nil {
439472
if resp != nil &&
440473
resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +528,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495528
}
496529

497530
// authenticate authenticates for a connection.
498-
func authenticate(c Conn, auth Auth, user string, pass string, salt string) error {
531+
func authenticate(ctx context.Context, c Conn, auth Auth, user, pass, salt string) error {
499532
var req Request
500533
var err error
501534

@@ -511,37 +544,74 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511544
return errors.New("unsupported method " + auth.String())
512545
}
513546

514-
if err = writeRequest(c, req); err != nil {
547+
if err = writeRequest(ctx, c, req); err != nil {
515548
return err
516549
}
517-
if _, err = readResponse(c, req); err != nil {
550+
if _, err = readResponse(ctx, c, req); err != nil {
518551
return err
519552
}
520553
return nil
521554
}
522555

523556
// writeRequest writes a request to the writer.
524-
func writeRequest(w writeFlusher, req Request) error {
557+
func writeRequest(ctx context.Context, conn Conn, req Request) error {
525558
var packet smallWBuf
526559
err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil)
527560

528561
if err != nil {
529562
return fmt.Errorf("pack error: %w", err)
530563
}
531-
if _, err = w.Write(packet.b); err != nil {
564+
565+
doneWrite := make(chan struct{})
566+
doneWait := ioWaiter(ctx, conn, doneWrite)
567+
568+
_, err = conn.Write(packet.b)
569+
close(doneWrite)
570+
571+
if waitErr := <-doneWait; waitErr != nil {
572+
err = waitErr
573+
}
574+
575+
if err != nil {
532576
return fmt.Errorf("write error: %w", err)
533577
}
534-
if err = w.Flush(); err != nil {
578+
579+
doneWrite = make(chan struct{})
580+
doneWait = ioWaiter(ctx, conn, doneWrite)
581+
582+
err = conn.Flush()
583+
close(doneWrite)
584+
585+
if waitErr := <-doneWait; waitErr != nil {
586+
err = waitErr
587+
}
588+
589+
if err != nil {
535590
return fmt.Errorf("flush error: %w", err)
536591
}
592+
593+
if waitErr := <-doneWait; waitErr != nil {
594+
err = waitErr
595+
}
596+
537597
return err
538598
}
539599

540600
// readResponse reads a response from the reader.
541-
func readResponse(r io.Reader, req Request) (Response, error) {
601+
func readResponse(ctx context.Context, conn Conn, req Request) (Response, error) {
542602
var lenbuf [packetLengthBytes]byte
543603

544-
respBytes, err := read(r, lenbuf[:])
604+
doneRead := make(chan struct{})
605+
doneWait := ioWaiter(ctx, conn, doneRead)
606+
607+
respBytes, err := read(conn, lenbuf[:])
608+
609+
close(doneRead)
610+
611+
if waitErr := <-doneWait; waitErr != nil {
612+
err = waitErr
613+
}
614+
545615
if err != nil {
546616
return nil, fmt.Errorf("read error: %w", err)
547617
}
@@ -555,10 +625,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555625
if err != nil {
556626
return nil, fmt.Errorf("decode response header error: %w", err)
557627
}
628+
558629
resp, err := req.Response(header, &buf)
559630
if err != nil {
560631
return nil, fmt.Errorf("creating response error: %w", err)
561632
}
633+
562634
_, err = resp.Decode()
563635
if err != nil {
564636
switch err.(type) {
@@ -568,5 +640,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568640
return resp, fmt.Errorf("decode response body error: %w", err)
569641
}
570642
}
643+
571644
return resp, nil
572645
}

0 commit comments

Comments
 (0)