Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new streaming RPCs to the admin server. #320

Merged
merged 3 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions fleetspeak/src/server/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ func (s adminServer) ListClients(ctx context.Context, req *spb.ListClientsReques
}, nil
}

func (s adminServer) StreamClientIds(_ *spb.StreamClientIdsRequest, srv spb.Admin_StreamClientIdsServer) error {
callback := func(id common.ClientID) error {
return srv.Send(&spb.StreamClientIdsResponse{
ClientId: id.Bytes(),
})
}
return s.store.StreamClientIds(srv.Context(), callback)
}

func (s adminServer) ListClientContacts(ctx context.Context, req *spb.ListClientContactsRequest) (*spb.ListClientContactsResponse, error) {
id, err := common.BytesToClientID(req.ClientId)
if err != nil {
Expand All @@ -138,6 +147,19 @@ func (s adminServer) ListClientContacts(ctx context.Context, req *spb.ListClient
}, nil
}

func (s adminServer) StreamClientContacts(req *spb.StreamClientContactsRequest, srv spb.Admin_StreamClientContactsServer) error {
callback := func(contact *spb.ClientContact) error {
return srv.Send(&spb.StreamClientContactsResponse{
Contact: contact,
})
}
id, err := common.BytesToClientID(req.ClientId)
if err != nil {
return err
}
return s.store.StreamClientContacts(srv.Context(), id, callback)
}

func (s adminServer) InsertMessage(ctx context.Context, m *fspb.Message) (*fspb.EmptyMessage, error) {
// At this point, we mostly trust the message we get, but do some basic
// sanity checks and generate missing metadata.
Expand Down
6 changes: 6 additions & 0 deletions fleetspeak/src/server/db/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ type ClientStore interface {
// returns all clients.
ListClients(ctx context.Context, ids []common.ClientID) ([]*spb.Client, error)

// StreamClientIds streams the IDs of all available clients.
StreamClientIds(ctx context.Context, callback func(common.ClientID) error) error

// GetClientData retrieves the current data about the client identified
// by id.
GetClientData(ctx context.Context, id common.ClientID) (*ClientData, error)
Expand Down Expand Up @@ -229,6 +232,9 @@ type ClientStore interface {
// older than a few weeks.
ListClientContacts(ctx context.Context, id common.ClientID) ([]*spb.ClientContact, error)

// StreamClientContacts is a streaming version of ListClientContacts.
StreamClientContacts(ctx context.Context, id common.ClientID, callback func(*spb.ClientContact) error) error

// LinkMessagesToContact associates messages with a contact - it records
// that they were sent or received during the given contact.
LinkMessagesToContact(ctx context.Context, contact ContactID, msgs []common.MessageID) error
Expand Down
80 changes: 66 additions & 14 deletions fleetspeak/src/server/dbtesting/clientstore_suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,23 +250,42 @@ func clientStoreTest(t *testing.T, ds db.Store) {
t.Errorf("ListClients error: want [%v] got [%v]", want, got)
}

contacts, err := ds.ListClientContacts(ctx, clientID)
if err != nil {
t.Errorf("ListClientContacts returned error: %v", err)
}
if len(contacts) != 1 {
t.Errorf("ListClientContacts returned %d results, expected 1.", len(contacts))
} else {
if contacts[0].SentNonce != 42 || contacts[0].ReceivedNonce != 54 {
t.Errorf("ListClientContact[0] should return nonces (42, 54), got (%d, %d)",
contacts[0].SentNonce, contacts[0].ReceivedNonce)
}
if contacts[0].ObservedAddress != longAddr {
t.Errorf("ListClientContact[0] should return address %s, got %s",
longAddr, contacts[0].ObservedAddress)
checkClientContacts := func(t *testing.T, contacts []*spb.ClientContact) {
if len(contacts) != 1 {
t.Errorf("ListClientContacts returned %d results, expected 1.", len(contacts))
} else {
if contacts[0].SentNonce != 42 || contacts[0].ReceivedNonce != 54 {
t.Errorf("ListClientContact[0] should return nonces (42, 54), got (%d, %d)",
contacts[0].SentNonce, contacts[0].ReceivedNonce)
}
if contacts[0].ObservedAddress != longAddr {
t.Errorf("ListClientContact[0] should return address %s, got %s",
longAddr, contacts[0].ObservedAddress)
}
}
}

t.Run("ListClientContacts", func(t *testing.T) {
contacts, err := ds.ListClientContacts(ctx, clientID)
if err != nil {
t.Errorf("ListClientContacts returned error: %v", err)
}
checkClientContacts(t, contacts)
})

t.Run("StreamClientContacts", func(t *testing.T) {
var contacts []*spb.ClientContact
callback := func(contact *spb.ClientContact) error {
contacts = append(contacts, contact)
return nil
}
err := ds.StreamClientContacts(ctx, clientID, callback)
if err != nil {
t.Errorf("StreamClientContacts returned error: %v", err)
}
checkClientContacts(t, contacts)
})

if err := ds.BlacklistClient(ctx, clientID); err != nil {
t.Errorf("Error blacklisting client: %v", err)
}
Expand Down Expand Up @@ -350,6 +369,38 @@ Cases:
}
}

func streamClientIdsTest(t *testing.T, ds db.Store) {
ctx := context.Background()

clientIds := []common.ClientID{clientID, clientID2, clientID3}

for _, cid := range clientIds {
if err := ds.AddClient(ctx, cid, &db.ClientData{Key: []byte("test key")}); err != nil {
t.Fatalf("AddClient [%v] failed: %v", clientID, err)
}
}

var result []common.ClientID

callback := func(id common.ClientID) error {
result = append(result, id)
return nil
}

err := ds.StreamClientIds(ctx, callback)
if err != nil {
t.Fatalf("StreamClientIds failed", err)
}

sort.Slice(result, func(i int, j int) bool {
return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0
})

if !reflect.DeepEqual(result, clientIds) {
t.Errorf("StreamClientIds returned unexpected result. Got: [%v]. Want: [%v].", result, clientIds)
}
}

func fetchResourceUsageRecordsTest(t *testing.T, ds db.Store) {
ctx := context.Background()
key := []byte("Test key")
Expand Down Expand Up @@ -486,6 +537,7 @@ func clientStoreTestSuite(t *testing.T, env DbTestEnv) {
runTestSuite(t, env, map[string]func(*testing.T, db.Store){
"ClientStoreTest": clientStoreTest,
"ListClientsTest": listClientsTest,
"StreamClientIdsTest": streamClientIdsTest,
"FetchResourceUsageRecordsTest": fetchResourceUsageRecordsTest,
})
})
Expand Down
50 changes: 41 additions & 9 deletions fleetspeak/src/server/mysql/clientstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,32 @@ func uint64ToBytes(i uint64) []byte {
return b
}

func (d *Datastore) StreamClientIds(ctx context.Context, callback func(common.ClientID) error) error {
return d.runInTx(ctx, true, func(tx *sql.Tx) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runInTx is a retry loop, which could lead to strange behavior (duplicate events, mostly). Maybe streaming reads should just call "runOnce" directly, and recognize that the caller might need to retry?

Also, it might be a time to fix (well, implement) the handling of the readonly flag. At least, recent versions of the golang sql package provide a way to pass it in to BeginTx. I don't know how complete the actual mysql driver support is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done switching to runOnce.

I've created #321 for the readonly flag.

rs, err := tx.QueryContext(ctx, "SELECT client_id FROM clients")
if err != nil {
return err
}
defer rs.Close()
for rs.Next() {
var bid []byte
err := rs.Scan(&bid)
if err != nil {
return err
}
id, err := common.BytesToClientID(bid)
if err != nil {
return err
}
err = callback(id)
if err != nil {
return err
}
}
return nil
})
}

func (d *Datastore) ListClients(ctx context.Context, ids []common.ClientID) ([]*spb.Client, error) {
// Return value map, maps string client ids to the return values.
var retm map[string]*spb.Client
Expand Down Expand Up @@ -265,10 +291,8 @@ func (d *Datastore) RecordClientContact(ctx context.Context, data db.ContactData
return res, err
}

func (d *Datastore) ListClientContacts(ctx context.Context, id common.ClientID) ([]*spb.ClientContact, error) {
var res []*spb.ClientContact
if err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
res = nil
func (d *Datastore) StreamClientContacts(ctx context.Context, id common.ClientID, callback func(*spb.ClientContact) error) error {
return d.runInTx(ctx, true, func(tx *sql.Tx) error {
rows, err := tx.QueryContext(
ctx,
"SELECT time, sent_nonce, received_nonce, address FROM client_contacts WHERE client_id = ?",
Expand Down Expand Up @@ -304,14 +328,22 @@ func (d *Datastore) ListClientContacts(ctx context.Context, id common.ClientID)
}
c.Timestamp = ts

res = append(res, c)
err = callback(c)
if err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
})
}

return res, nil
func (d *Datastore) ListClientContacts(ctx context.Context, id common.ClientID) ([]*spb.ClientContact, error) {
var res []*spb.ClientContact
callback := func(c *spb.ClientContact) error {
res = append(res, c)
return nil
}
return res, d.StreamClientContacts(ctx, id, callback)
}

func (d *Datastore) LinkMessagesToContact(ctx context.Context, contact db.ContactID, ids []common.MessageID) error {
Expand Down
Loading