@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289
289
if err != nil {
290
290
return conn , err
291
291
}
292
+
292
293
greeting := conn .Greeting ()
293
294
if greeting .Salt == "" {
294
295
conn .Close ()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309
310
}
310
311
}
311
312
312
- if err := authenticate (conn , d .Auth , d .Username , d .Password ,
313
+ if err := authenticate (ctx , conn , d .Auth , d .Username , d .Password ,
313
314
conn .Greeting ().Salt ); err != nil {
314
315
conn .Close ()
315
316
return nil , fmt .Errorf ("failed to authenticate: %w" , err )
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340
341
protocolInfo : d .RequiredProtocolInfo ,
341
342
}
342
343
343
- protocolConn .protocolInfo , err = identify (& protocolConn )
344
+ protocolConn .protocolInfo , err = identify (ctx , & protocolConn )
344
345
if err != nil {
345
346
protocolConn .Close ()
346
347
return nil , fmt .Errorf ("failed to identify: %w" , err )
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372
373
greetingConn := greetingConn {
373
374
Conn : conn ,
374
375
}
375
- version , salt , err := readGreeting (greetingConn )
376
+ version , salt , err := readGreeting (ctx , & greetingConn )
376
377
if err != nil {
377
378
greetingConn .Close ()
378
379
return nil , fmt .Errorf ("failed to read greeting: %w" , err )
379
380
}
381
+
380
382
greetingConn .greeting = Greeting {
381
383
Version : version ,
382
384
Salt : salt ,
@@ -410,31 +412,67 @@ func parseAddress(address string) (string, string) {
410
412
return network , address
411
413
}
412
414
415
+ // ioWaiter waits in a background until an io operation done or a context
416
+ // is expired. It closes the connection and writes a context error into the
417
+ // output channel on context expiration.
418
+ //
419
+ // A user of the helper should close the first output channel after an IO
420
+ // operation done and read an error from a second channel to get the result
421
+ // of waiting.
422
+ func ioWaiter (ctx context.Context , conn Conn ) (chan <- struct {}, <- chan error ) {
423
+ doneIO := make (chan struct {})
424
+ doneWait := make (chan error , 1 )
425
+
426
+ go func () {
427
+ defer close (doneWait )
428
+
429
+ select {
430
+ case <- ctx .Done ():
431
+ conn .Close ()
432
+ <- doneIO
433
+ doneWait <- ctx .Err ()
434
+ case <- doneIO :
435
+ doneWait <- nil
436
+ }
437
+ }()
438
+
439
+ return doneIO , doneWait
440
+ }
441
+
413
442
// readGreeting reads a greeting message.
414
- func readGreeting (reader io. Reader ) (string , string , error ) {
443
+ func readGreeting (ctx context. Context , conn Conn ) (string , string , error ) {
415
444
var version , salt string
416
445
446
+ doneRead , doneWait := ioWaiter (ctx , conn )
447
+
417
448
data := make ([]byte , 128 )
418
- _ , err := io .ReadFull (reader , data )
449
+ _ , err := io .ReadFull (conn , data )
450
+
451
+ close (doneRead )
452
+
419
453
if err == nil {
420
454
version = bytes .NewBuffer (data [:64 ]).String ()
421
455
salt = bytes .NewBuffer (data [64 :108 ]).String ()
422
456
}
423
457
458
+ if waitErr := <- doneWait ; waitErr != nil {
459
+ err = waitErr
460
+ }
461
+
424
462
return version , salt , err
425
463
}
426
464
427
465
// identify sends info about client protocol, receives info
428
466
// about server protocol in response and stores it in the connection.
429
- func identify (conn Conn ) (ProtocolInfo , error ) {
467
+ func identify (ctx context. Context , conn Conn ) (ProtocolInfo , error ) {
430
468
var info ProtocolInfo
431
469
432
470
req := NewIdRequest (clientProtocolInfo )
433
- if err := writeRequest (conn , req ); err != nil {
471
+ if err := writeRequest (ctx , conn , req ); err != nil {
434
472
return info , err
435
473
}
436
474
437
- resp , err := readResponse (conn , req )
475
+ resp , err := readResponse (ctx , conn , req )
438
476
if err != nil {
439
477
if resp != nil &&
440
478
resp .Header ().Error == iproto .ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +533,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495
533
}
496
534
497
535
// authenticate authenticates for a connection.
498
- func authenticate (c Conn , auth Auth , user string , pass string , salt string ) error {
536
+ func authenticate (ctx context. Context , c Conn , auth Auth , user , pass , salt string ) error {
499
537
var req Request
500
538
var err error
501
539
@@ -511,37 +549,73 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511
549
return errors .New ("unsupported method " + auth .String ())
512
550
}
513
551
514
- if err = writeRequest (c , req ); err != nil {
552
+ if err = writeRequest (ctx , c , req ); err != nil {
515
553
return err
516
554
}
517
- if _ , err = readResponse (c , req ); err != nil {
555
+ if _ , err = readResponse (ctx , c , req ); err != nil {
518
556
return err
519
557
}
520
558
return nil
521
559
}
522
560
523
561
// writeRequest writes a request to the writer.
524
- func writeRequest (w writeFlusher , req Request ) error {
562
+ func writeRequest (ctx context. Context , conn Conn , req Request ) error {
525
563
var packet smallWBuf
526
564
err := pack (& packet , msgpack .NewEncoder (& packet ), 0 , req , ignoreStreamId , nil )
527
565
528
566
if err != nil {
529
567
return fmt .Errorf ("pack error: %w" , err )
530
568
}
531
- if _ , err = w .Write (packet .b ); err != nil {
569
+
570
+ doneWrite , doneWait := ioWaiter (ctx , conn )
571
+
572
+ _ , err = conn .Write (packet .b )
573
+
574
+ close (doneWrite )
575
+
576
+ if waitErr := <- doneWait ; waitErr != nil {
577
+ err = waitErr
578
+ }
579
+
580
+ if err != nil {
532
581
return fmt .Errorf ("write error: %w" , err )
533
582
}
534
- if err = w .Flush (); err != nil {
583
+
584
+ doneWrite , doneWait = ioWaiter (ctx , conn )
585
+
586
+ err = conn .Flush ()
587
+
588
+ close (doneWrite )
589
+
590
+ if waitErr := <- doneWait ; waitErr != nil {
591
+ err = waitErr
592
+ }
593
+
594
+ if err != nil {
535
595
return fmt .Errorf ("flush error: %w" , err )
536
596
}
597
+
598
+ if waitErr := <- doneWait ; waitErr != nil {
599
+ err = waitErr
600
+ }
601
+
537
602
return err
538
603
}
539
604
540
605
// readResponse reads a response from the reader.
541
- func readResponse (r io. Reader , req Request ) (Response , error ) {
606
+ func readResponse (ctx context. Context , conn Conn , req Request ) (Response , error ) {
542
607
var lenbuf [packetLengthBytes ]byte
543
608
544
- respBytes , err := read (r , lenbuf [:])
609
+ doneRead , doneWait := ioWaiter (ctx , conn )
610
+
611
+ respBytes , err := read (conn , lenbuf [:])
612
+
613
+ close (doneRead )
614
+
615
+ if waitErr := <- doneWait ; waitErr != nil {
616
+ err = waitErr
617
+ }
618
+
545
619
if err != nil {
546
620
return nil , fmt .Errorf ("read error: %w" , err )
547
621
}
@@ -555,10 +629,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555
629
if err != nil {
556
630
return nil , fmt .Errorf ("decode response header error: %w" , err )
557
631
}
632
+
558
633
resp , err := req .Response (header , & buf )
559
634
if err != nil {
560
635
return nil , fmt .Errorf ("creating response error: %w" , err )
561
636
}
637
+
562
638
_ , err = resp .Decode ()
563
639
if err != nil {
564
640
switch err .(type ) {
@@ -568,5 +644,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568
644
return resp , fmt .Errorf ("decode response body error: %w" , err )
569
645
}
570
646
}
647
+
571
648
return resp , nil
572
649
}
0 commit comments