Skip to content

dial: Connect() does not cancel by context if no i/o #443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.

### Fixed

- Connect() may not cancel Dial() call on context expiration if network
connection hangs (#443).

## [v2.3.1] - 2025-04-03

The patch releases fixes expected Connect() behavior and reduces allocations.

### Added

- A usage of sync.Pool of msgpack.Decoder saves 2 object allocations per
a response decoding.
a response decoding (#440).

### Changed

Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ func (conn *Connection) dial(ctx context.Context) error {
}

req := newWatchRequest(key.(string))
if err = writeRequest(c, req); err != nil {
if err = writeRequest(ctx, c, req); err != nil {
st <- state
return false
}
Expand Down
109 changes: 93 additions & 16 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
if err != nil {
return conn, err
}

greeting := conn.Greeting()
if greeting.Salt == "" {
conn.Close()
Expand All @@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
}
}

if err := authenticate(conn, d.Auth, d.Username, d.Password,
if err := authenticate(ctx, conn, d.Auth, d.Username, d.Password,
conn.Greeting().Salt); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to authenticate: %w", err)
Expand Down Expand Up @@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
protocolInfo: d.RequiredProtocolInfo,
}

protocolConn.protocolInfo, err = identify(&protocolConn)
protocolConn.protocolInfo, err = identify(ctx, &protocolConn)
if err != nil {
protocolConn.Close()
return nil, fmt.Errorf("failed to identify: %w", err)
Expand Down Expand Up @@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
greetingConn := greetingConn{
Conn: conn,
}
version, salt, err := readGreeting(greetingConn)
version, salt, err := readGreeting(ctx, &greetingConn)
if err != nil {
greetingConn.Close()
return nil, fmt.Errorf("failed to read greeting: %w", err)
}

greetingConn.greeting = Greeting{
Version: version,
Salt: salt,
Expand Down Expand Up @@ -410,31 +412,67 @@ func parseAddress(address string) (string, string) {
return network, address
}

// ioWaiter waits in a background until an io operation done or a context
// is expired. It closes the connection and writes a context error into the
// output channel on context expiration.
//
// A user of the helper should close the first output channel after an IO
// operation done and read an error from a second channel to get the result
// of waiting.
func ioWaiter(ctx context.Context, conn Conn) (chan<- struct{}, <-chan error) {
doneIO := make(chan struct{})
doneWait := make(chan error, 1)

go func() {
defer close(doneWait)

select {
case <-ctx.Done():
conn.Close()
<-doneIO
doneWait <- ctx.Err()
case <-doneIO:
doneWait <- nil
}
}()

return doneIO, doneWait
}

// readGreeting reads a greeting message.
func readGreeting(reader io.Reader) (string, string, error) {
func readGreeting(ctx context.Context, conn Conn) (string, string, error) {
var version, salt string

doneRead, doneWait := ioWaiter(ctx, conn)

data := make([]byte, 128)
_, err := io.ReadFull(reader, data)
_, err := io.ReadFull(conn, data)

close(doneRead)

if err == nil {
version = bytes.NewBuffer(data[:64]).String()
salt = bytes.NewBuffer(data[64:108]).String()
}

if waitErr := <-doneWait; waitErr != nil {
err = waitErr
}

return version, salt, err
}

// identify sends info about client protocol, receives info
// about server protocol in response and stores it in the connection.
func identify(conn Conn) (ProtocolInfo, error) {
func identify(ctx context.Context, conn Conn) (ProtocolInfo, error) {
var info ProtocolInfo

req := NewIdRequest(clientProtocolInfo)
if err := writeRequest(conn, req); err != nil {
if err := writeRequest(ctx, conn, req); err != nil {
return info, err
}

resp, err := readResponse(conn, req)
resp, err := readResponse(ctx, conn, req)
if err != nil {
if resp != nil &&
resp.Header().Error == iproto.ER_UNKNOWN_REQUEST_TYPE {
Expand Down Expand Up @@ -495,7 +533,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
}

// authenticate authenticates for a connection.
func authenticate(c Conn, auth Auth, user string, pass string, salt string) error {
func authenticate(ctx context.Context, c Conn, auth Auth, user, pass, salt string) error {
var req Request
var err error

Expand All @@ -511,37 +549,73 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
return errors.New("unsupported method " + auth.String())
}

if err = writeRequest(c, req); err != nil {
if err = writeRequest(ctx, c, req); err != nil {
return err
}
if _, err = readResponse(c, req); err != nil {
if _, err = readResponse(ctx, c, req); err != nil {
return err
}
return nil
}

// writeRequest writes a request to the writer.
func writeRequest(w writeFlusher, req Request) error {
func writeRequest(ctx context.Context, conn Conn, req Request) error {
var packet smallWBuf
err := pack(&packet, msgpack.NewEncoder(&packet), 0, req, ignoreStreamId, nil)

if err != nil {
return fmt.Errorf("pack error: %w", err)
}
if _, err = w.Write(packet.b); err != nil {

doneWrite, doneWait := ioWaiter(ctx, conn)

_, err = conn.Write(packet.b)

close(doneWrite)

if waitErr := <-doneWait; waitErr != nil {
err = waitErr
}

if err != nil {
return fmt.Errorf("write error: %w", err)
}
if err = w.Flush(); err != nil {

doneWrite, doneWait = ioWaiter(ctx, conn)

err = conn.Flush()

close(doneWrite)

if waitErr := <-doneWait; waitErr != nil {
err = waitErr
}

if err != nil {
return fmt.Errorf("flush error: %w", err)
}

if waitErr := <-doneWait; waitErr != nil {
err = waitErr
}

return err
}

// readResponse reads a response from the reader.
func readResponse(r io.Reader, req Request) (Response, error) {
func readResponse(ctx context.Context, conn Conn, req Request) (Response, error) {
var lenbuf [packetLengthBytes]byte

respBytes, err := read(r, lenbuf[:])
doneRead, doneWait := ioWaiter(ctx, conn)

respBytes, err := read(conn, lenbuf[:])

close(doneRead)

if waitErr := <-doneWait; waitErr != nil {
err = waitErr
}

if err != nil {
return nil, fmt.Errorf("read error: %w", err)
}
Expand All @@ -555,10 +629,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
if err != nil {
return nil, fmt.Errorf("decode response header error: %w", err)
}

resp, err := req.Response(header, &buf)
if err != nil {
return nil, fmt.Errorf("creating response error: %w", err)
}

_, err = resp.Decode()
if err != nil {
switch err.(type) {
Expand All @@ -568,5 +644,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
return resp, fmt.Errorf("decode response body error: %w", err)
}
}

return resp, nil
}
Loading
Loading