Skip to content

Commit

Permalink
Merge pull request #253 from mrpalide/feat/improve-dmsg-client-regist…
Browse files Browse the repository at this point in the history
…er-logic

Improve dmsg client registration
  • Loading branch information
mrpalide authored Feb 26, 2024
2 parents 22e2bc0 + 0cc33d6 commit 86c43e8
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 8 deletions.
15 changes: 15 additions & 0 deletions internal/dmsg-discovery/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func New(log logrus.FieldLogger, db store.Storer, m discmetrics.Metrics, testMod
r.Post("/dmsg-discovery/entry/{pk}", api.setEntry())
r.Delete("/dmsg-discovery/entry", api.delEntry())
r.Get("/dmsg-discovery/entries", api.allEntries())
r.Get("/dmsg-discovery/visorEntries", api.allVisorEntries())
r.Delete("/dmsg-discovery/deregister", api.deregisterEntry())
r.Get("/dmsg-discovery/available_servers", api.getAvailableServers())
r.Get("/dmsg-discovery/all_servers", api.getAllServers())
Expand Down Expand Up @@ -163,6 +164,20 @@ func (a *API) allEntries() func(w http.ResponseWriter, r *http.Request) {
}
}

// allVisorEntries returns all visor client entries connected to dmsg
// URI: /dmsg-discovery/visorEntries
// Method: GET
func (a *API) allVisorEntries() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
entries, err := a.db.AllVisorEntries(r.Context())
if err != nil {
a.handleError(w, r, err)
return
}
a.writeJSON(w, r, http.StatusOK, entries)
}
}

// deregisterEntry deletes the client entry associated with the PK requested by the network monitor
// URI: /dmsg-discovery/deregister/:pk
// Method: DELETE
Expand Down
16 changes: 16 additions & 0 deletions internal/dmsg-discovery/store/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ func (r *redisStore) SetEntry(ctx context.Context, entry *disc.Entry, timeout ti
return disc.ErrUnexpected
}
}
if entry.ClientType == "visor" {
err = r.client.SAdd(ctx, "visorClients", entry.Static.Hex()).Err()
if err != nil {
log.WithError(err).Errorf("Failed to add to visorClients (SAdd) from redis")
return disc.ErrUnexpected
}
}

return nil
}
Expand All @@ -107,6 +114,7 @@ func (r *redisStore) DelEntry(ctx context.Context, staticPubKey cipher.PubKey) e
// Delete pubkey from servers or clients set stored
r.client.SRem(ctx, "servers", staticPubKey.Hex())
r.client.SRem(ctx, "clients", staticPubKey.Hex())
r.client.SRem(ctx, "visorClients", staticPubKey.Hex())
return nil
}

Expand Down Expand Up @@ -233,3 +241,11 @@ func (r *redisStore) AllEntries(ctx context.Context) ([]string, error) {
}
return clients, err
}

func (r *redisStore) AllVisorEntries(ctx context.Context) ([]string, error) {
clients, err := r.client.SMembers(ctx, "visorClients").Result()
if err != nil {
return nil, err
}
return clients, err
}
3 changes: 3 additions & 0 deletions internal/dmsg-discovery/store/storer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type Storer interface {

// AllEntries returns all clients PKs.
AllEntries(ctx context.Context) ([]string, error)

// AllVisorEntries returns all clients PKs.
AllVisorEntries(ctx context.Context) ([]string, error)
}

// Config configures the Store object.
Expand Down
21 changes: 21 additions & 0 deletions internal/dmsg-discovery/store/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,24 @@ func (ms *MockStore) AllEntries(_ context.Context) ([]string, error) {
}
return entries, nil
}

// AllVisorEntries implements Storer CountEntries method for MockStore
func (ms *MockStore) AllVisorEntries(_ context.Context) ([]string, error) {
entries := []string{}

ms.mLock.RLock()
defer ms.mLock.RUnlock()

clients := arrayFromMap(ms.m)
for _, entryString := range clients {
var e disc.Entry

err := json.Unmarshal(entryString, &e)
if err != nil {
return nil, disc.ErrUnexpected
}

entries = append(entries, e.String())
}
return entries, nil
}
3 changes: 3 additions & 0 deletions pkg/disc/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ type Entry struct {
// Contains the instance's client meta if it's to be advertised as a DMSG Client.
Client *Client `json:"client,omitempty"`

// ClientType the instance's client_type meta if it's to be advertised as a DMSG Client.
ClientType string `json:"client_type,omitempty"`

// Contains the instance's server meta if it's to be advertised as a DMSG Server.
Server *Server `json:"server,omitempty"`

Expand Down
21 changes: 15 additions & 6 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ type ClientCallbacks struct {

func (sc *ClientCallbacks) ensure() {
if sc.OnSessionDial == nil {
sc.OnSessionDial = func(network, addr string) (err error) { return nil }
sc.OnSessionDial = func(network, addr string) (err error) { return nil } //nolint
}
if sc.OnSessionDisconnect == nil {
sc.OnSessionDisconnect = func(network, addr string, err error) {}
sc.OnSessionDisconnect = func(network, addr string, err error) {} //nolint
}
}

Expand All @@ -44,6 +44,7 @@ type Config struct {
MinSessions int
UpdateInterval time.Duration // Duration between discovery entry updates.
Callbacks *ClientCallbacks
ClientType string
}

// Ensure ensures all config values are set.
Expand Down Expand Up @@ -108,10 +109,9 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf

// Init callback: on set session.
c.EntityCommon.setSessionCallback = func(ctx context.Context) error {
if err := c.EntityCommon.updateClientEntry(ctx, c.done); err != nil {
if err := c.EntityCommon.updateClientEntry(ctx, c.done, c.conf.ClientType); err != nil {
return err
}

// Client is 'ready' once we have successfully updated the discovery entry
// with at least one delegated server.
c.readyOnce.Do(func() { close(c.ready) })
Expand All @@ -120,7 +120,7 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Conf

// Init callback: on delete session.
c.EntityCommon.delSessionCallback = func(ctx context.Context) error {
err := c.EntityCommon.updateClientEntry(ctx, c.done)
err := c.EntityCommon.updateClientEntry(ctx, c.done, c.conf.ClientType)
return err
}

Expand Down Expand Up @@ -458,7 +458,7 @@ func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (cs Client

// AllStreams returns all the streams of the current client.
func (ce *Client) AllStreams() (out []*Stream) {
fn := func(port uint16, pv netutil.PorterValue) (next bool) {
fn := func(port uint16, pv netutil.PorterValue) (next bool) { //nolint
if str, ok := pv.Value.(*Stream); ok {
out = append(out, str)
return true
Expand All @@ -485,6 +485,15 @@ func (ce *Client) AllEntries(ctx context.Context) (entries []string, err error)
return entries, err
}

// AllVisorEntries returns all the entries registered in discovery that are visor
func (ce *Client) AllVisorEntries(ctx context.Context) (entries []string, err error) {
err = netutil.NewDefaultRetrier(ce.log).Do(ctx, func() error {
entries, err = ce.dc.AllEntries(ctx)
return err
})
return entries, err
}

// ConnectedServersPK return keys of all connected dmsg servers
func (ce *Client) ConnectedServersPK() []string {
sessions := ce.allClientSessions(ce.porter)
Expand Down
5 changes: 3 additions & 2 deletions pkg/dmsg/entity_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (c *EntityCommon) updateServerEntryLoop(ctx context.Context, addr string, m
}
}

func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}) (err error) {
func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}, clientType string) (err error) {
if isClosed(done) {
return nil
}
Expand All @@ -245,12 +245,13 @@ func (c *EntityCommon) updateClientEntry(ctx context.Context, done chan struct{}
entry, err := c.dc.Entry(ctx, c.pk)
if err != nil {
entry = disc.NewClientEntry(c.pk, 0, srvPKs)
entry.ClientType = clientType
if err := entry.Sign(c.sk); err != nil {
return err
}
return c.dc.PostEntry(ctx, entry)
}

entry.ClientType = clientType
entry.Client.DelegatedServers = srvPKs
c.log.WithField("entry", entry).Debug("Updating entry.")
return c.dc.PutEntry(ctx, c.sk, entry)
Expand Down

0 comments on commit 86c43e8

Please sign in to comment.