Skip to content

Commit

Permalink
zmq4: make sure msgreaders and msgwriters are properly closed
Browse files Browse the repository at this point in the history
Fixes #34.
  • Loading branch information
sbinet committed Nov 8, 2018
1 parent c1e1b9d commit 94bac15
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
42 changes: 42 additions & 0 deletions msgio.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type rpool interface {
io.Closer

addConn(r *msgReader)
rmConn(r *msgReader)
read(ctx context.Context, msg *Msg) error
}

Expand All @@ -25,6 +26,7 @@ type wpool interface {
io.Closer

addConn(w *msgWriter)
rmConn(r *msgWriter)
write(ctx context.Context, msg Msg) error
}

Expand Down Expand Up @@ -104,6 +106,22 @@ func (q *qreader) addConn(r *msgReader) {
q.mu.Unlock()
}

func (q *qreader) rmConn(r *msgReader) {
q.mu.Lock()
defer q.mu.Unlock()

cur := -1
for i := range q.rs {
if q.rs[i] == r {
cur = i
break
}
}
if cur >= 0 {
q.rs = append(q.rs[:cur], q.rs[cur+1:]...)
}
}

func (q *qreader) read(ctx context.Context, msg *Msg) error {
q.sem.lock()
select {
Expand All @@ -114,6 +132,9 @@ func (q *qreader) read(ctx context.Context, msg *Msg) error {
}

func (q *qreader) listen(ctx context.Context, r *msgReader) {
defer q.rmConn(r)
defer r.Close()

for {
var msg Msg
err := r.read(ctx, &msg)
Expand Down Expand Up @@ -164,6 +185,22 @@ func (mw *mwriter) addConn(w *msgWriter) {
mw.mu.Unlock()
}

func (mw *mwriter) rmConn(w *msgWriter) {
mw.mu.Lock()
defer mw.mu.Unlock()

cur := -1
for i := range mw.ws {
if mw.ws[i] == w {
cur = i
break
}
}
if cur >= 0 {
mw.ws = append(mw.ws[:cur], mw.ws[cur+1:]...)
}
}

func (w *mwriter) write(ctx context.Context, msg Msg) error {
w.sem.lock()
grp, ctx := errgroup.WithContext(ctx)
Expand Down Expand Up @@ -204,6 +241,8 @@ func (lw *lbwriter) addConn(w *msgWriter) {
go lw.listen(lw.ctx, w)
}

func (*lbwriter) rmConn(w *msgWriter) {}

func (lw *lbwriter) write(ctx context.Context, msg Msg) error {
lw.sem.lock()
select {
Expand All @@ -215,6 +254,9 @@ func (lw *lbwriter) write(ctx context.Context, msg Msg) error {
}

func (lw *lbwriter) listen(ctx context.Context, w *msgWriter) {
defer lw.rmConn(w)
defer w.Close()

for {
select {
case <-ctx.Done():
Expand Down
35 changes: 35 additions & 0 deletions pub.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ func (q *pubQReader) addConn(r *msgReader) {
q.mu.Unlock()
}

func (q *pubQReader) rmConn(r *msgReader) {
q.mu.Lock()
defer q.mu.Unlock()

cur := -1
for i := range q.rs {
if q.rs[i] == r {
cur = i
break
}
}
if cur >= 0 {
q.rs = append(q.rs[:cur], q.rs[cur+1:]...)
}
}

func (q *pubQReader) read(ctx context.Context, msg *Msg) error {
q.sem.lock()
select {
Expand All @@ -122,6 +138,9 @@ func (q *pubQReader) read(ctx context.Context, msg *Msg) error {
}

func (q *pubQReader) listen(ctx context.Context, r *msgReader) {
defer q.rmConn(r)
defer r.Close()

for {
var msg Msg
err := r.read(ctx, &msg)
Expand Down Expand Up @@ -189,6 +208,22 @@ func (mw *pubMWriter) addConn(w *msgWriter) {
mw.mu.Unlock()
}

func (mw *pubMWriter) rmConn(w *msgWriter) {
mw.mu.Lock()
defer mw.mu.Unlock()

cur := -1
for i := range mw.ws {
if mw.ws[i] == w {
cur = i
break
}
}
if cur >= 0 {
mw.ws = append(mw.ws[:cur], mw.ws[cur+1:]...)
}
}

func (w *pubMWriter) write(ctx context.Context, msg Msg) error {
w.sem.lock()
grp, ctx := errgroup.WithContext(ctx)
Expand Down
35 changes: 35 additions & 0 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ func (q *routerQReader) addConn(r *msgReader) {
q.mu.Unlock()
}

func (q *routerQReader) rmConn(r *msgReader) {
q.mu.Lock()
defer q.mu.Unlock()

cur := -1
for i := range q.rs {
if q.rs[i] == r {
cur = i
break
}
}
if cur >= 0 {
q.rs = append(q.rs[:cur], q.rs[cur+1:]...)
}
}

func (q *routerQReader) read(ctx context.Context, msg *Msg) error {
q.sem.lock()
select {
Expand All @@ -121,6 +137,9 @@ func (q *routerQReader) read(ctx context.Context, msg *Msg) error {
}

func (q *routerQReader) listen(ctx context.Context, r *msgReader) {
defer q.rmConn(r)
defer r.Close()

id := []byte(r.r.Peer.Meta[sysSockID])
for {
var msg Msg
Expand Down Expand Up @@ -173,6 +192,22 @@ func (mw *routerMWriter) addConn(w *msgWriter) {
mw.mu.Unlock()
}

func (mw *routerMWriter) rmConn(w *msgWriter) {
mw.mu.Lock()
defer mw.mu.Unlock()

cur := -1
for i := range mw.ws {
if mw.ws[i] == w {
cur = i
break
}
}
if cur >= 0 {
mw.ws = append(mw.ws[:cur], mw.ws[cur+1:]...)
}
}

func (w *routerMWriter) write(ctx context.Context, msg Msg) error {
w.sem.lock()
grp, ctx := errgroup.WithContext(ctx)
Expand Down

0 comments on commit 94bac15

Please sign in to comment.