@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289
289
if err != nil {
290
290
return conn , err
291
291
}
292
+
292
293
greeting := conn .Greeting ()
293
294
if greeting .Salt == "" {
294
295
conn .Close ()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309
310
}
310
311
}
311
312
312
- if err := authenticate (conn , d .Auth , d .Username , d .Password ,
313
+ if err := authenticate (ctx , conn , d .Auth , d .Username , d .Password ,
313
314
conn .Greeting ().Salt ); err != nil {
314
315
conn .Close ()
315
316
return nil , fmt .Errorf ("failed to authenticate: %w" , err )
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340
341
protocolInfo : d .RequiredProtocolInfo ,
341
342
}
342
343
343
- protocolConn .protocolInfo , err = identify (& protocolConn )
344
+ protocolConn .protocolInfo , err = identify (ctx , & protocolConn )
344
345
if err != nil {
345
346
protocolConn .Close ()
346
347
return nil , fmt .Errorf ("failed to identify: %w" , err )
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372
373
greetingConn := greetingConn {
373
374
Conn : conn ,
374
375
}
375
- version , salt , err := readGreeting (greetingConn )
376
+ version , salt , err := readGreeting (ctx , & greetingConn )
376
377
if err != nil {
377
378
greetingConn .Close ()
378
379
return nil , fmt .Errorf ("failed to read greeting: %w" , err )
379
380
}
381
+
380
382
greetingConn .greeting = Greeting {
381
383
Version : version ,
382
384
Salt : salt ,
@@ -410,31 +412,62 @@ func parseAddress(address string) (string, string) {
410
412
return network , address
411
413
}
412
414
415
+ // ioWaiter waits in a background until an io operation done or a context
416
+ // is expired. It closes the connection and writes a context error into the
417
+ // output channel if a context expired.
418
+ func ioWaiter (ctx context.Context , conn Conn , done <- chan struct {}) <- chan error {
419
+ doneWait := make (chan error , 1 )
420
+
421
+ go func () {
422
+ defer close (doneWait )
423
+
424
+ select {
425
+ case <- ctx .Done ():
426
+ conn .Close ()
427
+ <- done
428
+ doneWait <- ctx .Err ()
429
+ case <- done :
430
+ doneWait <- nil
431
+ }
432
+ }()
433
+
434
+ return doneWait
435
+ }
436
+
413
437
// readGreeting reads a greeting message.
414
- func readGreeting (reader io. Reader ) (string , string , error ) {
438
+ func readGreeting (ctx context. Context , conn Conn ) (string , string , error ) {
415
439
var version , salt string
416
440
441
+ doneRead := make (chan struct {})
442
+ doneWait := ioWaiter (ctx , conn , doneRead )
443
+
417
444
data := make ([]byte , 128 )
418
- _ , err := io .ReadFull (reader , data )
445
+ _ , err := io .ReadFull (conn , data )
446
+ close (doneRead )
447
+
419
448
if err == nil {
420
449
version = bytes .NewBuffer (data [:64 ]).String ()
421
450
salt = bytes .NewBuffer (data [64 :108 ]).String ()
422
451
}
423
452
453
+ if waitErr := <- doneWait ; waitErr != nil {
454
+ err = waitErr
455
+ }
456
+
424
457
return version , salt , err
425
458
}
426
459
427
460
// identify sends info about client protocol, receives info
428
461
// about server protocol in response and stores it in the connection.
429
- func identify (conn Conn ) (ProtocolInfo , error ) {
462
+ func identify (ctx context. Context , conn Conn ) (ProtocolInfo , error ) {
430
463
var info ProtocolInfo
431
464
432
465
req := NewIdRequest (clientProtocolInfo )
433
- if err := writeRequest (conn , req ); err != nil {
466
+ if err := writeRequest (ctx , conn , req ); err != nil {
434
467
return info , err
435
468
}
436
469
437
- resp , err := readResponse (conn , req )
470
+ resp , err := readResponse (ctx , conn , req )
438
471
if err != nil {
439
472
if resp != nil &&
440
473
resp .Header ().Error == iproto .ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +528,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495
528
}
496
529
497
530
// authenticate authenticates for a connection.
498
- func authenticate (c Conn , auth Auth , user string , pass string , salt string ) error {
531
+ func authenticate (ctx context. Context , c Conn , auth Auth , user , pass , salt string ) error {
499
532
var req Request
500
533
var err error
501
534
@@ -511,37 +544,74 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511
544
return errors .New ("unsupported method " + auth .String ())
512
545
}
513
546
514
- if err = writeRequest (c , req ); err != nil {
547
+ if err = writeRequest (ctx , c , req ); err != nil {
515
548
return err
516
549
}
517
- if _ , err = readResponse (c , req ); err != nil {
550
+ if _ , err = readResponse (ctx , c , req ); err != nil {
518
551
return err
519
552
}
520
553
return nil
521
554
}
522
555
523
556
// writeRequest writes a request to the writer.
524
- func writeRequest (w writeFlusher , req Request ) error {
557
+ func writeRequest (ctx context. Context , conn Conn , req Request ) error {
525
558
var packet smallWBuf
526
559
err := pack (& packet , msgpack .NewEncoder (& packet ), 0 , req , ignoreStreamId , nil )
527
560
528
561
if err != nil {
529
562
return fmt .Errorf ("pack error: %w" , err )
530
563
}
531
- if _ , err = w .Write (packet .b ); err != nil {
564
+
565
+ doneWrite := make (chan struct {})
566
+ doneWait := ioWaiter (ctx , conn , doneWrite )
567
+
568
+ _ , err = conn .Write (packet .b )
569
+ close (doneWrite )
570
+
571
+ if waitErr := <- doneWait ; waitErr != nil {
572
+ err = waitErr
573
+ }
574
+
575
+ if err != nil {
532
576
return fmt .Errorf ("write error: %w" , err )
533
577
}
534
- if err = w .Flush (); err != nil {
578
+
579
+ doneWrite = make (chan struct {})
580
+ doneWait = ioWaiter (ctx , conn , doneWrite )
581
+
582
+ err = conn .Flush ()
583
+ close (doneWrite )
584
+
585
+ if waitErr := <- doneWait ; waitErr != nil {
586
+ err = waitErr
587
+ }
588
+
589
+ if err != nil {
535
590
return fmt .Errorf ("flush error: %w" , err )
536
591
}
592
+
593
+ if waitErr := <- doneWait ; waitErr != nil {
594
+ err = waitErr
595
+ }
596
+
537
597
return err
538
598
}
539
599
540
600
// readResponse reads a response from the reader.
541
- func readResponse (r io. Reader , req Request ) (Response , error ) {
601
+ func readResponse (ctx context. Context , conn Conn , req Request ) (Response , error ) {
542
602
var lenbuf [packetLengthBytes ]byte
543
603
544
- respBytes , err := read (r , lenbuf [:])
604
+ doneRead := make (chan struct {})
605
+ doneWait := ioWaiter (ctx , conn , doneRead )
606
+
607
+ respBytes , err := read (conn , lenbuf [:])
608
+
609
+ close (doneRead )
610
+
611
+ if waitErr := <- doneWait ; waitErr != nil {
612
+ err = waitErr
613
+ }
614
+
545
615
if err != nil {
546
616
return nil , fmt .Errorf ("read error: %w" , err )
547
617
}
@@ -555,10 +625,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555
625
if err != nil {
556
626
return nil , fmt .Errorf ("decode response header error: %w" , err )
557
627
}
628
+
558
629
resp , err := req .Response (header , & buf )
559
630
if err != nil {
560
631
return nil , fmt .Errorf ("creating response error: %w" , err )
561
632
}
633
+
562
634
_ , err = resp .Decode ()
563
635
if err != nil {
564
636
switch err .(type ) {
@@ -568,5 +640,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568
640
return resp , fmt .Errorf ("decode response body error: %w" , err )
569
641
}
570
642
}
643
+
571
644
return resp , nil
572
645
}
0 commit comments