Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate auth errors from service #3609

Draft
wants to merge 5 commits into
base: version/0-48-0-RC1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmd/state-svc/internal/messages/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package messages

import (
"github.com/ActiveState/cli/internal/graph"
"github.com/google/uuid"
)

func NewMessage(topic string, message string) *graph.Message {
return &graph.Message{
ID: uuid.New().String(),
Topic: topic,
Message: message,
}
}
60 changes: 60 additions & 0 deletions cmd/state-svc/internal/messages/queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package messages

import (
"github.com/ActiveState/cli/internal/errs"
"github.com/ActiveState/cli/internal/graph"
"github.com/ActiveState/cli/internal/logging"
)

type Queue struct {
queue map[string]map[string]*graph.Message
}

func NewQueue() *Queue {
return &Queue{
queue: make(map[string]map[string]*graph.Message),
}
}

func (q *Queue) Queue(topic string, message string) error {
if _, ok := q.queue[topic]; !ok {
q.queue[topic] = make(map[string]*graph.Message)
}
msg := NewMessage(topic, message)
logging.Debug("Queued message: %s, %s", msg.ID, msg.Message)
q.queue[topic][msg.ID] = msg
return nil
}

func (q *Queue) Messages() ([]*graph.Message, error) {
var messages []*graph.Message
for _, topic := range q.queue {
for _, message := range topic {
messages = append(messages, message)
}
}
return messages, nil
}

func (q *Queue) Dequeue(messageIDs []string) error {
for _, messageID := range messageIDs {
err := q.dequeueMessages(messageID)
if err != nil {
return errs.Wrap(err, "failed to dequeue message")
}
}
return nil
}

func (q *Queue) dequeueMessages(messageID string) error {
for _, topic := range q.queue {
for _, msg := range topic {
if msg.ID != messageID {
continue
}
delete(topic, messageID)
return nil
}
}
return nil
}
169 changes: 169 additions & 0 deletions cmd/state-svc/internal/messages/queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package messages

import (
"testing"

"github.com/ActiveState/cli/internal/graph"
"github.com/stretchr/testify/assert"
)

func TestNewQueue(t *testing.T) {
q := NewQueue()
assert.NotNil(t, q)
assert.Empty(t, q.queue)
}

func TestQueue_Queue(t *testing.T) {
tests := []struct {
name string
messages []graph.Message
want map[string]int // topic -> expected message count
}{
{
name: "queue first message",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
},
want: map[string]int{"topic1": 1},
},
{
name: "queue second message in same topic",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
{Topic: "topic1", Message: "message2"},
},
want: map[string]int{"topic1": 2},
},
{
name: "queue message in different topic",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
{Topic: "topic2", Message: "message3"},
},
want: map[string]int{"topic1": 1, "topic2": 1},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := NewQueue()

for _, m := range tt.messages {
err := q.Queue(m.Topic, m.Message)
assert.NoError(t, err)
}

for topic, count := range tt.want {
assert.Len(t, q.queue[topic], count)
}
})
}
}

func TestQueue_Messages(t *testing.T) {
tests := []struct {
name string
messages []graph.Message
wantCount int
}{
{
name: "empty queue",
messages: nil,
wantCount: 0,
},
{
name: "single message",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
},
wantCount: 1,
},
{
name: "multiple messages across topics",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
{Topic: "topic1", Message: "message2"},
{Topic: "topic2", Message: "message3"},
},
wantCount: 3,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := NewQueue()

for _, m := range tt.messages {
err := q.Queue(m.Topic, m.Message)
assert.NoError(t, err)
}

msgs, err := q.Messages()
assert.NoError(t, err)
assert.Len(t, msgs, tt.wantCount)
})
}
}

func TestQueue_Dequeue(t *testing.T) {
tests := []struct {
name string
messages []graph.Message
dequeueIDs []string
wantRemaining int
}{
{
name: "dequeue single message",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
},
dequeueIDs: nil, // Will be populated during test with actual message ID
wantRemaining: 0,
},
{
name: "dequeue multiple messages",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
{Topic: "topic2", Message: "message2"},
},
dequeueIDs: nil, // Will be populated during test with actual message IDs
wantRemaining: 0,
},
{
name: "dequeue non-existent message",
messages: []graph.Message{
{Topic: "topic1", Message: "message1"},
},
dequeueIDs: []string{"non-existent-id"},
wantRemaining: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := NewQueue()

for _, m := range tt.messages {
err := q.Queue(m.Topic, m.Message)
assert.NoError(t, err)
}

if tt.dequeueIDs == nil {
msgs, err := q.Messages()
assert.NoError(t, err)

tt.dequeueIDs = make([]string, len(msgs))
for i, msg := range msgs {
tt.dequeueIDs[i] = msg.ID
}
}

err := q.Dequeue(tt.dequeueIDs)
assert.NoError(t, err)

remaining, err := q.Messages()
assert.NoError(t, err)
assert.Len(t, remaining, tt.wantRemaining)
})
}
}
51 changes: 46 additions & 5 deletions cmd/state-svc/internal/resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package resolver
import (
"context"
"encoding/json"
"errors"
"os"
"runtime/debug"
"sort"
Expand All @@ -11,6 +12,7 @@ import (

"github.com/ActiveState/cli/cmd/state-svc/internal/graphqltypes"
"github.com/ActiveState/cli/cmd/state-svc/internal/hash"
"github.com/ActiveState/cli/cmd/state-svc/internal/messages"
"github.com/ActiveState/cli/cmd/state-svc/internal/notifications"
"github.com/ActiveState/cli/cmd/state-svc/internal/rtwatcher"
genserver "github.com/ActiveState/cli/cmd/state-svc/internal/server/generated"
Expand All @@ -24,6 +26,7 @@ import (
"github.com/ActiveState/cli/internal/graph"
"github.com/ActiveState/cli/internal/logging"
configMediator "github.com/ActiveState/cli/internal/mediators/config"
msgs "github.com/ActiveState/cli/internal/messages"
"github.com/ActiveState/cli/internal/poller"
"github.com/ActiveState/cli/internal/rtutils/ptr"
"github.com/ActiveState/cli/internal/runbits/panics"
Expand All @@ -35,7 +38,8 @@ import (

type Resolver struct {
cfg *config.Instance
messages *notifications.Notifications
notifications *notifications.Notifications
messages *messages.Queue
updatePoller *poller.Poller
authPoller *poller.Poller
projectIDCache *projectcache.ID
Expand All @@ -50,11 +54,12 @@ type Resolver struct {
// var _ genserver.ResolverRoot = &Resolver{} // Must implement ResolverRoot

func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Resolver, error) {
msg, err := notifications.New(cfg, auth)
notif, err := notifications.New(cfg, auth)
if err != nil {
return nil, errs.Wrap(err, "Could not initialize messages")
}

msg := messages.NewQueue()
upchecker := updater.NewDefaultChecker(cfg, an)
pollUpdate := poller.New(1*time.Hour, func() (interface{}, error) {
defer func() {
Expand All @@ -74,11 +79,23 @@ func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Res
}

pollAuth := poller.New(time.Duration(int64(time.Millisecond)*pollRate), func() (interface{}, error) {
logging.Debug("Polling for authenticated state")
defer func() {
panics.LogAndPanic(recover(), debug.Stack())
}()
if auth.SyncRequired() {
return nil, auth.Sync()
logging.Debug("Sync required")
if err := auth.Sync(); err != nil {
logging.Debug("Syncing authenticated state: %s", err.Error())
var invalidTokenErr *authentication.ErrInvalidToken
if errors.As(err, &invalidTokenErr) {
logging.Debug("Queuing invalid API token error")
msg.Queue(msgs.TopicErrorAuthToken, "Invalid API token")
} else {
logging.Warning("Could not sync authenticated state: %s", err.Error())
}
}
return nil, nil
}
return nil, nil
})
Expand All @@ -88,6 +105,7 @@ func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Res
anForClient := sync.New(anaConsts.SrcStateTool, cfg, auth, nil)
return &Resolver{
cfg,
notif,
msg,
pollUpdate,
pollAuth,
Expand All @@ -102,7 +120,7 @@ func New(cfg *config.Instance, an *sync.Client, auth *authentication.Auth) (*Res
}

func (r *Resolver) Close() error {
r.messages.Close()
r.notifications.Close()
r.updatePoller.Close()
r.authPoller.Close()
r.anForClient.Close()
Expand Down Expand Up @@ -250,7 +268,30 @@ func (r *Resolver) ReportRuntimeUsage(_ context.Context, pid int, exec, source s
func (r *Resolver) CheckNotifications(ctx context.Context, command string, flags []string) ([]*graph.NotificationInfo, error) {
defer func() { panics.LogAndPanic(recover(), debug.Stack()) }()
logging.Debug("Check notifications resolver")
return r.messages.Check(command, flags)
return r.notifications.Check(command, flags)
}

func (r *Resolver) CheckMessages(ctx context.Context) ([]*graph.Message, error) {
logging.Debug("Check messages resolver")
var messages []*graph.Message
var err error

defer func() {
var sentMessageIDs []string
for _, msg := range messages {
sentMessageIDs = append(sentMessageIDs, msg.ID)
}
if err := r.messages.Dequeue(sentMessageIDs); err != nil {
logging.Error("Could not dequeue messages: %s", errs.JoinMessage(err))
}
panics.LogAndPanic(recover(), debug.Stack())
}()

messages, err = r.messages.Messages()
if err != nil {
return nil, errs.Wrap(err, "Could not get messages")
}
return messages, nil
}

func (r *Resolver) ConfigChanged(ctx context.Context, key string) (*graph.ConfigChangedResponse, error) {
Expand Down
Loading
Loading