Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dmsg client registration #253

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions internal/dmsg-discovery/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func New(log logrus.FieldLogger, db store.Storer, m discmetrics.Metrics, testMod
r.Post("/dmsg-discovery/entry/{pk}", api.setEntry())
r.Delete("/dmsg-discovery/entry", api.delEntry())
r.Get("/dmsg-discovery/entries", api.allEntries())
r.Get("/dmsg-discovery/visorEntries", api.allVisorEntries())
r.Delete("/dmsg-discovery/deregister", api.deregisterEntry())
r.Get("/dmsg-discovery/available_servers", api.getAvailableServers())
r.Get("/dmsg-discovery/all_servers", api.getAllServers())
Expand Down Expand Up @@ -163,6 +164,20 @@ func (a *API) allEntries() func(w http.ResponseWriter, r *http.Request) {
}
}

// allVisorEntries returns all visor client entries connected to dmsg
// URI: /dmsg-discovery/visorEntries
// Method: GET
func (a *API) allVisorEntries() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
entries, err := a.db.AllVisorEntries(r.Context())
if err != nil {
a.handleError(w, r, err)
return
}
a.writeJSON(w, r, http.StatusOK, entries)
}
}

// deregisterEntry deletes the client entry associated with the PK requested by the network monitor
// URI: /dmsg-discovery/deregister/:pk
// Method: DELETE
Expand Down
16 changes: 16 additions & 0 deletions internal/dmsg-discovery/store/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ func (r *redisStore) SetEntry(ctx context.Context, entry *disc.Entry, timeout ti
return disc.ErrUnexpected
}
}
if entry.ClientType == "visor" {
err = r.client.SAdd(ctx, "visorClients", entry.Static.Hex()).Err()
if err != nil {
log.WithError(err).Errorf("Failed to add to visorClients (SAdd) from redis")
return disc.ErrUnexpected
}
}

return nil
}
Expand All @@ -107,6 +114,7 @@ func (r *redisStore) DelEntry(ctx context.Context, staticPubKey cipher.PubKey) e
// Delete pubkey from servers or clients set stored
r.client.SRem(ctx, "servers", staticPubKey.Hex())
r.client.SRem(ctx, "clients", staticPubKey.Hex())
r.client.SRem(ctx, "visorClients", staticPubKey.Hex())
return nil
}

Expand Down Expand Up @@ -233,3 +241,11 @@ func (r *redisStore) AllEntries(ctx context.Context) ([]string, error) {
}
return clients, err
}

func (r *redisStore) AllVisorEntries(ctx context.Context) ([]string, error) {
clients, err := r.client.SMembers(ctx, "visorClients").Result()
if err != nil {
return nil, err
}
return clients, err
}
3 changes: 3 additions & 0 deletions internal/dmsg-discovery/store/storer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type Storer interface {

// AllEntries returns all clients PKs.
AllEntries(ctx context.Context) ([]string, error)

// AllVisorEntries returns all clients PKs.
AllVisorEntries(ctx context.Context) ([]string, error)
}

// Config configures the Store object.
Expand Down
21 changes: 21 additions & 0 deletions internal/dmsg-discovery/store/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,24 @@ func (ms *MockStore) AllEntries(_ context.Context) ([]string, error) {
}
return entries, nil
}

// AllVisorEntries implements Storer CountEntries method for MockStore
func (ms *MockStore) AllVisorEntries(_ context.Context) ([]string, error) {
entries := []string{}

ms.mLock.RLock()
defer ms.mLock.RUnlock()

clients := arrayFromMap(ms.m)
for _, entryString := range clients {
var e disc.Entry

err := json.Unmarshal(entryString, &e)
if err != nil {
return nil, disc.ErrUnexpected
}

entries = append(entries, e.String())
}
return entries, nil
}
3 changes: 3 additions & 0 deletions pkg/disc/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ type Entry struct {
// Contains the instance's client meta if it's to be advertised as a DMSG Client.
Client *Client `json:"client,omitempty"`

// ClientType the instance's client_type meta if it's to be advertised as a DMSG Client.
ClientType string `json:"client_type,omitempty"`

// Contains the instance's server meta if it's to be advertised as a DMSG Server.
Server *Server `json:"server,omitempty"`

Expand Down
21 changes: 15 additions & 6 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ type ClientCallbacks struct {

func (sc *ClientCallbacks) ensure() {
if sc.OnSessionDial == nil {
sc.OnSessionDial = func(network, addr string) (err error) { return nil }
sc.OnSessionDial = func(network, addr string) (err error) { return nil } //nolint
}
if sc.OnSessionDisconnect == nil {
sc.OnSessionDisconnect = func(network, addr string, err error) {}
sc.OnSessionDisconnect = func(network, addr string, err error) {} //nolint
}
}

Expand All @@ -44,6 +44,7 @@ type Config struct {
MinSessions int
UpdateInterval time.Duration // Duration between discovery entry updates.
Callbacks *ClientCallbacks
ClientType string
}

// Ensure ensures all config values are set.
Expand Down Expand Up @@ -108,10 +109,9 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf

// Init callback: on set session.
c.EntityCommon.setSessionCallback = func(ctx context.Context) error {
if err := c.EntityCommon.updateClientEntry(ctx, c.done); err != nil {
if err := c.EntityCommon.updateClientEntry(ctx, c.done, c.conf.ClientType); err != nil {
return err
}

// Client is 'ready' once we have successfully updated the discovery entry
// with at least one delegated server.
c.readyOnce.Do(func() { close(c.ready) })
Expand All @@ -120,7 +120,7 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf

// Init callback: on delete session.
c.EntityCommon.delSessionCallback = func(ctx context.Context) error {
err := c.EntityCommon.updateClientEntry(ctx, c.done)
err := c.EntityCommon.updateClientEntry(ctx, c.done, c.conf.ClientType)
return err
}

Expand Down Expand Up @@ -458,7 +458,7 @@ func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (cs Client

// AllStreams returns all the streams of the current client.
func (ce *Client) AllStreams() (out []*Stream) {
fn := func(port uint16, pv netutil.PorterValue) (next bool) {
fn := func(port uint16, pv netutil.PorterValue) (next bool) { //nolint
if str, ok := pv.Value.(*Stream); ok {
out = append(out, str)
return true
Expand All @@ -485,6 +485,15 @@ func (ce *Client) AllEntries(ctx context.Context) (entries []string, err error)
return entries, err
}

// AllVisorEntries returns all the entries registered in discovery that are visor
func (ce *Client) AllVisorEntries(ctx context.Context) (entries []string, err error) {
err = netutil.NewDefaultRetrier(ce.log).Do(ctx, func() error {
entries, err = ce.dc.AllEntries(ctx)
return err
})
return entries, err
}

// ConnectedServersPK return keys of all connected dmsg servers
func (ce *Client) ConnectedServersPK() []string {
sessions := ce.allClientSessions(ce.porter)
Expand Down
5 changes: 3 additions & 2 deletions pkg/dmsg/entity_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (c *EntityCommon) updateServerEntryLoop(ctx context.Context, addr string, m
}
}

func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}) (err error) {
func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}, clientType string) (err error) {
if isClosed(done) {
return nil
}
Expand All @@ -245,12 +245,13 @@ func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}
entry, err := c.dc.Entry(ctx, c.pk)
if err != nil {
entry = disc.NewClientEntry(c.pk, 0, srvPKs)
entry.ClientType = clientType
if err := entry.Sign(c.sk); err != nil {
return err
}
return c.dc.PostEntry(ctx, entry)
}

entry.ClientType = clientType
entry.Client.DelegatedServers = srvPKs
c.log.WithField("entry", entry).Debug("Updating entry.")
return c.dc.PutEntry(ctx, c.sk, entry)
Expand Down
Loading