Skip to content

Commit

Permalink
all: introduce transport.Transport interface and Transport plugin mec…
Browse files Browse the repository at this point in the history
…hanism

Updates go-zeromq/zmq4#87.
  • Loading branch information
sbinet committed Oct 21, 2020
1 parent 361b05f commit 6459199
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 22 deletions.
27 changes: 27 additions & 0 deletions internal/inproc/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package inproc

import (
"context"
"net"

"github.com/go-zeromq/zmq4/transport"
)

// Transport implements the zmq4 Transport interface for the inproc transport.
type Transport struct{}

func (Transport) Dial(ctx context.Context, dialer transport.Dialer, addr string) (net.Conn, error) {
return Dial(addr)
}

func (Transport) Listen(ctx context.Context, addr string) (net.Listener, error) {
return Listen(addr)
}

var (
_ transport.Transport = (*Transport)(nil)
)
35 changes: 13 additions & 22 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"sync"
"time"

"github.com/go-zeromq/zmq4/internal/inproc"
"golang.org/x/xerrors"
)

Expand Down Expand Up @@ -179,15 +178,10 @@ func (sck *socket) Listen(endpoint string) error {

var l net.Listener

switch network {
case "ipc":
l, err = net.Listen("unix", addr)
case "tcp":
l, err = net.Listen("tcp", addr)
case "udp":
l, err = net.Listen("udp", addr)
case "inproc":
l, err = inproc.Listen(addr)
trans, ok := drivers.get(network)
switch {
case ok:
l, err = trans.Listen(sck.ctx, addr)
default:
panic("zmq4: unknown protocol " + network)
}
Expand Down Expand Up @@ -239,18 +233,15 @@ func (sck *socket) Dial(endpoint string) error {
return err
}

retries := 0
var conn net.Conn
var (
conn net.Conn
trans, ok = drivers.get(network)
retries = 0
)
connect:
switch network {
case "ipc":
conn, err = sck.dialer.DialContext(sck.ctx, "unix", addr)
case "tcp":
conn, err = sck.dialer.DialContext(sck.ctx, "tcp", addr)
case "udp":
conn, err = sck.dialer.DialContext(sck.ctx, "udp", addr)
case "inproc":
conn, err = inproc.Dial(addr)
switch {
case ok:
conn, err = trans.Dial(sck.ctx, &sck.dialer, addr)
default:
panic("zmq4: unknown protocol " + network)
}
Expand All @@ -261,7 +252,7 @@ connect:
time.Sleep(sck.retry)
goto connect
}
return xerrors.Errorf("zmq4: could not dial to %q: %w", endpoint, err)
return xerrors.Errorf("zmq4: could not dial to %q (retry=%v): %w", endpoint, sck.retry, err)
}

if conn == nil {
Expand Down
72 changes: 72 additions & 0 deletions transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2018 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zmq4

import (
"fmt"
"sort"
"sync"

"github.com/go-zeromq/zmq4/internal/inproc"
"github.com/go-zeromq/zmq4/transport"
)

// Transports returns the sorted list of currently registered transports.
func Transports() []string {
return drivers.names()
}

// RegisterTransport registers a new transport with the zmq4 package.
func RegisterTransport(name string, trans transport.Transport) error {
return drivers.add(name, trans)
}

type transports struct {
sync.RWMutex
db map[string]transport.Transport
}

func (ts *transports) get(name string) (transport.Transport, bool) {
ts.RLock()
defer ts.RUnlock()

v, ok := ts.db[name]
return v, ok
}

func (ts *transports) add(name string, trans transport.Transport) error {
ts.Lock()
defer ts.Unlock()

if old, dup := ts.db[name]; dup {
return fmt.Errorf("zmq4: duplicate transport %q (%T)", name, old)
}

ts.db[name] = trans
return nil
}

func (ts *transports) names() []string {
ts.RLock()
defer ts.RUnlock()

o := make([]string, 0, len(ts.db))
for k := range ts.db {
o = append(o, k)
}
sort.Strings(o)
return o
}

var drivers = transports{
db: make(map[string]transport.Transport),
}

func init() {
RegisterTransport("ipc", transport.New("unix"))
RegisterTransport("tcp", transport.New("tcp"))
RegisterTransport("udp", transport.New("udp"))
RegisterTransport("inproc", inproc.Transport{})
}
38 changes: 38 additions & 0 deletions transport/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2020 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package transport

import (
"context"
"net"
)

type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

// Transport is the zmq4 transport interface that wraps
// the Dial and Listen methods.
type Transport interface {
Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error)
Listen(ctx context.Context, addr string) (net.Listener, error)
}

type netTransport struct {
prot string
}

// New returns a new net-based transport with the given network (e.g "tcp").
func New(network string) Transport {
return netTransport{prot: network}
}

func (trans netTransport) Dial(ctx context.Context, dialer Dialer, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, trans.prot, addr)
}

func (trans netTransport) Listen(ctx context.Context, addr string) (net.Listener, error) {
return net.Listen(trans.prot, addr)
}
31 changes: 31 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2018 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zmq4

import (
"reflect"
"testing"

"github.com/go-zeromq/zmq4/internal/inproc"
)

func TestTransport(t *testing.T) {
if got, want := Transports(), []string{"inproc", "ipc", "tcp", "udp"}; !reflect.DeepEqual(got, want) {
t.Fatalf("invalid list of transports.\ngot= %q\nwant=%q", got, want)
}

err := RegisterTransport("tcp", inproc.Transport{})
if err == nil {
t.Fatalf("expected a duplicate-registration error")
}
if got, want := err.Error(), "zmq4: duplicate transport \"tcp\" (transport.netTransport)"; got != want {
t.Fatalf("invalid duplicate registration error:\ngot= %s\nwant=%s", got, want)
}

err = RegisterTransport("inproc2", inproc.Transport{})
if err != nil {
t.Fatalf("could not register 'inproc2': %+v", err)
}
}

0 comments on commit 6459199

Please sign in to comment.