Skip to content

Commit

Permalink
dcerpc: fix possible deadlocks in sender/receiver close flow
Browse files Browse the repository at this point in the history
  • Loading branch information
oiweiwei committed Sep 13, 2024
1 parent eb1b248 commit 0dd6eba
Showing 1 changed file with 69 additions and 33 deletions.
102 changes: 69 additions & 33 deletions dcerpc/transport_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,17 @@ func (c *call) ID() uint32 {
// with the next fragment by calling `Ready`. (this is done to acquire security
// lock before next request will be able to use it).
func (c *call) ReadBuffer(ctx context.Context, p []byte) (Header, error) {
var hdr Header
var ok bool

// wait.
hdr, ok := <-c.outQ
if !ok {
return Header{}, io.ErrUnexpectedEOF
select {
case hdr, ok = <-c.outQ:
if !ok {
return Header{}, io.ErrUnexpectedEOF
}
case <-ctx.Done():
return Header{}, ctx.Err()
}
if !c.noCopy {
copy(p, c.recv)
Expand All @@ -73,7 +80,10 @@ func (c *call) ReadBuffer(ctx context.Context, p []byte) (Header, error) {
// the reader will be performed.
func (c *call) Ready(ctx context.Context) {
// done.
c.inQ <- nil
select {
case c.inQ <- nil:
case <-ctx.Done():
}
}

// WriteBuffer function copies the buffer `p` to xmit buffer and notifies
Expand All @@ -84,24 +94,30 @@ func (c *call) WriteBuffer(ctx context.Context, hdr Header, p []byte) error {
copy(c.xmit, p)
}
// ready.
c.outQ <- hdr
select {
case c.outQ <- hdr:
case <-ctx.Done():
return ctx.Err()
}
// wait for response.
return <-c.inQ
select {
case err := <-c.inQ:
return err
case <-ctx.Done():
return ctx.Err()
}
}

// WriteBuffer function writes the data `p` to the wire.
func (c *transport) WriteBuffer(ctx context.Context, hdr Header, p []byte) error {

ctx, cancel := context.WithTimeout(ctx, c.settings.Timeout)
defer cancel()

if int(hdr.FragLength) > c.settings.MaxXmitFrag {
return ErrPacketTooLong
}

p = p[:hdr.FragLength]

return doWithTimeout(ctx, func() error {
return doWithTimeout(ctx, c.settings.Timeout, func() error {
for n := 0; n < int(hdr.FragLength); {
actual, err := c.cc.Write(p[n:])
if err != nil {
Expand All @@ -128,7 +144,7 @@ func (c *transport) ReadBuffer(ctx context.Context, p []byte) (Header, error) {
p = p[:c.settings.MaxRecvFrag]

// read header.
if err := doWithTimeout(ctx, func() error {
if err := doWithTimeout(ctx, c.settings.Timeout, func() error {
for n := 0; n < HeaderSize; {
actual, err := c.cc.Read(p[n:HeaderSize])
if err != nil {
Expand All @@ -149,7 +165,7 @@ func (c *transport) ReadBuffer(ctx context.Context, p []byte) (Header, error) {
return hdr, ErrPacketTooLong
}
// read remaining fragment.
if err := doWithTimeout(ctx, func() error {
if err := doWithTimeout(ctx, c.settings.Timeout, func() error {
for n := HeaderSize; n < int(hdr.FragLength); {
actual, err := c.cc.Read(p[n:hdr.FragLength])
if err != nil {
Expand Down Expand Up @@ -178,7 +194,7 @@ func (t *transport) recvLoop(ctx context.Context) error {
t.logger.Error().Uint32("call_id", call.ID()).Err(err).Msg("serving response error")
}
case <-ctx.Done():
return t.WithErr(ctx.Err())
return nil
}
}
}
Expand Down Expand Up @@ -221,6 +237,8 @@ func (t *transport) recv(ctx context.Context, call *call) error {
case call.outQ <- hdr:
case <-deadline.C:
return t.WithErr(fmt.Errorf("caller-receiver timer expired"))
case <-ctx.Done():
return nil
}

// wait for buffer copy.
Expand All @@ -229,6 +247,8 @@ func (t *transport) recv(ctx context.Context, call *call) error {
case <-call.inQ:
case <-deadline.C:
return t.WithErr(fmt.Errorf("caller-ready timer expired"))
case <-ctx.Done():
return nil
}

// remove caller from the wait queue.
Expand All @@ -245,14 +265,16 @@ func (t *transport) sendLoop(ctx context.Context) error {
t.logger.Debug().Msg("started sender routine")

for {

select {
case call := <-t.txQ:
t.logger.Debug().Uint32("call_id", call.ID()).Msg("serving call")
if err := t.send(ctx, call); err != nil {

t.logger.Error().Uint32("call_id", call.ID()).Err(err).Msg("serving call error")
}
case <-ctx.Done():
return t.WithErr(ctx.Err())
return nil
}
}
}
Expand All @@ -269,7 +291,12 @@ func (t *transport) send(ctx context.Context, call *call) error {

for {
// wait for write done.
hdr := <-call.outQ
var hdr Header
select {
case hdr = <-call.outQ:
case <-ctx.Done():
return nil
}
// write buffer.
err := t.WriteBuffer(ctx, hdr, t.tx)

Expand All @@ -279,6 +306,8 @@ func (t *transport) send(ctx context.Context, call *call) error {
case call.inQ <- err:
case <-deadline.C:
err = fmt.Errorf("client-ack timer expired")
case <-ctx.Done():
return nil
}

if err := t.WithErr(err); err != nil {
Expand All @@ -296,11 +325,17 @@ func (t *transport) send(ctx context.Context, call *call) error {
}

// publish to receiver queue.
t.rxQ <- call
select {
case t.rxQ <- call:
case <-ctx.Done():
}

// wait for receive completes.
if t.settings.Multiplexing {
<-call.done
select {
case <-call.done:
case <-ctx.Done():
}
}

return nil
Expand All @@ -326,27 +361,28 @@ func clearTimer(t **time.Timer) {
}

// doWithTimeout.
func doWithTimeout(ctx context.Context, f func() error) error {
func doWithTimeout(ctx context.Context, timout time.Duration, f func() error) error {

done := make(chan error, 1)
go func() {
done <- f()
}()

loop:
for {
select {
case err := <-done:
if err != nil {
return err
}
break loop
case <-ctx.Done():
if ctx.Err() != nil {
return ctx.Err()
}
return context.DeadlineExceeded
}
timer := time.NewTimer(timout)

var err error

select {
case err = <-done:
case <-timer.C:
err = context.DeadlineExceeded
case <-ctx.Done():
err = ctx.Err()
}
return nil

if !timer.Stop() {
<-timer.C
}

return err
}

0 comments on commit 0dd6eba

Please sign in to comment.