Skip to content

Commit

Permalink
Feat: add state snapshot events, marker task tracking, etc. (#713)
Browse files Browse the repository at this point in the history
- Add network state polling
- Consolidate events
- Tie markers to specifics tasks for easy tracing
  • Loading branch information
luke-lombardi authored Nov 14, 2024
1 parent 0682796 commit c58525b
Show file tree
Hide file tree
Showing 19 changed files with 562 additions and 168 deletions.
7 changes: 5 additions & 2 deletions pkg/abstractions/experimental/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ func (s *PetriBotService) PushBotEvent(ctx context.Context, in *pb.PushBotEventR
}

err = instance.botStateManager.pushEvent(instance.workspace.Name, instance.stub.ExternalId, in.SessionId, &BotEvent{
Type: BotEventType(in.EventType),
Value: in.EventValue,
Type: BotEventType(in.EventType),
Value: in.EventValue,
Metadata: in.Metadata,
})
if err != nil {
return &pb.PushBotEventResponse{Ok: false}, nil
Expand Down Expand Up @@ -233,7 +234,9 @@ func (s *PetriBotService) PushBotMarkers(ctx context.Context, in *pb.PushBotMark
marker := Marker{
LocationName: marker.LocationName,
Fields: fields,
SourceTaskId: in.SourceTaskId,
}

err = s.botStateManager.pushMarker(instance.workspace.Name, instance.stub.ExternalId, in.SessionId, locationName, marker)
if err != nil {
log.Printf("<bot %s> Failed to push marker: %s", instance.stub.ExternalId, err)
Expand Down
3 changes: 3 additions & 0 deletions pkg/abstractions/experimental/bot/bot.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ message MarkerField {
message Marker {
string location_name = 1;
repeated MarkerField fields = 2;
string source_task_id = 3;
}

message PushBotMarkersRequest {
string stub_id = 1;
string session_id = 2;
map<string, MarkerList> markers = 3;
message MarkerList { repeated Marker markers = 4; }
string source_task_id = 5;
}

message PushBotMarkersResponse { bool ok = 1; }
Expand All @@ -47,6 +49,7 @@ message PushBotEventRequest {
string session_id = 2;
string event_type = 3;
string event_value = 4;
map<string, string> metadata = 5;
}

message PushBotEventResponse { bool ok = 1; }
25 changes: 10 additions & 15 deletions pkg/abstractions/experimental/bot/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func registerBotRoutes(g *echo.Group, pbs *PetriBotService) *botGroup {
}

const keepAliveInterval = 1 * time.Second
const eventPollingInterval = 1 * time.Second
const eventPollingInterval = 500 * time.Millisecond

func (g *botGroup) BotOpenSession(ctx echo.Context) error {
cc, _ := ctx.(*auth.HttpAuthContext)
Expand Down Expand Up @@ -125,6 +125,9 @@ func (g *botGroup) BotOpenSession(ctx echo.Context) error {
err = instance.botStateManager.pushEvent(instance.workspace.Name, instance.stub.ExternalId, sessionId, &BotEvent{
Type: BotEventTypeSessionCreated,
Value: sessionId,
Metadata: map[string]string{
string(MetadataSessionId): sessionId,
},
})
if err != nil {
return err
Expand Down Expand Up @@ -183,17 +186,8 @@ func (g *botGroup) BotOpenSession(ctx echo.Context) error {
continue
}

if event.Type == BotEventTypeUserMessage {
instance.botInterface.SendPrompt(sessionId, PromptTypeUser, event.Value)
continue
} else if event.Type == BotEventTypeTransitionMessage {
instance.botInterface.SendPrompt(sessionId, PromptTypeTransition, event.Value)
continue
} else if event.Type == BotEventTypeMemoryMessage {
instance.botInterface.SendPrompt(sessionId, PromptTypeMemory, event.Value)
continue
}

// Handle event and echo it back to the client
instance.eventChan <- event
serializedEvent, err := json.Marshal(event)
if err != nil {
continue
Expand All @@ -216,12 +210,13 @@ func (g *botGroup) BotOpenSession(ctx echo.Context) error {
break
}

var userRequest UserRequest
if err := json.Unmarshal(message, &userRequest); err != nil {
var event BotEvent
if err := json.Unmarshal(message, &event); err != nil {
continue
}

if err := instance.botStateManager.pushInputMessage(instance.workspace.Name, instance.stub.ExternalId, sessionId, userRequest.Msg); err != nil {
event.Metadata[string(MetadataSessionId)] = sessionId
if err := instance.botStateManager.pushUserEvent(instance.workspace.Name, instance.stub.ExternalId, sessionId, &event); err != nil {
continue
}
}
Expand Down
158 changes: 145 additions & 13 deletions pkg/abstractions/experimental/bot/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package bot

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -36,6 +37,7 @@ type botInstance struct {
taskDispatcher *task.Dispatcher
authInfo *auth.AuthInfo
containerRepo repository.ContainerRepository
eventChan chan *BotEvent
}

type botInstanceOpts struct {
Expand Down Expand Up @@ -65,7 +67,7 @@ func newBotInstance(ctx context.Context, opts botInstanceOpts) (*botInstance, er
return nil, err
}

return &botInstance{
instance := &botInstance{
ctx: ctx,
appConfig: opts.AppConfig,
token: opts.Token,
Expand All @@ -83,7 +85,12 @@ func newBotInstance(ctx context.Context, opts botInstanceOpts) (*botInstance, er
Token: opts.Token,
},
containerRepo: opts.ContainerRepo,
}, nil
eventChan: make(chan *BotEvent),
}

go instance.monitorEvents()
go instance.sendNetworkState()
return instance, nil
}

func (i *botInstance) containersBySessionId() (map[string][]string, error) {
Expand Down Expand Up @@ -141,10 +148,8 @@ func (i *botInstance) Start() error {

lastActiveSessionAt = time.Now().Unix()
for _, session := range activeSessions {
if msg, err := i.botStateManager.popInputMessage(i.workspace.Name, i.stub.ExternalId, session.Id); err == nil {
if err := i.botInterface.SendPrompt(session.Id, PromptTypeUser, msg); err != nil {
continue
}
if event, err := i.botStateManager.popUserEvent(i.workspace.Name, i.stub.ExternalId, session.Id); err == nil {
i.eventChan <- event
}

// Run any network transitions that can run
Expand Down Expand Up @@ -218,20 +223,148 @@ func (i *botInstance) step(sessionId string) {
},
}

// If this transition requires explicit confirmation, we need to send a confirmation request before executing the task
if transition.Confirm {
t, err := i.taskDispatcher.Send(i.ctx, string(types.ExecutorBot), i.authInfo, i.stub.ExternalId, taskPayload, getDefaultTaskPolicy())
if err != nil {
i.handleTransitionFailed(sessionId, transition.Name, err)
continue
}

i.botStateManager.pushEvent(i.workspace.Name, i.stub.ExternalId, sessionId, &BotEvent{
Type: BotEventTypeConfirmRequest,
Value: transition.Name,
Metadata: map[string]string{
string(MetadataSessionId): sessionId,
string(MetadataTransitionName): transition.Name,
string(MetadataTaskId): t.Metadata().TaskId,
},
})

continue
}

t, err := i.taskDispatcher.SendAndExecute(i.ctx, string(types.ExecutorBot), i.authInfo, i.stub.ExternalId, taskPayload, getDefaultTaskPolicy())
if err != nil {
i.handleTransitionFailed(sessionId, transition.Name, err)
continue
}

i.botStateManager.pushEvent(i.workspace.Name, i.stub.ExternalId, sessionId, &BotEvent{
Type: BotEventTypeTransitionFired,
Value: transition.Name,
Metadata: map[string]string{
string(MetadataSessionId): sessionId,
string(MetadataTransitionName): transition.Name,
string(MetadataTaskId): t.Metadata().TaskId,
},
})
}
}

}()
}

func (i *botInstance) sendNetworkState() {
for {
select {
case <-i.ctx.Done():
return
case <-time.After(time.Second):
activeSessions, err := i.botStateManager.getActiveSessions(i.workspace.Name, i.stub.ExternalId)
if err != nil || len(activeSessions) == 0 {
continue
}

i.taskDispatcher.SendAndExecute(i.ctx, string(types.ExecutorBot), i.authInfo, i.stub.ExternalId, taskPayload, types.TaskPolicy{
MaxRetries: 0,
Timeout: 3600,
TTL: 3600,
Expires: time.Now().Add(time.Duration(3600) * time.Second),
for _, session := range activeSessions {
state := &BotNetworkSnapshot{
SessionId: session.Id,
LocationMarkerCounts: make(map[string]int64),
Config: i.botConfig,
}

for locationName := range i.botConfig.Locations {
count, err := i.botStateManager.countMarkers(i.workspace.Name, i.stub.ExternalId, session.Id, locationName)
if err != nil {
continue
}
state.LocationMarkerCounts[locationName] = count
}

stateJson, err := json.Marshal(state)
if err != nil {
continue
}

i.botStateManager.pushEvent(i.workspace.Name, i.stub.ExternalId, session.Id, &BotEvent{
Type: BotEventTypeNetworkState,
Value: string(stateJson),
Metadata: map[string]string{
string(MetadataSessionId): session.Id,
},
})
}
}
}()
}
}

func (i *botInstance) handleTransitionFailed(sessionId, transitionName string, err error) {
i.botStateManager.pushEvent(i.workspace.Name, i.stub.ExternalId, sessionId, &BotEvent{
Type: BotEventTypeTransitionFailed,
Value: transitionName,
Metadata: map[string]string{
string(MetadataSessionId): sessionId,
string(MetadataTransitionName): transitionName,
string(MetadataErrorMsg): err.Error(),
},
})
}

func getDefaultTaskPolicy() types.TaskPolicy {
return types.TaskPolicy{
MaxRetries: 0,
Timeout: 3600,
TTL: 3600,
Expires: time.Now().Add(time.Duration(3600) * time.Second),
}
}

func (i *botInstance) monitorEvents() error {
for {
select {
case <-i.ctx.Done():
return nil
case event := <-i.eventChan:
sessionId := event.Metadata[string(MetadataSessionId)]

switch event.Type {
case BotEventTypeUserMessage:
i.botInterface.SendPrompt(sessionId, PromptTypeUser, &PromptRequest{Msg: event.Value, RequestId: event.Metadata[string(MetadataRequestId)]})
case BotEventTypeTransitionMessage:
i.botInterface.SendPrompt(sessionId, PromptTypeTransition, &PromptRequest{Msg: event.Value})
case BotEventTypeMemoryMessage:
i.botInterface.SendPrompt(sessionId, PromptTypeMemory, &PromptRequest{Msg: event.Value})
case BotEventTypeConfirmResponse:
taskId := event.Metadata[string(MetadataTaskId)]
accepts := event.Metadata[string(MetadataAccept)] == "true"
transitionName := event.Metadata[string(MetadataTransitionName)]

task, err := i.taskDispatcher.Retrieve(i.ctx, i.workspace.Name, i.stub.ExternalId, taskId)
if err != nil {
continue
}

if accepts {
err = task.Execute(i.ctx)
if err != nil {
i.handleTransitionFailed(sessionId, transitionName, err)
}
} else {
task.Cancel(i.ctx, types.TaskRequestCancelled)
}
}
}
}
}

func (i *botInstance) run(transitionName, sessionId, taskId string) error {
Expand Down Expand Up @@ -290,7 +423,6 @@ func (i *botInstance) run(transitionName, sessionId, taskId string) error {
Stub: *i.stub,
})
if err != nil {
log.Printf("<bot %s> Error running transition %s: %s", i.stub.ExternalId, transitionName, err)
return err
}

Expand Down
Loading

0 comments on commit c58525b

Please sign in to comment.