Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7d78206

Browse files
committedApr 10, 2025··
dial: Connect() does not cancel by context if no i/o
NetDialer, GreetingDialer, AuthDialer and ProtocolDialer may not cancel Dial() on context expiration when network connection hangs. The issue occurred because context wasn't properly handled during network I/O operations, potentially causing infinite waiting. Part of TNTP-2018
1 parent c043771 commit 7d78206

File tree

4 files changed

+210
-27
lines changed

4 files changed

+210
-27
lines changed
 

‎CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
1414

1515
### Fixed
1616

17+
- Connect() may not cancel Dial() call on context expiration if network
18+
connection hangs (#443).
19+
1720
## [v2.3.1] - 2025-04-03
1821

1922
The patch releases fixes expected Connect() behavior and reduces allocations.
2023

2124
### Added
2225

2326
- A usage of sync.Pool of msgpack.Decoder saves 2 object allocations per
24-
a response decoding.
27+
a response decoding (#440).
2528

2629
### Changed
2730

‎connection.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ func (conn *Connection) dial(ctx context.Context) error {
489489
}
490490

491491
req := newWatchRequest(key.(string))
492-
if err = writeRequest(c, req); err != nil {
492+
if err = writeRequest(ctx, c, req); err != nil {
493493
st <- state
494494
return false
495495
}

‎dial.go

+89-16
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289289
if err != nil {
290290
return conn, err
291291
}
292+
292293
greeting := conn.Greeting()
293294
if greeting.Salt == "" {
294295
conn.Close()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309310
}
310311
}
311312

312-
if err := authenticate(conn, d.Auth, d.Username, d.Password,
313+
if err := authenticate(ctx, conn, d.Auth, d.Username, d.Password,
313314
conn.Greeting().Salt); err != nil {
314315
conn.Close()
315316
return nil, fmt.Errorf("failed to authenticate: %w", err)
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340341
protocolInfo: d.RequiredProtocolInfo,
341342
}
342343

343-
protocolConn.protocolInfo, err = identify(&protocolConn)
344+
protocolConn.protocolInfo, err = identify(ctx, &protocolConn)
344345
if err != nil {
345346
protocolConn.Close()
346347
return nil, fmt.Errorf("failed to identify: %w", err)
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372373
greetingConn := greetingConn{
373374
Conn: conn,
374375
}
375-
version, salt, err := readGreeting(greetingConn)
376+
version, salt, err := readGreeting(ctx, &greetingConn)
376377
if err != nil {
377378
greetingConn.Close()
378379
return nil, fmt.Errorf("failed to read greeting: %w", err)
379380
}
381+
380382
greetingConn.greeting = Greeting{
381383
Version: version,
382384
Salt: salt,
@@ -410,31 +412,62 @@ func parseAddress(address string) (string, string) {
410412
return network, address
411413
}
412414

415+
// ioWaiter waits in a background until an io operation done or a context
416+
// is expired. It closes the connection and writes a context error into the
417+
// output channel if a context expired.
418+
func ioWaiter(ctx context.Context, conn Conn, done <-chan struct{}) <-chan error {
419+
doneWait := make(chan error, 1)
420+
421+
go func() {
422+
defer close(doneWait)
423+
424+
select {
425+
case <-ctx.Done():
426+
conn.Close()
427+
<-done
428+
doneWait <- ctx.Err()
429+
case <-done:
430+
doneWait <- nil
431+
}
432+
}()
433+
434+
return doneWait
435+
}
436+
413437
// readGreeting reads a greeting message.
414-
func readGreeting(reader io.Reader) (string, string, error) {
438+
func readGreeting(ctx context.Context, conn Conn) (string, string, error) {
415439
var version, salt string
416440

441+
doneRead := make(chan struct{})
442+
doneWait := ioWaiter(ctx, conn, doneRead)
443+
417444
data := make([]byte, 128)
418-
_, err := io.ReadFull(reader, data)
445+
_, err := io.ReadFull(conn, data)
446+
close(doneRead)
447+
419448
if err == nil {
420449
version = bytes.NewBuffer(data[:64]).String()
421450
salt = bytes.NewBuffer(data[64:108]).String()
422451
}
423452

453+
if waitErr := <-doneWait; waitErr != nil {
454+
err = waitErr
455+
}
456+
424457
return version, salt, err
425458
}
426459

427460
// identify sends info about client protocol, receives info
428461
// about server protocol in response and stores it in the connection.
429-
func identify(conn Conn) (ProtocolInfo, error) {
462+
func identify(ctx context.Context, conn Conn) (ProtocolInfo, error) {
430463
var info ProtocolInfo
431464

432465
req := NewIdRequest(clientProtocolInfo)
433-
if err := writeRequest(conn, req); err != nil {
466+
if err := writeRequest(ctx, conn, req); err != nil {
434467
return info, err
435468
}
436469

437-
resp, err := readResponse(conn, req)
470+
resp, err := readResponse(ctx, conn, req)
438471
if err != nil {
439472
if resp != nil &&
440473
resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +528,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495528
}
496529

497530
// authenticate authenticates for a connection.
498-
func authenticate(c Conn, auth Auth, user string, pass string, salt string) error {
531+
func authenticate(ctx context.Context, c Conn, auth Auth, user, pass, salt string) error {
499532
var req Request
500533
var err error
501534

@@ -511,37 +544,74 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511544
return errors.New("unsupported method " + auth.String())
512545
}
513546

514-
if err = writeRequest(c, req); err != nil {
547+
if err = writeRequest(ctx, c, req); err != nil {
515548
return err
516549
}
517-
if _, err = readResponse(c, req); err != nil {
550+
if _, err = readResponse(ctx, c, req); err != nil {
518551
return err
519552
}
520553
return nil
521554
}
522555

523556
// writeRequest writes a request to the writer.
524-
func writeRequest(w writeFlusher, req Request) error {
557+
func writeRequest(ctx context.Context, conn Conn, req Request) error {
525558
var packet smallWBuf
526559
err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil)
527560

528561
if err != nil {
529562
return fmt.Errorf("pack error: %w", err)
530563
}
531-
if _, err = w.Write(packet.b); err != nil {
564+
565+
doneWrite := make(chan struct{})
566+
doneWait := ioWaiter(ctx, conn, doneWrite)
567+
568+
_, err = conn.Write(packet.b)
569+
close(doneWrite)
570+
571+
if waitErr := <-doneWait; waitErr != nil {
572+
err = waitErr
573+
}
574+
575+
if err != nil {
532576
return fmt.Errorf("write error: %w", err)
533577
}
534-
if err = w.Flush(); err != nil {
578+
579+
doneWrite = make(chan struct{})
580+
doneWait = ioWaiter(ctx, conn, doneWrite)
581+
582+
err = conn.Flush()
583+
close(doneWrite)
584+
585+
if waitErr := <-doneWait; waitErr != nil {
586+
err = waitErr
587+
}
588+
589+
if err != nil {
535590
return fmt.Errorf("flush error: %w", err)
536591
}
592+
593+
if waitErr := <-doneWait; waitErr != nil {
594+
err = waitErr
595+
}
596+
537597
return err
538598
}
539599

540600
// readResponse reads a response from the reader.
541-
func readResponse(r io.Reader, req Request) (Response, error) {
601+
func readResponse(ctx context.Context, conn Conn, req Request) (Response, error) {
542602
var lenbuf [packetLengthBytes]byte
543603

544-
respBytes, err := read(r, lenbuf[:])
604+
doneRead := make(chan struct{})
605+
doneWait := ioWaiter(ctx, conn, doneRead)
606+
607+
respBytes, err := read(conn, lenbuf[:])
608+
609+
close(doneRead)
610+
611+
if waitErr := <-doneWait; waitErr != nil {
612+
err = waitErr
613+
}
614+
545615
if err != nil {
546616
return nil, fmt.Errorf("read error: %w", err)
547617
}
@@ -555,10 +625,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555625
if err != nil {
556626
return nil, fmt.Errorf("decode response header error: %w", err)
557627
}
628+
558629
resp, err := req.Response(header, &buf)
559630
if err != nil {
560631
return nil, fmt.Errorf("creating response error: %w", err)
561632
}
633+
562634
_, err = resp.Decode()
563635
if err != nil {
564636
switch err.(type) {
@@ -568,5 +640,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568640
return resp, fmt.Errorf("decode response body error: %w", err)
569641
}
570642
}
643+
571644
return resp, nil
572645
}

‎dial_test.go

+116-9
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ type mockIoConn struct {
8787
readbuf, writebuf bytes.Buffer
8888
// Calls readWg/writeWg.Wait() in Read()/Flush().
8989
readWg, writeWg sync.WaitGroup
90+
// wgDoneOnClose call Done() on the wait groups on Close().
91+
wgDoneOnClose bool
9092
// How many times to wait before a wg.Wait() call.
9193
readWgDelay, writeWgDelay int
9294
// Write()/Read()/Flush()/Close() calls count.
@@ -137,6 +139,12 @@ func (m *mockIoConn) Flush() error {
137139
}
138140

139141
func (m *mockIoConn) Close() error {
142+
if m.wgDoneOnClose {
143+
m.readWg.Done()
144+
m.writeWg.Done()
145+
m.wgDoneOnClose = false
146+
}
147+
140148
m.closeCnt++
141149
return nil
142150
}
@@ -165,6 +173,7 @@ func newMockIoConn() *mockIoConn {
165173
conn := new(mockIoConn)
166174
conn.readWg.Add(1)
167175
conn.writeWg.Add(1)
176+
conn.wgDoneOnClose = true
168177
return conn
169178
}
170179

@@ -201,9 +210,6 @@ func TestConn_Close(t *testing.T) {
201210
conn.Close()
202211

203212
assert.Equal(t, 1, dialer.conn.closeCnt)
204-
205-
dialer.conn.readWg.Done()
206-
dialer.conn.writeWg.Done()
207213
}
208214

209215
type stubAddr struct {
@@ -224,8 +230,6 @@ func TestConn_Addr(t *testing.T) {
224230
conn.addr = stubAddr{str: addr}
225231
})
226232
defer func() {
227-
dialer.conn.readWg.Done()
228-
dialer.conn.writeWg.Done()
229233
conn.Close()
230234
}()
231235

@@ -242,8 +246,6 @@ func TestConn_Greeting(t *testing.T) {
242246
conn.greeting = greeting
243247
})
244248
defer func() {
245-
dialer.conn.readWg.Done()
246-
dialer.conn.writeWg.Done()
247249
conn.Close()
248250
}()
249251

@@ -263,8 +265,6 @@ func TestConn_ProtocolInfo(t *testing.T) {
263265
conn.info = info
264266
})
265267
defer func() {
266-
dialer.conn.readWg.Done()
267-
dialer.conn.writeWg.Done()
268268
conn.Close()
269269
}()
270270

@@ -284,6 +284,7 @@ func TestConn_ReadWrite(t *testing.T) {
284284
0x01, 0xce, 0x00, 0x00, 0x00, 0x02,
285285
0x80, // Body map.
286286
})
287+
conn.wgDoneOnClose = false
287288
})
288289
defer func() {
289290
dialer.conn.writeWg.Done()
@@ -579,6 +580,24 @@ func TestNetDialer_Dial(t *testing.T) {
579580
}
580581
}
581582

583+
func TestNetDialer_Dial_hang_connection(t *testing.T) {
584+
l, err := net.Listen("tcp", "127.0.0.1:0")
585+
require.NoError(t, err)
586+
defer l.Close()
587+
588+
dialer := tarantool.NetDialer{
589+
Address: l.Addr().String(),
590+
}
591+
592+
ctx, cancel := test_helpers.GetConnectContext()
593+
defer cancel()
594+
595+
conn, err := dialer.Dial(ctx, tarantool.DialOpts{})
596+
597+
require.Nil(t, conn)
598+
require.Error(t, err, context.DeadlineExceeded)
599+
}
600+
582601
func TestNetDialer_Dial_requirements(t *testing.T) {
583602
l, err := net.Listen("tcp", "127.0.0.1:0")
584603
require.NoError(t, err)
@@ -685,6 +704,7 @@ func TestAuthDialer_Dial_DialerError(t *testing.T) {
685704

686705
ctx, cancel := test_helpers.GetConnectContext()
687706
defer cancel()
707+
688708
conn, err := dialer.Dial(ctx, tarantool.DialOpts{})
689709
if conn != nil {
690710
conn.Close()
@@ -717,6 +737,38 @@ func TestAuthDialer_Dial_NoSalt(t *testing.T) {
717737
}
718738
}
719739

740+
func TestConn_AuthDialer_hang_connection(t *testing.T) {
741+
salt := fmt.Sprintf("%s", testDialSalt)
742+
salt = base64.StdEncoding.EncodeToString([]byte(salt))
743+
mock := &mockIoDialer{
744+
init: func(conn *mockIoConn) {
745+
conn.greeting.Salt = salt
746+
conn.readWgDelay = 0
747+
conn.writeWgDelay = 0
748+
},
749+
}
750+
dialer := tarantool.AuthDialer{
751+
Dialer: mock,
752+
Username: "test",
753+
Password: "test",
754+
}
755+
756+
ctx, cancel := test_helpers.GetConnectContext()
757+
defer cancel()
758+
759+
conn, err := tarantool.Connect(ctx, &dialer,
760+
tarantool.Opts{
761+
Timeout: 1000 * time.Second, // Avoid pings.
762+
SkipSchema: true,
763+
})
764+
765+
require.Nil(t, conn)
766+
require.Error(t, err, context.DeadlineExceeded)
767+
require.Equal(t, mock.conn.writeCnt, 1)
768+
require.Equal(t, mock.conn.readCnt, 0)
769+
require.Greater(t, mock.conn.closeCnt, 1)
770+
}
771+
720772
func TestAuthDialer_Dial(t *testing.T) {
721773
salt := fmt.Sprintf("%s", testDialSalt)
722774
salt = base64.StdEncoding.EncodeToString([]byte(salt))
@@ -726,6 +778,7 @@ func TestAuthDialer_Dial(t *testing.T) {
726778
conn.writeWgDelay = 1
727779
conn.readWgDelay = 2
728780
conn.readbuf.Write(okResponse)
781+
conn.wgDoneOnClose = false
729782
},
730783
}
731784
defer func() {
@@ -758,6 +811,7 @@ func TestAuthDialer_Dial_PapSha256Auth(t *testing.T) {
758811
conn.writeWgDelay = 1
759812
conn.readWgDelay = 2
760813
conn.readbuf.Write(okResponse)
814+
conn.wgDoneOnClose = false
761815
},
762816
}
763817
defer func() {
@@ -800,6 +854,34 @@ func TestProtocolDialer_Dial_DialerError(t *testing.T) {
800854
assert.EqualError(t, err, "some error")
801855
}
802856

857+
func TestConn_ProtocolDialer_hang_connection(t *testing.T) {
858+
mock := &mockIoDialer{
859+
init: func(conn *mockIoConn) {
860+
conn.info = tarantool.ProtocolInfo{Version: 1}
861+
conn.readWgDelay = 0
862+
conn.writeWgDelay = 0
863+
},
864+
}
865+
dialer := tarantool.ProtocolDialer{
866+
Dialer: mock,
867+
}
868+
869+
ctx, cancel := test_helpers.GetConnectContext()
870+
defer cancel()
871+
872+
conn, err := tarantool.Connect(ctx, &dialer,
873+
tarantool.Opts{
874+
Timeout: 1000 * time.Second, // Avoid pings.
875+
SkipSchema: true,
876+
})
877+
878+
require.Nil(t, conn)
879+
require.Error(t, err, context.DeadlineExceeded)
880+
require.Equal(t, mock.conn.writeCnt, 1)
881+
require.Equal(t, mock.conn.readCnt, 0)
882+
require.Greater(t, mock.conn.closeCnt, 1)
883+
}
884+
803885
func TestProtocolDialer_Dial_IdentifyFailed(t *testing.T) {
804886
dialer := tarantool.ProtocolDialer{
805887
Dialer: &mockIoDialer{
@@ -898,6 +980,31 @@ func TestGreetingDialer_Dial_DialerError(t *testing.T) {
898980
assert.EqualError(t, err, "some error")
899981
}
900982

983+
func TestConn_GreetingDialer_hang_connection(t *testing.T) {
984+
mock := &mockIoDialer{
985+
init: func(conn *mockIoConn) {
986+
conn.readWgDelay = 0
987+
},
988+
}
989+
dialer := tarantool.GreetingDialer{
990+
Dialer: mock,
991+
}
992+
993+
ctx, cancel := test_helpers.GetConnectContext()
994+
defer cancel()
995+
996+
conn, err := tarantool.Connect(ctx, &dialer,
997+
tarantool.Opts{
998+
Timeout: 1000 * time.Second, // Avoid pings.
999+
SkipSchema: true,
1000+
})
1001+
1002+
require.Nil(t, conn)
1003+
require.Error(t, err, context.DeadlineExceeded)
1004+
require.Equal(t, mock.conn.readCnt, 1)
1005+
require.Greater(t, mock.conn.closeCnt, 1)
1006+
}
1007+
9011008
func TestGreetingDialer_Dial_GreetingFailed(t *testing.T) {
9021009
dialer := tarantool.GreetingDialer{
9031010
Dialer: &mockIoDialer{

0 commit comments

Comments
 (0)
Please sign in to comment.