Skip to content

Commit

Permalink
fix: OnRequest should wait all readable data consumed when sender clo…
Browse files Browse the repository at this point in the history
…se connection after send data
  • Loading branch information
joway committed Jul 31, 2023
1 parent 68ebb21 commit c6f8e6f
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 33 deletions.
6 changes: 5 additions & 1 deletion connection_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"sync/atomic"
)

type who int32
type who = int32

const (
none who = iota
Expand Down Expand Up @@ -65,6 +65,10 @@ func (l *locker) isCloseBy(w who) (yes bool) {
return atomic.LoadInt32(&l.keychain[closing]) == int32(w)
}

func (l *locker) force(k key, v int32) {
atomic.StoreInt32(&l.keychain[k], v)
}

func (l *locker) lock(k key) (success bool) {
return atomic.CompareAndSwapInt32(&l.keychain[k], 0, 1)
}
Expand Down
2 changes: 1 addition & 1 deletion connection_onevent.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f
if isProcessable(c) {
process(c)
}
for c.IsActive() && isProcessable(c) {
for !c.isCloseBy(user) && isProcessable(c) {
process(c)
}
// Handling callback if connection has been closed.
Expand Down
32 changes: 18 additions & 14 deletions connection_reactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ import (

// onHup means close by poller.
func (c *connection) onHup(p Poll) error {
if c.closeBy(poller) {
c.triggerRead()
c.triggerWrite(ErrConnClosed)
// It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively.
// It can be confirmed that the OnRequest goroutine has been exited before closecallback executing,
// and it is safe to close the buffer at this time.
var onConnect, _ = c.onConnectCallback.Load().(OnConnect)
var onRequest, _ = c.onRequestCallback.Load().(OnRequest)
if onConnect != nil || onRequest != nil {
c.closeCallback(true)
}
if !c.closeBy(poller) {
return nil
}
c.triggerRead()
c.triggerWrite(ErrConnClosed)
// It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively.
// It can be confirmed that the OnRequest goroutine has been exited before closecallback executing,
// and it is safe to close the buffer at this time.
var onConnect, _ = c.onConnectCallback.Load().(OnConnect)
var onRequest, _ = c.onRequestCallback.Load().(OnRequest)
if onConnect != nil || onRequest != nil {
c.closeCallback(true)
}
return nil
}
Expand All @@ -48,9 +49,12 @@ func (c *connection) onClose() error {
c.closeCallback(true)
return nil
}
if c.isCloseBy(poller) {
// Connection with OnRequest of nil
// relies on the user to actively close the connection to recycle resources.
closedByPoller := c.isCloseBy(poller)
// force change closed by user
c.force(closing, user)

// If OnRequest is nil, relies on the user to actively close the connection to recycle resources.
if closedByPoller {
c.closeCallback(true)
}
return nil
Expand Down
45 changes: 45 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,48 @@ func TestConnDetach(t *testing.T) {
err = ln.Close()
MustNil(t, err)
}

func TestParallelShortConnection(t *testing.T) {
ln, err := CreateListener("tcp", ":1234")
MustNil(t, err)

var received int64
el, err := NewEventLoop(func(ctx context.Context, connection Connection) error {
data, err := connection.Reader().Next(connection.Reader().Len())
if err != nil {
return err
}
atomic.AddInt64(&received, int64(len(data)))
//t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive())
return nil
})
go func() {
el.Serve(ln)
}()

conns := 100
sizePerConn := 1024 * 100
totalSize := conns * sizePerConn
var wg sync.WaitGroup
for i := 0; i < conns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := DialConnection("tcp", ":1234", time.Second)
MustNil(t, err)
n, err := conn.Writer().WriteBinary(make([]byte, sizePerConn))
MustNil(t, err)
MustTrue(t, n == sizePerConn)
err = conn.Writer().Flush()
MustNil(t, err)
err = conn.Close()
MustNil(t, err)
}()
}
wg.Wait()

for atomic.LoadInt64(&received) < int64(totalSize) {
t.Logf("received: %d, except: %d", atomic.LoadInt64(&received), totalSize)
time.Sleep(time.Millisecond * 100)
}
}
7 changes: 4 additions & 3 deletions poll_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,22 @@ func (p *defaultPoll) onhups() {
}

// readall read all left data before close connection
func readall(op *FDOperator, br barrier) (err error) {
func readall(op *FDOperator, br barrier) (total int, err error) {
var bs = br.bs
var ivs = br.ivs
var n int
for {
bs = op.Inputs(br.bs)
if len(bs) == 0 {
return nil
return total, nil
}

TryRead:
n, err = ioread(op.FD, bs, ivs)
op.InputAck(n)
total += n
if err != nil {
return err
return total, err
}
if n == 0 {
goto TryRead
Expand Down
22 changes: 15 additions & 7 deletions poll_default_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (p *defaultPoll) Wait() error {
continue
}

var totalRead int
evt := events[i]
triggerRead = evt.Filter == syscall.EVFILT_READ && evt.Flags&syscall.EV_ENABLE != 0
triggerWrite = evt.Filter == syscall.EVFILT_WRITE && evt.Flags&syscall.EV_ENABLE != 0
Expand All @@ -105,21 +106,28 @@ func (p *defaultPoll) Wait() error {
if len(bs) > 0 {
var n, err = ioread(operator.FD, bs, barriers[i].ivs)
operator.InputAck(n)
totalRead += n
if err != nil {
p.appendHup(operator)
continue
}
}
}
}
if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close
if err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) {
logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error())
}
}
if triggerHup {
p.appendHup(operator)
continue
if triggerRead && operator.Inputs != nil {
var leftRead int
// read all left data if peer send and close
if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) {
logger.Printf("NETPOLL: readall(fd=%d)=%d before close: %s", operator.FD, total, err.Error())
}
totalRead += leftRead
}
// only close connection if no further read bytes
if totalRead == 0 {
p.appendHup(operator)
continue
}
}
if triggerWrite {
if operator.OnWrite != nil {
Expand Down
24 changes: 17 additions & 7 deletions poll_default_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,14 @@ func (p *defaultPoll) Wait() (err error) {

func (p *defaultPoll) handler(events []epollevent) (closed bool) {
var triggerRead, triggerWrite, triggerHup, triggerError bool
var err error
for i := range events {
operator := p.getOperator(0, unsafe.Pointer(&events[i].data))
if operator == nil || !operator.do() {
continue
}

var totalRead int
evt := events[i].events
triggerRead = evt&syscall.EPOLLIN != 0
triggerWrite = evt&syscall.EPOLLOUT != 0
Expand Down Expand Up @@ -155,6 +157,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
if len(bs) > 0 {
var n, err = ioread(operator.FD, bs, p.barriers[i].ivs)
operator.InputAck(n)
totalRead += n
if err != nil {
p.appendHup(operator)
continue
Expand All @@ -164,14 +167,21 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator)
}
}
if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close
if err := readall(operator, p.barriers[i]); err != nil && !errors.Is(err, ErrEOF) {
logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error())
}
}
if triggerHup {
p.appendHup(operator)
continue
if triggerRead && operator.Inputs != nil {
// read all left data if peer send and close
var leftRead int
// read all left data if peer send and close
if leftRead, err = readall(operator, p.barriers[i]); err != nil && !errors.Is(err, ErrEOF) {
logger.Printf("NETPOLL: readall(fd=%d)=%d before close: %s", operator.FD, total, err.Error())
}
totalRead += leftRead
}
// only close connection if no further read bytes
if totalRead == 0 {
p.appendHup(operator)
continue
}
}
if triggerError {
// Under block-zerocopy, the kernel may give an error callback, which is not a real error, just an EAGAIN.
Expand Down

0 comments on commit c6f8e6f

Please sign in to comment.