diff --git a/go.mod b/go.mod index 9615fdd..52d1a8a 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,16 @@ module github.com/signalfx/signalflow-client-go go 1.21 + +require ( + github.com/gorilla/websocket v1.5.1 + github.com/signalfx/signalfx-go v1.34.0 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.17.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0288f06 --- /dev/null +++ b/go.sum @@ -0,0 +1,16 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/signalfx/signalfx-go v1.34.0 h1:OQ6tyMY4efWB57EPIQqrpWrAfcSdyfa+bLtmAe7GLfE= +github.com/signalfx/signalfx-go v1.34.0/go.mod h1:IpGZLPvCKNFyspAXoS480jB02mocTpo0KYd8jbl6/T8= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/signalflow/client.go b/signalflow/client.go new file mode 100644 index 0000000..a68ae00 --- /dev/null +++ b/signalflow/client.go @@ -0,0 +1,372 @@ +package signalflow + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/signalfx/signalflow-client-go/signalflow/messages" +) + +// Client for SignalFlow via websockets (SSE is not currently supported). +type Client struct { + // Access token for the org + token string + userAgent string + defaultMetadataTimeout time.Duration + nextChannelNum int64 + conn *wsConn + readTimeout time.Duration + // How long to wait for writes to the websocket to finish + writeTimeout time.Duration + streamURL *url.URL + onError OnErrorFunc + channelsByName map[string]chan messages.Message + + // These are the lower-level WebSocket level channels for byte messages + outgoingTextMsgs chan *outgoingMessage + incomingTextMsgs chan []byte + incomingBinaryMsgs chan []byte + connectedCh chan struct{} + + isClosed atomic.Bool + sync.Mutex + cancel context.CancelFunc +} + +type clientMessageRequest struct { + msg interface{} + resultCh chan error +} + +// ClientParam is the common type of configuration functions for the SignalFlow client +type ClientParam func(*Client) error + +// StreamURL lets you set the full URL to the stream endpoint, including the +// path. +func StreamURL(streamEndpoint string) ClientParam { + return func(c *Client) error { + var err error + c.streamURL, err = url.Parse(streamEndpoint) + return err + } +} + +// StreamURLForRealm can be used to configure the websocket url for a specific +// SignalFx realm. +func StreamURLForRealm(realm string) ClientParam { + return func(c *Client) error { + var err error + c.streamURL, err = url.Parse(fmt.Sprintf("wss://stream.%s.signalfx.com/v2/signalflow", realm)) + return err + } +} + +// AccessToken can be used to provide a SignalFx organization access token or +// user access token to the SignalFlow client. +func AccessToken(token string) ClientParam { + return func(c *Client) error { + c.token = token + return nil + } +} + +// UserAgent allows setting the `userAgent` field when authenticating to +// SignalFlow. This can be useful for accounting how many jobs are started +// from each client. +func UserAgent(userAgent string) ClientParam { + return func(c *Client) error { + c.userAgent = userAgent + return nil + } +} + +// ReadTimeout sets the duration to wait between messages that come on the +// websocket. If the resolution of the job is very low, this should be +// increased. +func ReadTimeout(timeout time.Duration) ClientParam { + return func(c *Client) error { + if timeout <= 0 { + return errors.New("ReadTimeout cannot be <= 0") + } + c.readTimeout = timeout + return nil + } +} + +// WriteTimeout sets the maximum duration to wait to send a single message when +// writing messages to the SignalFlow server over the WebSocket connection. +func WriteTimeout(timeout time.Duration) ClientParam { + return func(c *Client) error { + if timeout <= 0 { + return errors.New("WriteTimeout cannot be <= 0") + } + c.writeTimeout = timeout + return nil + } +} + +type OnErrorFunc func(err error) + +func OnError(f OnErrorFunc) ClientParam { + return func(c *Client) error { + c.onError = f + return nil + } +} + +// NewClient makes a new SignalFlow client that will immediately try and +// connect to the SignalFlow backend. +func NewClient(options ...ClientParam) (*Client, error) { + c := &Client{ + streamURL: &url.URL{ + Scheme: "wss", + Host: "stream.us0.signalfx.com", + Path: "/v2/signalflow", + }, + readTimeout: 1 * time.Minute, + writeTimeout: 5 * time.Second, + channelsByName: make(map[string]chan messages.Message), + + outgoingTextMsgs: make(chan *outgoingMessage), + incomingTextMsgs: make(chan []byte), + incomingBinaryMsgs: make(chan []byte), + connectedCh: make(chan struct{}), + } + + for i := range options { + if err := options[i](c); err != nil { + return nil, err + } + } + + c.conn = &wsConn{ + StreamURL: c.streamURL, + OutgoingTextMsgs: c.outgoingTextMsgs, + IncomingTextMsgs: c.incomingTextMsgs, + IncomingBinaryMsgs: c.incomingBinaryMsgs, + ConnectedCh: c.connectedCh, + ConnectTimeout: 10 * time.Second, + ReadTimeout: c.readTimeout, + WriteTimeout: c.writeTimeout, + OnError: c.onError, + PostDisconnectCallback: func() { + c.closeRegisteredChannels() + }, + PostConnectMessage: func() []byte { + bytes, err := c.makeAuthRequest() + if err != nil { + c.sendErrIfWanted(fmt.Errorf("failed to send auth: %w", err)) + return nil + } + return bytes + }, + } + + var ctx context.Context + ctx, c.cancel = context.WithCancel(context.Background()) + + go c.conn.Run(ctx) + go c.run(ctx) + + return c, nil +} + +func (c *Client) newUniqueChannelName() string { + name := fmt.Sprintf("ch-%d", atomic.AddInt64(&c.nextChannelNum, 1)) + return name +} + +func (c *Client) sendErrIfWanted(err error) { + if c.onError != nil { + c.onError(err) + } +} + +// Writes all messages from a single goroutine since that is required by +// websocket library. +func (c *Client) run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg := <-c.incomingTextMsgs: + err := c.handleMessage(msg, websocket.TextMessage) + if err != nil { + c.sendErrIfWanted(fmt.Errorf("error handling SignalFlow text message: %w", err)) + } + case msg := <-c.incomingBinaryMsgs: + err := c.handleMessage(msg, websocket.BinaryMessage) + if err != nil { + c.sendErrIfWanted(fmt.Errorf("error handling SignalFlow binary message: %w", err)) + } + } + } +} + +func (c *Client) sendMessage(ctx context.Context, message interface{}) error { + msgBytes, err := c.serializeMessage(message) + if err != nil { + return err + } + + resultCh := make(chan error, 1) + select { + case c.outgoingTextMsgs <- &outgoingMessage{ + bytes: msgBytes, + resultCh: resultCh, + }: + return <-resultCh + case <-ctx.Done(): + close(resultCh) + return ctx.Err() + } +} + +func (c *Client) serializeMessage(message interface{}) ([]byte, error) { + msgBytes, err := json.Marshal(message) + if err != nil { + return nil, fmt.Errorf("could not marshal SignalFlow request: %w", err) + } + return msgBytes, nil +} + +func (c *Client) handleMessage(msgBytes []byte, msgTyp int) error { + message, err := messages.ParseMessage(msgBytes, msgTyp == websocket.TextMessage) + if err != nil { + return fmt.Errorf("could not parse SignalFlow message: %w", err) + } + + if cm, ok := message.(messages.ChannelMessage); ok { + channelName := cm.Channel() + c.Lock() + channel, ok := c.channelsByName[channelName] + if !ok { + // The channel should have existed before, but now doesn't, + // probably because it was closed. + return nil + } else if channelName == "" { + c.acceptMessage(message) + return nil + } + channel <- message + c.Unlock() + } else { + return c.acceptMessage(message) + } + return nil +} + +// acceptMessages accepts non-channel specific messages. The only one that I +// know of is the authenticated response. +func (c *Client) acceptMessage(message messages.Message) error { + if _, ok := message.(*messages.AuthenticatedMessage); ok { + return nil + } else if msg, ok := message.(*messages.BaseJSONMessage); ok { + data := msg.RawData() + if data != nil && data["event"] == "KEEP_ALIVE" { + // Ignore keep alive messages + return nil + } + } + + return fmt.Errorf("unknown SignalFlow message received: %v", message) +} + +// Sends the authenticate message but does not wait for a response. +func (c *Client) makeAuthRequest() ([]byte, error) { + return c.serializeMessage(&AuthRequest{ + Token: c.token, + UserAgent: c.userAgent, + }) +} + +// Execute a SignalFlow job and return a channel upon which informational +// messages and data will flow. +// See https://dev.splunk.com/observability/docs/signalflow/messages/websocket_request_messages#Execute-a-computation +func (c *Client) Execute(ctx context.Context, req *ExecuteRequest) (*Computation, error) { + if req.Channel == "" { + req.Channel = c.newUniqueChannelName() + } + + err := c.sendMessage(ctx, req) + if err != nil { + return nil, err + } + + return newComputation(c.registerChannel(req.Channel), req.Channel, c), nil +} + +// Detach from a computation but keep it running. See +// https://dev.splunk.com/observability/docs/signalflow/messages/websocket_request_messages#Detach-from-a-computation. +func (c *Client) Detach(ctx context.Context, req *DetachRequest) error { + // We are assuming that the detach request will always come from the same + // client that started it with the Execute method above, and thus the + // connection is still active (i.e. we don't need to call ensureInitialized + // here). If the websocket connection does drop, all jobs started by that + // connection get detached/stopped automatically. + return c.sendMessage(ctx, req) +} + +// Stop sends a job stop request message to the backend. It does not wait for +// jobs to actually be stopped. +// See https://dev.splunk.com/observability/docs/signalflow/messages/websocket_request_messages#Stop-a-computation +func (c *Client) Stop(ctx context.Context, req *StopRequest) error { + // We are assuming that the stop request will always come from the same + // client that started it with the Execute method above, and thus the + // connection is still active (i.e. we don't need to call ensureInitialized + // here). If the websocket connection does drop, all jobs started by that + // connection get stopped automatically. + return c.sendMessage(ctx, req) +} + +func (c *Client) registerChannel(name string) chan messages.Message { + ch := make(chan messages.Message) + + c.Lock() + c.channelsByName[name] = ch + c.Unlock() + + return ch +} + +func (c *Client) closeRegisteredChannels() { + c.Lock() + for _, ch := range c.channelsByName { + close(ch) + } + c.channelsByName = map[string]chan messages.Message{} + c.Unlock() +} + +// Close the client and shutdown any ongoing connections and goroutines. The client cannot be +// reused after Close. Calling any of the client methods after Close() is undefined and will likely +// result in a panic. +func (c *Client) Close() { + if c.isClosed.Load() { + panic("cannot close client more than once") + } + c.isClosed.Store(true) + + c.cancel() + c.closeRegisteredChannels() + +DRAIN: + for { + select { + case outMsg := <-c.outgoingTextMsgs: + outMsg.resultCh <- io.EOF + default: + break DRAIN + } + } + close(c.outgoingTextMsgs) +} diff --git a/signalflow/client_test.go b/signalflow/client_test.go new file mode 100644 index 0000000..e6603e1 --- /dev/null +++ b/signalflow/client_test.go @@ -0,0 +1,361 @@ +package signalflow + +import ( + "context" + "fmt" + "log" + "math/rand" + "os" + "runtime" + "runtime/pprof" + "sync" + "testing" + "time" + + "github.com/signalfx/signalflow-client-go/signalflow/messages" + "github.com/signalfx/signalfx-go/idtool" + "github.com/stretchr/testify/require" +) + +func TestAuthenticationFlow(t *testing.T) { + t.Parallel() + fakeBackend := NewRunningFakeBackend() + defer fakeBackend.Stop() + + c, err := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + defer c.Close() + + comp, err := c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + }) + require.Nil(t, err) + + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, 1*time.Second, resolution) + + require.Equal(t, []map[string]interface{}{ + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-1", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + }, fakeBackend.received) +} + +func TestBasicComputation(t *testing.T) { + t.Parallel() + fakeBackend := NewRunningFakeBackend() + defer fakeBackend.Stop() + + c, err := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + defer c.Close() + + tsids := []idtool.ID{idtool.ID(rand.Int63()), idtool.ID(rand.Int63())} + for i, host := range []string{"host1", "host2"} { + fakeBackend.AddTSIDMetadata(tsids[i], &messages.MetadataProperties{ + Metric: "jobs_queued", + CustomProperties: map[string]string{ + "host": host, + }, + }) + } + + for i, val := range []float64{5, 10} { + fakeBackend.SetTSIDFloatData(tsids[i], val) + } + + program := "data('cpu.utilization').publish()" + fakeBackend.AddProgramTSIDs(program, tsids) + + comp, err := c.Execute(context.Background(), &ExecuteRequest{ + Program: program, + Resolution: 1 * time.Second, + }) + require.Nil(t, err) + + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, 1*time.Second, resolution) + + dataMsg := <-comp.Data() + require.Len(t, dataMsg.Payloads, 2) + require.Equal(t, dataMsg.Payloads[0].Float64(), float64(5)) + require.Equal(t, dataMsg.Payloads[1].Float64(), float64(10)) +} + +func TestMultipleComputations(t *testing.T) { + t.Parallel() + fakeBackend := NewRunningFakeBackend() + defer fakeBackend.Stop() + + c, err := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + defer c.Close() + + for i := 1; i < 50; i++ { + comp, err := c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + Resolution: time.Duration(i) * time.Second, + }) + require.Nil(t, err) + + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, time.Duration(i)*time.Second, resolution) + require.Equal(t, fmt.Sprintf("ch-%d", i), comp.name) + } +} + +func TestShutdown(t *testing.T) { + fakeBackend := NewRunningFakeBackend() + defer fakeBackend.Stop() + + c, err := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + + var comps []*Computation + for i := 1; i < 3; i++ { + comp, err := c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + Resolution: time.Duration(i) * time.Second, + }) + require.Nil(t, err) + comps = append(comps, comp) + + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, time.Duration(i)*time.Second, resolution) + require.Equal(t, fmt.Sprintf("ch-%d", i), comp.name) + } + + c.Close() + + for _, comp := range comps { + _, ok := <-comp.Data() + require.False(t, ok) + } +} + +func TestReconnect(t *testing.T) { + t.Parallel() + fakeBackend := NewRunningFakeBackend() + defer fakeBackend.Stop() + + c, err := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + defer c.Close() + + comp, err := c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + }) + require.Nil(t, err) + + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, 1*time.Second, resolution) + + require.Equal(t, []map[string]interface{}{ + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-1", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + }, fakeBackend.received) + + fakeBackend.KillExistingConnections() + + for { + _, ok := <-comp.Data() + if !ok { + break + } + } + + comp, err = c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + }) + require.Nil(t, err) + + resolution, _ = comp.Resolution(context.Background()) + require.Equal(t, 1*time.Second, resolution) + + log.Printf("%v", fakeBackend.received) + require.Equal(t, []map[string]interface{}{ + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-1", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-2", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + }, fakeBackend.received) +} + +func TestReconnectAfterBackendDown(t *testing.T) { + t.Parallel() + fakeBackend := NewRunningFakeBackend() + defer fakeBackend.Stop() + + c, err := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + + defer c.Close() + + comp, err := c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + }) + require.Nil(t, err) + + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, 1*time.Second, resolution) + + require.Equal(t, []map[string]interface{}{ + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-1", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + }, fakeBackend.received) + + fakeBackend.Stop() + for { + _, ok := <-comp.Data() + if !ok { + break + } + } + + time.Sleep(7 * time.Second) + fakeBackend.Restart() + + comp, err = c.Execute(context.Background(), &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + }) + require.Nil(t, err) + + resolution, _ = comp.Resolution(context.Background()) + require.Equal(t, 1*time.Second, resolution) + + require.Equal(t, []map[string]interface{}{ + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-1", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + { + "type": "authenticate", + "token": fakeBackend.AccessToken, + }, + { + "type": "execute", + "channel": "ch-2", + "immediate": false, + "maxDelay": 0., + "program": "data('cpu.utilization').publish()", + "resolution": 0., + "start": 0., + "stop": 0., + "timezone": "", + }, + }, fakeBackend.received) +} + +func TestFailedConnGoroutineShutdown(t *testing.T) { + defer func() { + time.Sleep(2 * time.Second) + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + }() + + fakeBackend := NewRunningFakeBackend() + fakeBackend.Stop() + + startingGoroutines := runtime.NumGoroutine() + clients := make([]*Client, 100) + var wg sync.WaitGroup + for i := range clients { + var err error + clients[i], err = NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + require.Nil(t, err) + + wg.Add(1) + go func(c *Client) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + _, err = c.Execute(ctx, &ExecuteRequest{ + Program: "data('cpu.utilization').publish()", + }) + cancel() + t.Logf("execute error: %v", err) + require.Error(t, err) + wg.Done() + }(clients[i]) + } + wg.Wait() + + for _, c := range clients { + c.Close() + } + time.Sleep(1 * time.Second) + + require.InDelta(t, startingGoroutines, runtime.NumGoroutine(), 10) +} diff --git a/signalflow/computation.go b/signalflow/computation.go new file mode 100644 index 0000000..efc2732 --- /dev/null +++ b/signalflow/computation.go @@ -0,0 +1,381 @@ +package signalflow + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/signalfx/signalflow-client-go/signalflow/messages" + "github.com/signalfx/signalfx-go/idtool" +) + +// Computation is a single running SignalFlow job +type Computation struct { + sync.Mutex + channel <-chan messages.Message + name string + client *Client + dataCh chan *messages.DataMessage + // An intermediate channel for data messages where they can be buffered if + // nothing is currently pulling data messages. + dataChBuffer chan *messages.DataMessage + eventCh chan *messages.EventMessage + infoCh chan *messages.InfoMessage + eventChBuffer chan *messages.EventMessage + expirationCh chan *messages.ExpiredTSIDMessage + expirationChBuffer chan *messages.ExpiredTSIDMessage + infoChBuffer chan *messages.InfoMessage + + errMutex sync.RWMutex + lastError error + + handle asyncMetadata[string] + resolutionMS asyncMetadata[int] + lagMS asyncMetadata[int] + maxDelayMS asyncMetadata[int] + matchedSize asyncMetadata[int] + limitSize asyncMetadata[int] + matchedNoTimeseriesQuery asyncMetadata[string] + groupByMissingProperties asyncMetadata[[]string] + + tsidMetadata map[idtool.ID]*asyncMetadata[*messages.MetadataProperties] +} + +// ComputationError exposes the underlying metadata of a computation error +type ComputationError struct { + Code int + Message string + ErrorType string +} + +func (e *ComputationError) Error() string { + err := fmt.Sprintf("%v", e.Code) + if e.ErrorType != "" { + err = fmt.Sprintf("%v (%v)", e.Code, e.ErrorType) + } + if e.Message != "" { + err = fmt.Sprintf("%v: %v", err, e.Message) + } + return err +} + +func newComputation(channel <-chan messages.Message, name string, client *Client) *Computation { + comp := &Computation{ + channel: channel, + name: name, + client: client, + dataCh: make(chan *messages.DataMessage), + dataChBuffer: make(chan *messages.DataMessage), + eventCh: make(chan *messages.EventMessage), + infoCh: make(chan *messages.InfoMessage), + eventChBuffer: make(chan *messages.EventMessage), + expirationCh: make(chan *messages.ExpiredTSIDMessage), + expirationChBuffer: make(chan *messages.ExpiredTSIDMessage), + infoChBuffer: make(chan *messages.InfoMessage), + tsidMetadata: make(map[idtool.ID]*asyncMetadata[*messages.MetadataProperties]), + } + + go bufferMessages(comp.dataChBuffer, comp.dataCh) + go bufferMessages(comp.expirationChBuffer, comp.expirationCh) + go bufferMessages(comp.eventChBuffer, comp.eventCh) + go bufferMessages(comp.infoChBuffer, comp.infoCh) + + go func() { + err := comp.watchMessages() + + if !errors.Is(err, errChannelClosed) { + comp.errMutex.Lock() + comp.lastError = err + comp.errMutex.Unlock() + } + + comp.shutdown() + }() + + return comp +} + +// Handle of the computation. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) Handle(ctx context.Context) (string, error) { + return c.handle.Get(ctx) +} + +// Resolution of the job. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) Resolution(ctx context.Context) (time.Duration, error) { + resMS, err := c.resolutionMS.Get(ctx) + return time.Duration(resMS) * time.Millisecond, err +} + +// Lag detected for the job. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) Lag(ctx context.Context) (time.Duration, error) { + lagMS, err := c.lagMS.Get(ctx) + return time.Duration(lagMS) * time.Millisecond, err +} + +// MaxDelay detected of the job. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) MaxDelay(ctx context.Context) (time.Duration, error) { + maxDelayMS, err := c.maxDelayMS.Get(ctx) + return time.Duration(maxDelayMS) * time.Millisecond, err +} + +// MatchedSize detected of the job. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) MatchedSize(ctx context.Context) (int, error) { + return c.matchedSize.Get(ctx) +} + +// LimitSize detected of the job. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) LimitSize(ctx context.Context) (int, error) { + return c.limitSize.Get(ctx) +} + +// MatchedNoTimeseriesQuery if it matched no active timeseries. Will wait as long as the given ctx +// is not closed. If ctx is closed an error will be returned. +func (c *Computation) MatchedNoTimeseriesQuery(ctx context.Context) (string, error) { + return c.matchedNoTimeseriesQuery.Get(ctx) +} + +// GroupByMissingProperties are timeseries that don't contain the required dimensions. Will wait as +// long as the given ctx is not closed. If ctx is closed an error will be returned. +func (c *Computation) GroupByMissingProperties(ctx context.Context) ([]string, error) { + return c.groupByMissingProperties.Get(ctx) +} + +// TSIDMetadata for a particular tsid. Will wait as long as the given ctx is not closed. If ctx is closed an +// error will be returned. +func (c *Computation) TSIDMetadata(ctx context.Context, tsid idtool.ID) (*messages.MetadataProperties, error) { + c.Lock() + if _, ok := c.tsidMetadata[tsid]; !ok { + c.tsidMetadata[tsid] = &asyncMetadata[*messages.MetadataProperties]{} + } + md := c.tsidMetadata[tsid] + c.Unlock() + return md.Get(ctx) +} + +// Err returns the last fatal error that caused the computation to stop, if +// any. Will be nil if the computation stopped in an expected manner. +func (c *Computation) Err() error { + c.errMutex.RLock() + defer c.errMutex.RUnlock() + + return c.lastError +} + +func (c *Computation) watchMessages() error { + for { + m, ok := <-c.channel + if !ok { + return nil + } + if err := c.processMessage(m); err != nil { + return err + } + } +} + +var errChannelClosed = errors.New("computation channel is closed") + +func (c *Computation) processMessage(m messages.Message) error { + switch v := m.(type) { + case *messages.JobStartControlMessage: + c.handle.Set(v.Handle) + case *messages.EndOfChannelControlMessage, *messages.ChannelAbortControlMessage: + return errChannelClosed + case *messages.DataMessage: + c.dataChBuffer <- v + case *messages.ExpiredTSIDMessage: + c.Lock() + delete(c.tsidMetadata, idtool.IDFromString(v.TSID)) + c.Unlock() + c.expirationChBuffer <- v + case *messages.InfoMessage: + switch v.MessageBlock.Code { + case messages.JobRunningResolution: + c.resolutionMS.Set(v.MessageBlock.Contents.(messages.JobRunningResolutionContents).ResolutionMS()) + case messages.JobDetectedLag: + c.lagMS.Set(v.MessageBlock.Contents.(messages.JobDetectedLagContents).LagMS()) + case messages.JobInitialMaxDelay: + c.maxDelayMS.Set(v.MessageBlock.Contents.(messages.JobInitialMaxDelayContents).MaxDelayMS()) + case messages.FindLimitedResultSet: + c.matchedSize.Set(v.MessageBlock.Contents.(messages.FindLimitedResultSetContents).MatchedSize()) + c.limitSize.Set(v.MessageBlock.Contents.(messages.FindLimitedResultSetContents).LimitSize()) + case messages.FindMatchedNoTimeseries: + c.matchedNoTimeseriesQuery.Set(v.MessageBlock.Contents.(messages.FindMatchedNoTimeseriesContents).MatchedNoTimeseriesQuery()) + case messages.GroupByMissingProperty: + c.groupByMissingProperties.Set(v.MessageBlock.Contents.(messages.GroupByMissingPropertyContents).GroupByMissingProperties()) + } + c.infoChBuffer <- v + case *messages.ErrorMessage: + rawData := v.RawData() + computationError := ComputationError{} + if code, ok := rawData["error"]; ok { + computationError.Code = int(code.(float64)) + } + if msg, ok := rawData["message"]; ok && msg != nil { + computationError.Message = msg.(string) + } + if errType, ok := rawData["errorType"]; ok { + computationError.ErrorType = errType.(string) + } + return &computationError + case *messages.MetadataMessage: + c.Lock() + if _, ok := c.tsidMetadata[v.TSID]; !ok { + c.tsidMetadata[v.TSID] = &asyncMetadata[*messages.MetadataProperties]{} + } + c.tsidMetadata[v.TSID].Set(&v.Properties) + c.Unlock() + case *messages.EventMessage: + c.eventChBuffer <- v + } + return nil +} + +func bufferMessages[T any](in chan *T, out chan *T) { + buffer := make([]*T, 0) + var nextMessage *T + + defer func() { + if nextMessage != nil { + out <- nextMessage + } + for i := range buffer { + out <- buffer[i] + } + + close(out) + }() + for { + if len(buffer) > 0 { + if nextMessage == nil { + nextMessage, buffer = buffer[0], buffer[1:] + } + + select { + case out <- nextMessage: + nextMessage = nil + case msg, ok := <-in: + if !ok { + return + } + buffer = append(buffer, msg) + } + } else { + msg, ok := <-in + if !ok { + return + } + buffer = append(buffer, msg) + } + } +} + +// Data returns the channel on which data messages come. This channel will be closed when the +// computation is finished. To prevent goroutine leaks, you should read all messages from this +// channel until it is closed. +func (c *Computation) Data() <-chan *messages.DataMessage { + return c.dataCh +} + +// Expirations returns a channel that will be sent messages about expired TSIDs, i.e. time series +// that are no longer valid for this computation. This channel will be closed when the computation +// is finished. To prevent goroutine leaks, you should read all messages from this channel until it +// is closed. +func (c *Computation) Expirations() <-chan *messages.ExpiredTSIDMessage { + return c.expirationCh +} + +// Events returns a channel that receives event/alert messages from the signalflow computation. +func (c *Computation) Events() <-chan *messages.EventMessage { + return c.eventCh +} + +// Info returns a channel that receives info messages from the signalflow computation. +func (c *Computation) Info() <-chan *messages.InfoMessage { + return c.infoCh +} + +// Detach the computation on the backend +func (c *Computation) Detach(ctx context.Context) error { + return c.DetachWithReason(ctx, "") +} + +// DetachWithReason detaches the computation with a given reason. This reason will +// be reflected in the control message that signals the end of the job/channel +func (c *Computation) DetachWithReason(ctx context.Context, reason string) error { + return c.client.Detach(ctx, &DetachRequest{ + Reason: reason, + Channel: c.name, + }) +} + +// Stop the computation on the backend. +func (c *Computation) Stop(ctx context.Context) error { + return c.StopWithReason(ctx, "") +} + +// StopWithReason stops the computation with a given reason. This reason will +// be reflected in the control message that signals the end of the job/channel. +func (c *Computation) StopWithReason(ctx context.Context, reason string) error { + handle, err := c.handle.Get(ctx) + if err != nil { + return err + } + return c.client.Stop(ctx, &StopRequest{ + Reason: reason, + Handle: handle, + }) +} + +func (c *Computation) shutdown() { + close(c.dataChBuffer) + close(c.expirationChBuffer) + close(c.infoChBuffer) +} + +var ErrMetadataTimeout = errors.New("metadata value did not come in time") + +type asyncMetadata[T any] struct { + sync.Mutex + sig chan struct{} + isSet bool + val T +} + +func (a *asyncMetadata[T]) ensureInit() { + a.Lock() + if a.sig == nil { + a.sig = make(chan struct{}) + } + a.Unlock() +} + +func (a *asyncMetadata[T]) Set(val T) { + a.ensureInit() + a.Lock() + a.val = val + if !a.isSet { + close(a.sig) + a.isSet = true + } + a.Unlock() +} + +func (a *asyncMetadata[T]) Get(ctx context.Context) (T, error) { + a.ensureInit() + select { + case <-ctx.Done(): + var t T + return t, ErrMetadataTimeout + case <-a.sig: + return a.val, nil + } +} diff --git a/signalflow/computation_test.go b/signalflow/computation_test.go new file mode 100644 index 0000000..facf116 --- /dev/null +++ b/signalflow/computation_test.go @@ -0,0 +1,411 @@ +package signalflow + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/signalfx/signalflow-client-go/signalflow/messages" + "github.com/signalfx/signalfx-go/idtool" + "github.com/stretchr/testify/require" +) + +func TestBuffersDataMessages(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- &messages.DataMessage{ + Payloads: []messages.DataPayload{ + { + TSID: idtool.ID(4000), + }, + }, + } + ch <- &messages.MetadataMessage{ + TSID: idtool.ID(4000), + } + + md, _ := comp.TSIDMetadata(context.Background(), 4000) + require.NotNil(t, md) + + ch <- &messages.InfoMessage{} + + msg := waitForMsg(t, comp.Data(), comp) + require.Equal(t, idtool.ID(4000), msg.Payloads[0].TSID) + + ch <- &messages.DataMessage{ + Payloads: []messages.DataPayload{ + { + TSID: idtool.ID(4001), + }, + }, + } + msg = waitForMsg(t, comp.Data(), comp) + require.Equal(t, idtool.ID(4001), msg.Payloads[0].TSID) +} + +func waitForMsg[T any](t *testing.T, ch <-chan *T, comp *Computation) *T { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + for { + select { + case m, ok := <-ch: + if !ok { + require.FailNow(t, "message channel closed unexpected") + } + return m + case <-ctx.Done(): + require.FailNow(t, "message didn't arrive in timeout with error: %v", comp.Err()) + } + } +} + +func TestBuffersExpiryMessages(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- &messages.ExpiredTSIDMessage{ + TSID: idtool.ID(4000).String(), + } + ch <- &messages.MetadataMessage{ + TSID: idtool.ID(4000), + } + + md, _ := comp.TSIDMetadata(context.Background(), 4000) + require.NotNil(t, md) + + ch <- &messages.InfoMessage{} + + msg := waitForMsg(t, comp.Expirations(), comp) + require.Equal(t, idtool.ID(4000).String(), msg.TSID) + + ch <- &messages.ExpiredTSIDMessage{ + TSID: idtool.ID(4001).String(), + } + msg = waitForMsg(t, comp.Expirations(), comp) + require.Equal(t, idtool.ID(4001).String(), msg.TSID) +} + +func TestBuffersEventMessages(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- &messages.EventMessage{} + ch <- &messages.MetadataMessage{ + TSID: idtool.ID(4000), + } + + md, _ := comp.TSIDMetadata(context.Background(), 4000) + require.NotNil(t, md) + + ch <- &messages.InfoMessage{} + + msg := waitForMsg(t, comp.Events(), comp) + require.NotNil(t, msg) + + ch <- &messages.EventMessage{} + msg = waitForMsg(t, comp.Events(), comp) + require.NotNil(t, msg) +} + +func TestBuffersInfoMessages(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- &messages.InfoMessage{} + ch <- &messages.MetadataMessage{ + TSID: idtool.ID(4000), + } + + md, _ := comp.TSIDMetadata(context.Background(), 4000) + require.NotNil(t, md) + + ch <- &messages.InfoMessage{} + + msg := waitForMsg(t, comp.Info(), comp) + require.NotNil(t, msg) + + ch <- &messages.InfoMessage{} + msg = waitForMsg(t, comp.Info(), comp) + require.NotNil(t, msg) +} + +func mustParse(m messages.Message, err error) messages.Message { + if err != nil { + panic(err) + } + return m +} + +func TestResolutionMetadata(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + + wg := sync.WaitGroup{} + + // Ensure multiple calls get the same result and also wait for the message + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resolution, _ := comp.Resolution(context.Background()) + require.Equal(t, 5*time.Second, resolution) + }() + } + + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "message", + "message": { + "messageCode": "JOB_RUNNING_RESOLUTION", + "contents": { + "resolutionMs": 5000 + } + } + }`), true)) + + wg.Wait() +} + +func TestMaxDelayMetadata(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "message", + "message": { + "messageCode": "JOB_INITIAL_MAX_DELAY", + "contents": { + "maxDelayMs": 1000 + } + } + }`), true)) + + maxDelay, _ := comp.MaxDelay(context.Background()) + require.Equal(t, 1*time.Second, maxDelay) +} + +func TestLagMetadata(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "message", + "message": { + "messageCode": "JOB_DETECTED_LAG", + "contents": { + "lagMs": 3500 + } + } + }`), true)) + + lag, _ := comp.Lag(context.Background()) + require.Equal(t, 3500*time.Millisecond, lag) +} + +func TestFindLimitedResultSetMetadata(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "message", + "message": { + "messageCode": "FIND_LIMITED_RESULT_SET", + "contents": { + "matchedSize": 123456789, + "limitSize": 50000 + } + } + }`), true)) + + matchedSize, _ := comp.MatchedSize(context.Background()) + require.Equal(t, 123456789, matchedSize) + + limitSize, _ := comp.LimitSize(context.Background()) + require.Equal(t, 50000, limitSize) +} + +func TestMatchedNoTimeseriesQueryMetaData(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "message", + "message": { + "messageCode": "FIND_MATCHED_NO_TIMESERIES", + "contents": { + "query": "abc" + } + } + }`), true)) + + noMatched, _ := comp.MatchedNoTimeseriesQuery(context.Background()) + require.Equal(t, "abc", noMatched) +} + +func TestGroupByMissingPropertyMetaData(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "message", + "message": { + "messageCode": "GROUPBY_MISSING_PROPERTY", + "contents": { + "propertyNames": ["x", "y", "z"] + } + } + }`), true)) + + missingProps, _ := comp.GroupByMissingProperties(context.Background()) + require.Equal(t, []string{"x", "y", "z"}, missingProps) +} + +func TestHandle(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "control-message", + "event": "JOB_START", + "handle": "AAAABBBB" + }`), true)) + + handle, _ := comp.Handle(context.Background()) + require.Equal(t, "AAAABBBB", handle) +} + +func TestComputationError(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "error", + "error": 400, + "errorType": "ANALYTICS_PROGRAM_NAME_ERROR", + "message": "We hit some error" + }`), true)) + + err := waitForComputationError(t, comp) + var ce *ComputationError + if !errors.As(err, &ce) { + t.FailNow() + } + require.Equal(t, 400, ce.Code) + require.Equal(t, "ANALYTICS_PROGRAM_NAME_ERROR", ce.ErrorType) + require.Equal(t, "We hit some error", ce.Message) +} + +func TestComputationErrorWithNullMessage(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "error", + "error": 400, + "errorType": "ANALYTICS_INTERNAL_ERROR", + "message": null + }`), true)) + + err := waitForComputationError(t, comp) + var ce *ComputationError + if !errors.As(err, &ce) { + t.FailNow() + } + require.Equal(t, 400, ce.Code) + require.Equal(t, "ANALYTICS_INTERNAL_ERROR", ce.ErrorType) + require.Equal(t, "", ce.Message) +} + +func waitForComputationError(t *testing.T, comp *Computation) error { + t.Helper() + start := time.Now() + var err error + for time.Since(start) < 3*time.Second { + err = comp.Err() + if err != nil { + return err + } + time.Sleep(50 * time.Millisecond) + } + require.FailNow(t, "computation did not fail") + return nil +} + +func TestComputationFinish(t *testing.T) { + t.Parallel() + ch := make(chan messages.Message) + comp := newComputation(ch, "ch1", &Client{ + defaultMetadataTimeout: 1 * time.Second, + }) + defer close(ch) + go func() { + ch <- mustParse(messages.ParseMessage([]byte(`{ + "type": "control-message", + "event": "JOB_START", + "handle": "AAAABBBB" + }`), true)) + + ch <- &messages.MetadataMessage{ + TSID: idtool.ID(4000), + } + + ch <- &messages.DataMessage{ + Payloads: []messages.DataPayload{ + { + TSID: idtool.ID(4000), + }, + }, + } + + ch <- &messages.EndOfChannelControlMessage{} + }() + + for msg := range comp.Data() { + require.Equal(t, idtool.ID(4000), msg.Payloads[0].TSID) + } + + // The for loop should exit when the end of channel message comes through +} diff --git a/signalflow/conn.go b/signalflow/conn.go new file mode 100644 index 0000000..6da947a --- /dev/null +++ b/signalflow/conn.go @@ -0,0 +1,193 @@ +package signalflow + +import ( + "context" + "fmt" + "net/url" + "path" + "time" + + "github.com/gorilla/websocket" +) + +// How long to wait between connections in case of a bad connection. +var reconnectDelay = 5 * time.Second + +type wsConn struct { + StreamURL *url.URL + + OutgoingTextMsgs chan *outgoingMessage + IncomingTextMsgs chan []byte + IncomingBinaryMsgs chan []byte + ConnectedCh chan struct{} + + ConnectTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + OnError OnErrorFunc + PostDisconnectCallback func() + PostConnectMessage func() []byte +} + +type outgoingMessage struct { + bytes []byte + resultCh chan error +} + +// Run keeps the connection alive and puts all incoming messages into a channel +// as needed. +func (c *wsConn) Run(ctx context.Context) { + var conn *websocket.Conn + defer func() { + if conn != nil { + conn.Close() + } + }() + + for { + if conn != nil { + conn.Close() + time.Sleep(reconnectDelay) + } + // This will get run on before the first connection as well. + if c.PostDisconnectCallback != nil { + c.PostDisconnectCallback() + } + + if ctx.Err() != nil { + return + } + + dialCtx, cancel := context.WithTimeout(ctx, c.ConnectTimeout) + var err error + conn, err = c.connect(dialCtx) + cancel() + if err != nil { + c.sendErrIfWanted(fmt.Errorf("Error connecting to SignalFlow websocket: %w", err)) + continue + } + + err = c.postConnect(conn) + if err != nil { + c.sendErrIfWanted(fmt.Errorf("Error setting up SignalFlow websocket: %w", err)) + continue + } + + err = c.readAndWriteMessages(conn) + if err == nil { + return + } + c.sendErrIfWanted(fmt.Errorf("Error in SignalFlow websocket: %w", err)) + } +} + +type messageWithType struct { + bytes []byte + msgType int +} + +func (c *wsConn) readAndWriteMessages(conn *websocket.Conn) error { + readMessageCh := make(chan messageWithType) + readErrCh := make(chan error) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + bytes, typ, err := readNextMessage(conn, c.ReadTimeout) + if err != nil { + select { + case readErrCh <- err: + case <-ctx.Done(): + } + return + } + readMessageCh <- messageWithType{ + bytes: bytes, + msgType: typ, + } + } + }() + + for { + select { + case msg, ok := <-readMessageCh: + if !ok { + return nil + } + if msg.msgType == websocket.TextMessage { + c.IncomingTextMsgs <- msg.bytes + } else { + c.IncomingBinaryMsgs <- msg.bytes + } + case err := <-readErrCh: + return err + case msg, ok := <-c.OutgoingTextMsgs: + if !ok { + return nil + } + err := c.writeMessage(conn, msg.bytes) + msg.resultCh <- err + if err != nil { + return err + } + } + } +} + +func (c *wsConn) sendErrIfWanted(err error) { + if c.OnError != nil { + c.OnError(err) + } +} + +func (c *wsConn) Close() { + close(c.IncomingTextMsgs) + close(c.IncomingBinaryMsgs) +} + +func (c *wsConn) connect(ctx context.Context) (*websocket.Conn, error) { + connectURL := *c.StreamURL + connectURL.Path = path.Join(c.StreamURL.Path, "connect") + conn, _, err := websocket.DefaultDialer.DialContext(ctx, connectURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("could not connect Signalflow websocket: %w", err) + } + return conn, nil +} + +func (c *wsConn) postConnect(conn *websocket.Conn) error { + if c.PostConnectMessage != nil { + msg := c.PostConnectMessage() + if msg != nil { + return c.writeMessage(conn, msg) + } + } + return nil +} + +func readNextMessage(conn *websocket.Conn, timeout time.Duration) (data []byte, msgType int, err error) { + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return nil, 0, fmt.Errorf("could not set read timeout in SignalFlow client: %w", err) + } + + typ, bytes, err := conn.ReadMessage() + if err != nil { + return nil, 0, err + } + return bytes, typ, nil +} + +func (c *wsConn) writeMessage(conn *websocket.Conn, msgBytes []byte) error { + err := conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + if err != nil { + return fmt.Errorf("could not set write timeout for SignalFlow request: %w", err) + } + + err = conn.WriteMessage(websocket.TextMessage, msgBytes) + if err != nil { + return err + } + return nil +} diff --git a/signalfow/doc.go b/signalflow/doc.go similarity index 100% rename from signalfow/doc.go rename to signalflow/doc.go diff --git a/signalflow/fake_backend.go b/signalflow/fake_backend.go new file mode 100644 index 0000000..f3166f9 --- /dev/null +++ b/signalflow/fake_backend.go @@ -0,0 +1,382 @@ +package signalflow + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/signalfx/signalflow-client-go/signalflow/messages" + "github.com/signalfx/signalfx-go/idtool" +) + +var upgrader = websocket.Upgrader{} // use default options + +type tsidVal struct { + TSID idtool.ID + Val float64 +} + +// FakeBackend is useful for testing, both internal to this package and +// externally. It supports basic messages and allows for the specification of +// metadata and data messages that map to a particular program. +type FakeBackend struct { + sync.Mutex + + AccessToken string + authenticated bool + + conns map[*websocket.Conn]bool + + received []map[string]interface{} + metadataByTSID map[idtool.ID]*messages.MetadataProperties + dataByTSID map[idtool.ID]*float64 + tsidsByProgram map[string][]idtool.ID + programErrors map[string]string + runningJobsByProgram map[string]int + cancelFuncsByHandle map[string]context.CancelFunc + cancelFuncsByChannel map[string]context.CancelFunc + server *httptest.Server + handleIdx int +} + +func (f *FakeBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + panic(err) + } + f.registerConn(c) + defer c.Close() + defer cancel() + + textMsgs := make(chan string) + binMsgs := make(chan []byte) + go func() { + for { + var err error + select { + case m := <-textMsgs: + err = c.WriteMessage(websocket.TextMessage, []byte(m)) + case m := <-binMsgs: + err = c.WriteMessage(websocket.BinaryMessage, m) + case <-ctx.Done(): + f.unregisterConn(c) + return + } + if err != nil { + log.Printf("Could not write message: %v", err) + } + } + }() + + for { + _, message, err := c.ReadMessage() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + log.Println("read err:", err) + } + return + } + + var in map[string]interface{} + if err := json.Unmarshal(message, &in); err != nil { + log.Println("error unmarshalling: ", err) + } + f.received = append(f.received, in) + + err = f.handleMessage(ctx, in, textMsgs, binMsgs) + if err != nil { + log.Printf("Error handling fake backend message, closing connection: %v", err) + return + } + } +} + +func (f *FakeBackend) registerConn(conn *websocket.Conn) { + f.Lock() + f.conns[conn] = true + f.Unlock() +} + +func (f *FakeBackend) unregisterConn(conn *websocket.Conn) { + f.Lock() + delete(f.conns, conn) + f.Unlock() +} + +func (f *FakeBackend) handleMessage(ctx context.Context, message map[string]interface{}, textMsgs chan<- string, binMsgs chan<- []byte) error { + typ, ok := message["type"].(string) + if !ok { + textMsgs <- `{"type": "error"}` + return nil + } + + switch typ { + case "authenticate": + token, _ := message["token"].(string) + if f.AccessToken == "" || token == f.AccessToken { + textMsgs <- `{"type": "authenticated"}` + f.authenticated = true + } else { + textMsgs <- `{"type": "error", "message": "Invalid auth token"}` + return errors.New("bad auth token") + } + case "stop": + if cancel := f.cancelFuncsByHandle[message["handle"].(string)]; cancel != nil { + cancel() + } + case "detach": + if cancel := f.cancelFuncsByChannel[message["channel"].(string)]; cancel != nil { + cancel() + } + case "execute": + if !f.authenticated { + return errors.New("not authenticated") + } + program, _ := message["program"].(string) + ch, _ := message["channel"].(string) + + if errMsg := f.programErrors[program]; errMsg != "" { + textMsgs <- fmt.Sprintf(`{"type": "error", "message": "%s", "channel": "%s"}`, errMsg, ch) + } + + programTSIDs := f.tsidsByProgram[program] + + handle := fmt.Sprintf("handle-%d", f.handleIdx) + f.handleIdx++ + + execCtx, cancel := context.WithCancel(ctx) + f.cancelFuncsByHandle[handle] = cancel + f.cancelFuncsByChannel[ch] = cancel + + log.Printf("Executing SignalFlow program %s with tsids %v and handle %s", program, programTSIDs, handle) + f.runningJobsByProgram[program]++ + + var resolutionMs int + for _, tsid := range programTSIDs { + if md := f.metadataByTSID[tsid]; md != nil { + if md.ResolutionMS > resolutionMs { + resolutionMs = md.ResolutionMS + } + } + } + + messageResMs, _ := message["resolution"].(float64) + if messageResMs != 0.0 { + resolutionMs = int(messageResMs) + } + + if resolutionMs == 0 { + resolutionMs = 1000 + } + + // use start and stop to control ending the fakebackend + var stopMs uint64 + var startMs uint64 + messageStopMs, _ := message["stop"].(float64) + if messageStopMs != 0.0 { + stopMs = uint64(messageStopMs) + } + + messageStartMs, _ := message["start"].(float64) + if messageStartMs != 0.0 { + startMs = uint64(messageStartMs) + } + + if startMs == 0 { + startMs = uint64(time.Now().UnixNano() / (1000 * 1000)) + } + + textMsgs <- fmt.Sprintf(`{"type": "control-message", "channel": "%s", "event": "STREAM_START"}`, ch) + textMsgs <- fmt.Sprintf(`{"type": "control-message", "channel": "%s", "event": "JOB_START", "handle": "%s"}`, ch, handle) + textMsgs <- fmt.Sprintf(`{"type": "message", "channel": "%s", "logicalTimestampMs": 1464736034000, "message": {"contents": {"resolutionMs" : %d}, "messageCode": "JOB_RUNNING_RESOLUTION", "timestampMs": 1464736033000}}`, ch, int64(resolutionMs)) + + for _, tsid := range programTSIDs { + if md := f.metadataByTSID[tsid]; md != nil { + propJSON, err := json.Marshal(md) + if err != nil { + log.Printf("Error serializing metadata to json: %v", err) + continue + } + textMsgs <- fmt.Sprintf(`{"type": "metadata", "tsId": "%s", "channel": "%s", "properties": %s}`, tsid, ch, propJSON) + } + } + + log.Print("done sending metadata messages") + + // Send data periodically until the connection is closed. + iterations := 0 + go func() { + t := time.NewTicker(time.Duration(resolutionMs) * time.Millisecond) + for { + select { + case <-execCtx.Done(): + log.Printf("sending done") + f.Lock() + f.runningJobsByProgram[program]-- + f.Unlock() + return + case <-t.C: + f.Lock() + valsWithTSID := []tsidVal{} + for _, tsid := range programTSIDs { + if data := f.dataByTSID[tsid]; data != nil { + valsWithTSID = append(valsWithTSID, tsidVal{TSID: tsid, Val: *data}) + } + } + f.Unlock() + metricTime := startMs + uint64(iterations*resolutionMs) + if stopMs != 0 && metricTime > stopMs { + log.Printf("sending channel end") + // tell the client the computation is complete + textMsgs <- fmt.Sprintf(`{"type": "control-message", "channel": "%s", "event": "END_OF_CHANNEL", "handle": "%s"}`, ch, handle) + return + } + log.Printf("sending data message") + binMsgs <- makeDataMessage(ch, valsWithTSID, metricTime) + log.Printf("done sending data message") + iterations++ + } + } + }() + } + return nil +} + +func makeDataMessage(channel string, valsWithTSID []tsidVal, now uint64) []byte { + var ch [16]byte + copy(ch[:], channel) + header := messages.BinaryMessageHeader{ + Version: 1, + MessageType: 5, + Flags: 0, + Reserved: 0, + Channel: ch, + } + w := new(bytes.Buffer) + binary.Write(w, binary.BigEndian, &header) + + dataHeader := messages.DataMessageHeader{ + TimestampMillis: now, + ElementCount: uint32(len(valsWithTSID)), + } + binary.Write(w, binary.BigEndian, &dataHeader) + + for i := range valsWithTSID { + var valBytes [8]byte + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, valsWithTSID[i].Val) + copy(valBytes[:], buf.Bytes()) + + payload := messages.DataPayload{ + Type: messages.ValTypeDouble, + TSID: valsWithTSID[i].TSID, + Val: valBytes, + } + + binary.Write(w, binary.BigEndian, &payload) + } + + return w.Bytes() +} + +func (f *FakeBackend) Start() { + f.metadataByTSID = map[idtool.ID]*messages.MetadataProperties{} + f.dataByTSID = map[idtool.ID]*float64{} + f.tsidsByProgram = map[string][]idtool.ID{} + f.programErrors = map[string]string{} + f.runningJobsByProgram = map[string]int{} + f.cancelFuncsByHandle = map[string]context.CancelFunc{} + f.cancelFuncsByChannel = map[string]context.CancelFunc{} + f.conns = map[*websocket.Conn]bool{} + f.server = httptest.NewServer(f) +} + +func (f *FakeBackend) Stop() { + f.KillExistingConnections() + f.server.Close() +} + +func (f *FakeBackend) Restart() { + l, err := net.Listen("tcp", f.server.Listener.Addr().String()) + if err != nil { + panic("Could not relisten: " + err.Error()) + } + f.server = httptest.NewUnstartedServer(f) + f.server.Listener = l + f.server.Start() +} + +func (f *FakeBackend) Client() (*Client, error) { + return NewClient(StreamURL(f.URL()), AccessToken(f.AccessToken)) +} + +func (f *FakeBackend) AddProgramError(program string, errorMsg string) { + f.Lock() + f.programErrors[program] = errorMsg + f.Unlock() +} + +func (f *FakeBackend) AddProgramTSIDs(program string, tsids []idtool.ID) { + f.Lock() + f.tsidsByProgram[program] = tsids + f.Unlock() +} + +func (f *FakeBackend) AddTSIDMetadata(tsid idtool.ID, props *messages.MetadataProperties) { + f.Lock() + f.metadataByTSID[tsid] = props + f.Unlock() +} + +func (f *FakeBackend) SetTSIDFloatData(tsid idtool.ID, val float64) { + f.Lock() + f.dataByTSID[tsid] = &val + f.Unlock() +} + +func (f *FakeBackend) RemoveTSIDData(tsid idtool.ID) { + f.Lock() + delete(f.dataByTSID, tsid) + f.Unlock() +} + +func (f *FakeBackend) URL() string { + return strings.Replace(f.server.URL, "http", "ws", 1) +} + +func (f *FakeBackend) KillExistingConnections() { + f.Lock() + for conn := range f.conns { + conn.Close() + } + f.Unlock() +} + +// RunningJobsForProgram returns how many currently executing jobs there are +// for a particular program text. +func (f *FakeBackend) RunningJobsForProgram(program string) int { + f.Lock() + defer f.Unlock() + return f.runningJobsByProgram[program] +} + +func NewRunningFakeBackend() *FakeBackend { + f := &FakeBackend{ + AccessToken: "abcd", + } + f.Start() + return f +} diff --git a/signalflow/fake_backend_test.go b/signalflow/fake_backend_test.go new file mode 100644 index 0000000..7e37fc8 --- /dev/null +++ b/signalflow/fake_backend_test.go @@ -0,0 +1,125 @@ +package signalflow + +import ( + "context" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/signalfx/signalflow-client-go/signalflow/messages" + "github.com/signalfx/signalfx-go/idtool" + "github.com/stretchr/testify/assert" +) + +const program = "testflow" + +type testCase struct { + timeSeriesProperties []map[string]string + name string + startMs int64 // epoch time ms to start the metrics + stopMs int64 // epoch time ms to stop the metrics + resolutionSecs int64 // seconds gap between the metrics + expectedTimestamps []int64 + numberOfSfxClients int // count of SFX clients to connect to fakebackend +} + +func TestFakeBackend(t *testing.T) { + t.Parallel() + + now := time.Now() + testCases := []testCase{ + { + timeSeriesProperties: []map[string]string{}, + name: "no metrics with one resolution window", + stopMs: now.UnixNano() / (1000 * 1000), + startMs: now.Add(-2*time.Second).UnixNano() / (1000 * 1000), + resolutionSecs: 2, + expectedTimestamps: []int64{ + now.Add(-2*time.Second).UnixNano() / (1000 * 1000), + now.UnixNano() / (1000 * 1000), + }, + numberOfSfxClients: 2, + }, + { + timeSeriesProperties: []map[string]string{ + { + "dim1": "val1", + "dim2": "val2", + }, + { + "dim1": "val1", + }, + }, + name: "some metrics across 2 resolution windows", + stopMs: now.UnixNano() / (1000 * 1000), + startMs: now.Add(-4*time.Second).UnixNano() / (1000 * 1000), + resolutionSecs: 2, + expectedTimestamps: []int64{ + now.Add(-4*time.Second).UnixNano() / (1000 * 1000), + now.Add(-2*time.Second).UnixNano() / (1000 * 1000), + now.UnixNano() / (1000 * 1000), + }, + numberOfSfxClients: 2, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + fakeBackend := NewRunningFakeBackend() + tsids := []idtool.ID{} + for range testCase.timeSeriesProperties { + tsids = append(tsids, idtool.ID(rand.Int63())) + } + for i, ts := range testCase.timeSeriesProperties { + fakeBackend.AddTSIDMetadata(tsids[i], &messages.MetadataProperties{ + Metric: program, + CustomProperties: ts, + }) + fakeBackend.SetTSIDFloatData(tsids[i], 0) + } + + fakeBackend.AddProgramTSIDs(program, tsids) + + // connect N clients so we can prove the fakebackend is not killed by the first client disconnecting + for i := 1; i <= testCase.numberOfSfxClients; i++ { + sfxClient, _ := NewClient(StreamURL(fakeBackend.URL()), AccessToken(fakeBackend.AccessToken)) + processClient(t, sfxClient, testCase, i) + } + }) + } +} + +func processClient(t *testing.T, sfxClient *Client, testCase testCase, connectionCount int) { + t.Helper() + data, _ := sfxClient.Execute(context.Background(), &ExecuteRequest{ + Program: program, + StartMs: testCase.startMs, + StopMs: testCase.stopMs, + ResolutionMs: testCase.resolutionSecs * 1000, + }) + + timestamps := []int64{} + datapointCount := 0 + for msg := range data.Data() { + timestamps = append(timestamps, int64(msg.TimestampMillis)) + datapoints := []map[string]string{} + for _, pl := range msg.Payloads { + meta, _ := data.TSIDMetadata(context.Background(), pl.TSID) + dims := map[string]string{} + for k, v := range meta.CustomProperties { + dims[k] = v + } + datapoints = append(datapoints, dims) + datapointCount++ + } + // the datapoints should be always the same the fed in mts + assert.Equal(t, testCase.timeSeriesProperties, datapoints, testCase.name+": datapoints are wrong on connection "+strconv.Itoa(connectionCount)) + } + + assert.Equal(t, testCase.expectedTimestamps, timestamps, testCase.name+": timestamps in metrics are wrong on connection "+strconv.Itoa(connectionCount)) + // the number of datapoints should be the number of resolution windows multiplied by the number of MTS in each timestamp payload + assert.Equal(t, len(testCase.expectedTimestamps)*len(testCase.timeSeriesProperties), datapointCount, testCase.name+": amount of datapoints unexpected on connection "+strconv.Itoa(connectionCount)) +} diff --git a/signalflow/messages/binary.go b/signalflow/messages/binary.go new file mode 100644 index 0000000..cc20e1d --- /dev/null +++ b/signalflow/messages/binary.go @@ -0,0 +1,185 @@ +package messages + +import ( + "bytes" + "compress/gzip" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + + "github.com/signalfx/signalfx-go/idtool" +) + +type DataPayload struct { + Type ValType + TSID idtool.ID + Val [8]byte +} + +// Value returns the numeric value as an interface{}. +func (dp *DataPayload) Value() interface{} { + switch dp.Type { + case ValTypeLong: + return dp.Int64() + case ValTypeDouble: + return dp.Float64() + case ValTypeInt: + return dp.Int32() + default: + return nil + } +} + +func (dp *DataPayload) Int64() int64 { + n := binary.BigEndian.Uint64(dp.Val[:]) + return int64(n) +} + +func (dp *DataPayload) Float64() float64 { + bits := binary.BigEndian.Uint64(dp.Val[:]) + return math.Float64frombits(bits) +} + +func (dp *DataPayload) Int32() int32 { + var n int32 + _ = binary.Read(bytes.NewBuffer(dp.Val[:]), binary.BigEndian, &n) + return n +} + +// DataMessage is a set of datapoints that share a common timestamp +type DataMessage struct { + BaseMessage + BaseChannelMessage + TimestampedMessage + Payloads []DataPayload +} + +func (dm *DataMessage) String() string { + pls := make([]map[string]interface{}, 0) + for _, pl := range dm.Payloads { + pls = append(pls, map[string]interface{}{ + "type": pl.Type, + "tsid": pl.TSID, + "value": pl.Value(), + }) + } + + return fmt.Sprintf("%v", map[string]interface{}{ + "channel": dm.Channel(), + "timestamp": dm.Timestamp(), + "payloads": pls, + }) +} + +type DataMessageHeader struct { + TimestampMillis uint64 + ElementCount uint32 +} + +type ValType uint8 + +const ( + ValTypeLong ValType = 1 + ValTypeDouble ValType = 2 + ValTypeInt ValType = 3 +) + +func (vt ValType) String() string { + switch vt { + case ValTypeLong: + return "long" + case ValTypeDouble: + return "double" + case ValTypeInt: + return "int32" + } + return "Unknown" +} + +// BinaryMessageHeader represents the first 20 bytes of every binary websocket +// message from the backend. +// https://developers.signalfx.com/signalflow_analytics/rest_api_messages/stream_messages_specification.html#_binary_encoding_of_websocket_messages +type BinaryMessageHeader struct { + Version uint8 + MessageType uint8 + Flags uint8 + Reserved uint8 + Channel [16]byte +} + +const ( + compressed uint8 = 1 << iota + jsonEncoded = 1 << iota +) + +func parseBinaryHeader(msg []byte) (string, bool /* isCompressed */, bool /* isJSON */, []byte /* rest of message */, error) { + if len(msg) <= 20 { + return "", false, false, nil, fmt.Errorf("expected SignalFlow message of at least 21 bytes, got %d bytes", len(msg)) + } + + r := bytes.NewReader(msg[:20]) + var header BinaryMessageHeader + err := binary.Read(r, binary.BigEndian, &header) + if err != nil { + return "", false, false, nil, err + } + + isCompressed := header.Flags&compressed != 0 + isJSON := header.Flags&jsonEncoded != 0 + + return string(header.Channel[:bytes.IndexByte(header.Channel[:], 0)]), isCompressed, isJSON, msg[20:], err +} + +func parseBinaryMessage(msg []byte) (Message, error) { + channel, isCompressed, isJSON, rest, err := parseBinaryHeader(msg) + if err != nil { + return nil, err + } + + if isCompressed { + reader, err := gzip.NewReader(bytes.NewReader(rest)) + if err != nil { + return nil, err + } + rest, err = io.ReadAll(reader) + if err != nil { + return nil, err + } + } + + if isJSON { + return nil, errors.New("cannot handle json binary message") + } + + r := bytes.NewReader(rest[:12]) + var header DataMessageHeader + err = binary.Read(r, binary.BigEndian, &header) + if err != nil { + return nil, err + } + + var payloads []DataPayload + for i := 0; i < int(header.ElementCount); i++ { + r := bytes.NewReader(rest[12+17*i : 12+17*(i+1)]) + var payload DataPayload + if err := binary.Read(r, binary.BigEndian, &payload); err != nil { + return nil, err + } + payloads = append(payloads, payload) + } + + return &DataMessage{ + BaseMessage: BaseMessage{ + Typ: DataType, + }, + BaseChannelMessage: BaseChannelMessage{ + Chan: channel, + }, + TimestampedMessage: TimestampedMessage{ + TimestampMillis: header.TimestampMillis, + }, + Payloads: payloads, + }, nil +} diff --git a/signalflow/messages/binary_test.go b/signalflow/messages/binary_test.go new file mode 100644 index 0000000..cc2fe2c --- /dev/null +++ b/signalflow/messages/binary_test.go @@ -0,0 +1,43 @@ +package messages + +import ( + "encoding/base64" + "testing" + "time" + + "github.com/signalfx/signalfx-go/idtool" + "github.com/stretchr/testify/assert" +) + +// DECODED {:type=>"data", :logicalTimestampMs=>1504064040000, :logicalTimestamp=>2017-08-30 03:34:00 +0000, :data=>{3079061720=>691.1, 3553579776=>5828.0, 2479549961=>9939.4, 2038453579=>94.8, 3928812177=>2952.2, 2885058095=>686.0, 3689271047=>695.4, 683255203=>756.3, 202128297=>5800.4, 1462695611=>2796.9, 3391947226=>44.8, 1321572762=>1302.3, 1136315563=>8700.8, 122567741=>16128.1, 800290351=>762.5, 1533912710=>2439.7}, :channel=>"channel-1"} +const binaryMsgBase64 = "AQUBAGNoYW5uZWwtMQAAAAAAAAAfiwgAAAAAAAAAY2BgjDM0fuPAwMAgwAQkGLa3Hbjh0DrjDBCcBQtcPh/M4LDtCIjJABaYfOodp8PhzM3GQAAWqGwI9XYIB/MhAq90Xk10WC6QBgJggdXf6vQdWgsQZtx+ep7doXU3QovGzqWLHdoXIbTw8FivdNi2AiEQrvdtt8PSmwiHndJlvOXgloZQ4Xd05iyHKZEIQ513nVntcOAfQgW7n42tw/kGHrgZ+pvL9B3aryAcFp27rM1hMT9YCwAua3WrHAEAAA==" + +func TestDecodeBinaryMessage(t *testing.T) { + rawMsg, err := base64.StdEncoding.DecodeString(binaryMsgBase64) + if err != nil { + panic("Could not decode test message") + } + + msg, err := parseBinaryMessage(rawMsg) + if err != nil { + t.Fatalf("could not parse message: %v", err) + } + + dm, ok := msg.(*DataMessage) + assert.True(t, ok, "message was not data message") + assert.NotNil(t, dm) + + assert.Equal(t, dm.Type(), DataType) + assert.Equal(t, dm.Channel(), "channel-1") + + assert.Equal(t, dm.TimestampMillis, uint64(1504064040000)) + assert.Equal( + t, + dm.Timestamp().Unix(), + time.Date(2017, 8, 29, 23, 34, 0, 0, time.FixedZone("EDT", -4*60*60)).Unix(), + ) + + assert.Len(t, dm.Payloads, 16) + assert.Equal(t, dm.Payloads[0].Value(), 691.1) + assert.Equal(t, dm.Payloads[0].TSID, idtool.ID(3079061720)) +} diff --git a/signalflow/messages/control.go b/signalflow/messages/control.go new file mode 100644 index 0000000..687a3e7 --- /dev/null +++ b/signalflow/messages/control.go @@ -0,0 +1,30 @@ +package messages + +// The event types used in the control-message messages. This are not used for +// "event" type messages. +const ( + StreamStartEvent = "STREAM_START" + JobStartEvent = "JOB_START" + JobProgressEvent = "JOB_PROGRESS" + ChannelAbortEvent = "CHANNEL_ABORT" + EndOfChannelEvent = "END_OF_CHANNEL" +) + +type BaseControlMessage struct { + BaseJSONChannelMessage + TimestampedMessage + Event string `json:"event"` +} + +type JobStartControlMessage struct { + BaseControlMessage + Handle string `json:"handle"` +} + +type EndOfChannelControlMessage struct { + BaseControlMessage +} + +type ChannelAbortControlMessage struct { + BaseControlMessage +} diff --git a/signalflow/messages/error.go b/signalflow/messages/error.go new file mode 100644 index 0000000..5239858 --- /dev/null +++ b/signalflow/messages/error.go @@ -0,0 +1,19 @@ +package messages + +type ErrorContext struct { + BindingName string `json:"bindingName"` + Column int `json:"column"` + Line int `json:"line"` + ProgramText string `json:"programText"` + Reference string `json:"reference"` + Traceback interface{} `json:"traceback"` +} + +type ErrorMessage struct { + BaseJSONChannelMessage + + Context ErrorContext `json:"context"` + Error int `json:"error"` + ErrorType string `json:"errorType"` + Message string `json:"message"` +} diff --git a/signalflow/messages/event.go b/signalflow/messages/event.go new file mode 100644 index 0000000..791745c --- /dev/null +++ b/signalflow/messages/event.go @@ -0,0 +1,5 @@ +package messages + +type EventMessage struct { + BaseJSONChannelMessage +} diff --git a/signalflow/messages/info.go b/signalflow/messages/info.go new file mode 100644 index 0000000..860ad65 --- /dev/null +++ b/signalflow/messages/info.go @@ -0,0 +1,122 @@ +package messages + +import ( + "encoding/json" + "time" +) + +const ( + JobRunningResolution = "JOB_RUNNING_RESOLUTION" + JobDetectedLag = "JOB_DETECTED_LAG" + JobInitialMaxDelay = "JOB_INITIAL_MAX_DELAY" + FindLimitedResultSet = "FIND_LIMITED_RESULT_SET" + FindMatchedNoTimeseries = "FIND_MATCHED_NO_TIMESERIES" + GroupByMissingProperty = "GROUPBY_MISSING_PROPERTY" +) + +type MessageBlock struct { + TimestampedMessage + Code string `json:"messageCode"` + Level string `json:"messageLevel"` + NumInputTimeseries int `json:"numInputTimeSeries"` + // If the messageCode field in the message is known, this will be an + // instance that has more specific methods to access the known fields. You + // can always access the original content by treating this value as a + // map[string]interface{}. + Contents interface{} `json:"-"` + ContentsRaw map[string]interface{} `json:"contents"` +} + +type InfoMessage struct { + BaseJSONChannelMessage + LogicalTimestampMillis uint64 `json:"logicalTimestampMs"` + MessageBlock `json:"message"` +} + +func (im *InfoMessage) UnmarshalJSON(raw []byte) error { + type IM InfoMessage + if err := json.Unmarshal(raw, (*IM)(im)); err != nil { + return err + } + + mb := &im.MessageBlock + switch mb.Code { + case JobRunningResolution: + mb.Contents = JobRunningResolutionContents(mb.ContentsRaw) + case JobDetectedLag: + mb.Contents = JobDetectedLagContents(mb.ContentsRaw) + case JobInitialMaxDelay: + mb.Contents = JobInitialMaxDelayContents(mb.ContentsRaw) + case FindLimitedResultSet: + mb.Contents = FindLimitedResultSetContents(mb.ContentsRaw) + case FindMatchedNoTimeseries: + mb.Contents = FindMatchedNoTimeseriesContents(mb.ContentsRaw) + case GroupByMissingProperty: + mb.Contents = GroupByMissingPropertyContents(mb.ContentsRaw) + default: + mb.Contents = mb.ContentsRaw + } + + return nil +} + +func (im *InfoMessage) LogicalTimestamp() time.Time { + return time.Unix(0, int64(im.LogicalTimestampMillis*uint64(time.Millisecond))) +} + +type JobRunningResolutionContents map[string]interface{} + +func (jm JobRunningResolutionContents) ResolutionMS() int { + field, _ := jm["resolutionMs"].(float64) + return int(field) +} + +type JobDetectedLagContents map[string]interface{} + +func (jm JobDetectedLagContents) LagMS() int { + field, _ := jm["lagMs"].(float64) + return int(field) +} + +type JobInitialMaxDelayContents map[string]interface{} + +func (jm JobInitialMaxDelayContents) MaxDelayMS() int { + field, _ := jm["maxDelayMs"].(float64) + return int(field) +} + +type FindLimitedResultSetContents map[string]interface{} + +func (jm FindLimitedResultSetContents) MatchedSize() int { + field, _ := jm["matchedSize"].(float64) + return int(field) +} + +func (jm FindLimitedResultSetContents) LimitSize() int { + field, _ := jm["limitSize"].(float64) + return int(field) +} + +type FindMatchedNoTimeseriesContents map[string]interface{} + +func (jm FindMatchedNoTimeseriesContents) MatchedNoTimeseriesQuery() string { + field, _ := jm["query"].(string) + return field +} + +type GroupByMissingPropertyContents map[string]interface{} + +func (jm GroupByMissingPropertyContents) GroupByMissingProperties() []string { + propNames := make([]string, len(jm["propertyNames"].([]interface{}))) + for i, v := range jm["propertyNames"].([]interface{}) { + propNames[i] = v.(string) + } + return propNames +} + +// ExpiredTSIDMessage is received when a timeseries has expired and is no +// longer relvant to a computation. +type ExpiredTSIDMessage struct { + BaseJSONChannelMessage + TSID string `json:"tsId"` +} diff --git a/signalflow/messages/json.go b/signalflow/messages/json.go new file mode 100644 index 0000000..a42b5b6 --- /dev/null +++ b/signalflow/messages/json.go @@ -0,0 +1,44 @@ +package messages + +import ( + "encoding/json" +) + +func parseJSONMessage(baseMessage Message, msg []byte) (JSONMessage, error) { + var out JSONMessage + switch baseMessage.Type() { + case AuthenticatedType: + out = &AuthenticatedMessage{} + case ControlMessageType: + var base BaseControlMessage + if err := json.Unmarshal(msg, &base); err != nil { + return nil, err + } + + switch base.Event { + case JobStartEvent: + out = &JobStartControlMessage{} + case EndOfChannelEvent: + out = &EndOfChannelControlMessage{} + case ChannelAbortEvent: + out = &ChannelAbortControlMessage{} + default: + return &base, nil + } + case ErrorType: + out = &ErrorMessage{} + case MetadataType: + out = &MetadataMessage{} + case ExpiredTSIDType: + out = &ExpiredTSIDMessage{} + case MessageType: + out = &InfoMessage{} + case EventType: + out = &EventMessage{} + default: + out = &BaseJSONMessage{} + } + err := json.Unmarshal(msg, out) + out.JSONBase().rawMessage = msg + return out, err +} diff --git a/signalflow/messages/metadata.go b/signalflow/messages/metadata.go new file mode 100644 index 0000000..69059fc --- /dev/null +++ b/signalflow/messages/metadata.go @@ -0,0 +1,77 @@ +package messages + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/signalfx/signalfx-go/idtool" +) + +type MetadataMessage struct { + BaseJSONChannelMessage + TSID idtool.ID `json:"tsId"` + Properties MetadataProperties `json:"properties"` +} + +type MetadataProperties struct { + Metric string `json:"sf_metric"` + OriginatingMetric string `json:"sf_originatingMetric"` + ResolutionMS int `json:"sf_resolutionMs"` + CreatedOnMS int `json:"sf_createdOnMs"` + // Additional SignalFx-generated properties about this time series. Many + // of these are exposed directly in fields on this struct. + InternalProperties map[string]interface{} `json:"-"` + // Custom properties applied to the timeseries through various means, + // including dimensions, properties on matching dimensions, etc. + CustomProperties map[string]string `json:"-"` +} + +func (mp *MetadataProperties) UnmarshalJSON(b []byte) error { + // Deserialize it at first to get all the well-known fields put in place so + // we don't have to manually assign them below. + type Alias MetadataProperties + if err := json.Unmarshal(b, (*Alias)(mp)); err != nil { + return err + } + + // Deserialize it again to a generic map so we can get at all the fields. + var propMap map[string]interface{} + if err := json.Unmarshal(b, &propMap); err != nil { + return err + } + + mp.InternalProperties = make(map[string]interface{}) + mp.CustomProperties = make(map[string]string) + for k, v := range propMap { + if strings.HasPrefix(k, "sf_") { + mp.InternalProperties[k] = v + } else { + mp.CustomProperties[k] = fmt.Sprintf("%v", v) + } + } + return nil +} + +func (mp *MetadataProperties) MarshalJSON() ([]byte, error) { + type Alias MetadataProperties + intermediate, err := json.Marshal((*Alias)(mp)) + if err != nil { + return nil, err + } + + out := map[string]interface{}{} + err = json.Unmarshal(intermediate, &out) + if err != nil { + return nil, err + } + + for k, v := range mp.InternalProperties { + out[k] = v + } + for k, v := range mp.CustomProperties { + out[k] = v + } + + return json.Marshal(out) +} diff --git a/signalflow/messages/types.go b/signalflow/messages/types.go new file mode 100644 index 0000000..00ccd01 --- /dev/null +++ b/signalflow/messages/types.go @@ -0,0 +1,126 @@ +package messages + +import ( + "encoding/json" + "fmt" + "time" +) + +// See https://developers.signalfx.com/signalflow_analytics/rest_api_messages/stream_messages_specification.html +const ( + AuthenticatedType = "authenticated" + ControlMessageType = "control-message" + ErrorType = "error" + MetadataType = "metadata" + MessageType = "message" + DataType = "data" + EventType = "event" + WebsocketErrorType = "websocket-error" + ExpiredTSIDType = "expired-tsid" +) + +type BaseMessage struct { + Typ string `json:"type"` +} + +func (bm *BaseMessage) Type() string { + return bm.Typ +} + +func (bm *BaseMessage) String() string { + return fmt.Sprintf("%s message", bm.Typ) +} + +func (bm *BaseMessage) Base() *BaseMessage { + return bm +} + +var _ Message = &BaseMessage{} + +type Message interface { + Type() string + Base() *BaseMessage +} + +type ChannelMessage interface { + Channel() string +} + +type BaseChannelMessage struct { + Chan string `json:"channel,omitempty"` +} + +func (bcm *BaseChannelMessage) Channel() string { + return bcm.Chan +} + +type JSONMessage interface { + Message + JSONBase() *BaseJSONMessage + RawData() map[string]interface{} +} + +type BaseJSONMessage struct { + BaseMessage + rawMessage []byte + rawData map[string]interface{} +} + +func (j *BaseJSONMessage) JSONBase() *BaseJSONMessage { + return j +} + +// The raw message deserialized from JSON. Only applicable for JSON +// Useful if the message type doesn't have a concrete struct type implemented +// in this library (e.g. due to an upgrade to the SignalFlow protocol). +func (j *BaseJSONMessage) RawData() map[string]interface{} { + if j.rawData == nil { + if err := json.Unmarshal(j.rawMessage, &j.rawData); err != nil { + // This shouldn't ever error since it wouldn't have been initially + // deserialized if there were parse errors. But in case it does + // just return nil. + return nil + } + } + return j.rawData +} + +func (j *BaseJSONMessage) String() string { + return j.BaseMessage.String() + string(j.rawMessage) +} + +type BaseJSONChannelMessage struct { + BaseJSONMessage + BaseChannelMessage +} + +func (j *BaseJSONChannelMessage) String() string { + return string(j.BaseJSONMessage.rawMessage) +} + +type TimestampedMessage struct { + TimestampMillis uint64 `json:"timestampMs"` +} + +func (m *TimestampedMessage) Timestamp() time.Time { + return time.Unix(0, int64(m.TimestampMillis*uint64(time.Millisecond))) +} + +type AuthenticatedMessage struct { + BaseJSONMessage + OrgID string `json:"orgId"` + UserID string `json:"userId"` +} + +// The way to distinguish between JSON and binary messages is the websocket +// message type. +func ParseMessage(msg []byte, isText bool) (Message, error) { + if isText { + var baseMessage BaseMessage + if err := json.Unmarshal(msg, &baseMessage); err != nil { + return nil, fmt.Errorf("couldn't unmarshal JSON websocket message: %w", err) + } + return parseJSONMessage(&baseMessage, msg) + } + return parseBinaryMessage(msg) +} diff --git a/signalflow/messages/types_test.go b/signalflow/messages/types_test.go new file mode 100644 index 0000000..229ee50 --- /dev/null +++ b/signalflow/messages/types_test.go @@ -0,0 +1,18 @@ +package messages + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseMessage(t *testing.T) { + t.Run("END_OF_CHANNEL", func(t *testing.T) { + msg, err := ParseMessage([]byte(` + {"channel": "ch-1", "event": "END_OF_CHANNEL", "timestampMs": 1607115512410, "type": "control-message"} + `), true) + + require.NoError(t, err) + require.IsType(t, &EndOfChannelControlMessage{}, msg) + }) +} diff --git a/signalflow/requests.go b/signalflow/requests.go new file mode 100644 index 0000000..57227a0 --- /dev/null +++ b/signalflow/requests.go @@ -0,0 +1,91 @@ +package signalflow + +import ( + "encoding/json" + "time" +) + +type AuthType string + +func (at AuthType) MarshalJSON() ([]byte, error) { + return []byte(`"authenticate"`), nil +} + +type AuthRequest struct { + // This should not be set manually. + Type AuthType `json:"type"` + // The Auth token for the org + Token string `json:"token"` + UserAgent string `json:"userAgent,omitempty"` +} + +type ExecuteType string + +func (ExecuteType) MarshalJSON() ([]byte, error) { + return []byte(`"execute"`), nil +} + +// See +// https://dev.splunk.com/observability/docs/signalflow/messages/websocket_request_messages#Execute-message-properties +// for details on the fields. +type ExecuteRequest struct { + // This should not be set manually + Type ExecuteType `json:"type"` + Program string `json:"program"` + Channel string `json:"channel"` + Start time.Time `json:"-"` + Stop time.Time `json:"-"` + Resolution time.Duration `json:"-"` + MaxDelay time.Duration `json:"-"` + StartMs int64 `json:"start"` + StopMs int64 `json:"stop"` + ResolutionMs int64 `json:"resolution"` + MaxDelayMs int64 `json:"maxDelay"` + Immediate bool `json:"immediate"` + Timezone string `json:"timezone"` +} + +// MarshalJSON does some assignments to allow using more native Go types for +// time/duration. +func (er ExecuteRequest) MarshalJSON() ([]byte, error) { + if !er.Start.IsZero() { + er.StartMs = er.Start.UnixNano() / int64(time.Millisecond) + } + if !er.Stop.IsZero() { + er.StopMs = er.Stop.UnixNano() / int64(time.Millisecond) + } + if er.Resolution != 0 { + er.ResolutionMs = er.Resolution.Nanoseconds() / int64(time.Millisecond) + } + if er.MaxDelay != 0 { + er.MaxDelayMs = er.MaxDelay.Nanoseconds() / int64(time.Millisecond) + } + type alias ExecuteRequest + return json.Marshal(alias(er)) +} + +type DetachType string + +func (DetachType) MarshalJSON() ([]byte, error) { + return []byte(`"detach"`), nil +} + +type DetachRequest struct { + // This should not be set manually + Type DetachType `json:"type"` + Channel string `json:"channel"` + Reason string `json:"reason"` +} + +type StopType string + +func (StopType) MarshalJSON() ([]byte, error) { + return []byte(`"stop"`), nil +} + +type StopRequest struct { + // This should not be set manually + Type StopType `json:"type"` + Handle string `json:"handle"` + Reason string `json:"reason"` +} diff --git a/signalflow/requests_test.go b/signalflow/requests_test.go new file mode 100644 index 0000000..00fd9f1 --- /dev/null +++ b/signalflow/requests_test.go @@ -0,0 +1,25 @@ +package signalflow + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSerializeExecuteRequest(t *testing.T) { + t.Parallel() + er := ExecuteRequest{ + Program: "data(cpu.utilization).publish()", + Start: time.Unix(5000, 0), + Stop: time.Unix(6000, 0), + Resolution: 5 * time.Second, + MaxDelay: 3 * time.Second, + StopMs: 6500, + } + + serialized, err := json.Marshal(er) + require.Nil(t, err) + require.Equal(t, `{"type":"execute","program":"data(cpu.utilization).publish()","channel":"","start":5000000,"stop":6000000,"resolution":5000,"maxDelay":3000,"immediate":false,"timezone":""}`, string(serialized)) +}