Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add min_containers to QueueDepthAutoscaler and instance controller for loading and reloading instances #846

Merged
14 changes: 8 additions & 6 deletions pkg/abstractions/common/container_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,28 @@ import (
)

type ContainerEventManager struct {
ctx context.Context
containerPrefix string
keyEventChan chan common.KeyEvent
keyEventManager *common.KeyEventManager
instanceFactory func(stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)
instanceFactory func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)
}

func NewContainerEventManager(containerPrefix string, keyEventManager *common.KeyEventManager, instanceFactory func(stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)) (*ContainerEventManager, error) {
func NewContainerEventManager(ctx context.Context, containerPrefix string, keyEventManager *common.KeyEventManager, instanceFactory func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)) (*ContainerEventManager, error) {
keyEventChan := make(chan common.KeyEvent)

return &ContainerEventManager{
ctx: ctx,
containerPrefix: containerPrefix,
instanceFactory: instanceFactory,
keyEventChan: keyEventChan,
keyEventManager: keyEventManager,
}, nil
}

func (em *ContainerEventManager) Listen(ctx context.Context) {
go em.keyEventManager.ListenForPattern(ctx, common.RedisKeys.SchedulerContainerState(em.containerPrefix), em.keyEventChan)
go em.handleContainerEvents(ctx)
func (em *ContainerEventManager) Listen() {
go em.keyEventManager.ListenForPattern(em.ctx, common.RedisKeys.SchedulerContainerState(em.containerPrefix), em.keyEventChan)
go em.handleContainerEvents(em.ctx)
}

func (em *ContainerEventManager) handleContainerEvents(ctx context.Context) {
Expand Down Expand Up @@ -59,7 +61,7 @@ func (em *ContainerEventManager) handleContainerEvents(ctx context.Context) {
containerIdParts := strings.Split(containerId, "-")
stubId := strings.Join(containerIdParts[1:6], "-")

instance, err := em.instanceFactory(stubId)
instance, err := em.instanceFactory(em.ctx, stubId)
if err != nil {
continue
}
Expand Down
110 changes: 109 additions & 1 deletion pkg/abstractions/common/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"github.com/beam-cloud/beta9/pkg/scheduler"
"github.com/beam-cloud/beta9/pkg/types"
"github.com/rs/zerolog/log"
"k8s.io/utils/ptr"
)

const IgnoreScalingEventInterval = 10 * time.Second

type IAutoscaledInstance interface {
ConsumeScaleResult(*AutoscalerResult)
ConsumeContainerEvent(types.ContainerEvent)
HandleScalingEvent(int) error
Reload() error
}

type AutoscaledInstanceState struct {
Expand Down Expand Up @@ -175,7 +178,7 @@ func (i *AutoscaledInstance) WaitForContainer(ctx context.Context, duration time
}

func (i *AutoscaledInstance) ConsumeScaleResult(result *AutoscalerResult) {
i.ScaleEventChan <- result.DesiredContainers
i.ScaleEventChan <- max(result.DesiredContainers, int(i.StubConfig.Autoscaler.MinContainers))
}

func (i *AutoscaledInstance) ConsumeContainerEvent(event types.ContainerEvent) {
Expand Down Expand Up @@ -321,3 +324,108 @@ func (i *AutoscaledInstance) emitUnhealthyEvent(stubId, currentState, reason str
log.Info().Str("instance_name", i.Name).Msgf("%s\n", reason)
go i.EventRepo.PushStubStateUnhealthy(i.Workspace.ExternalId, stubId, currentState, state, reason, containers)
}

type Controller struct {
ctx context.Context
getOrCreateInstance func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)
StubTypes []string
backendRepo repository.BackendRepository
redisClient *common.RedisClient
}

func NewController(
ctx context.Context,
getOrCreateInstance func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error),
stubTypes []string,
backendRepo repository.BackendRepository,
redisClient *common.RedisClient,
) *Controller {
return &Controller{
ctx: ctx,
getOrCreateInstance: getOrCreateInstance,
StubTypes: stubTypes,
backendRepo: backendRepo,
redisClient: redisClient,
}
}

func (c *Controller) Init() error {
eventBus := common.NewEventBus(
c.redisClient,
common.EventBusSubscriber{Type: common.EventTypeReloadInstance, Callback: func(e *common.Event) bool {
stubId := e.Args["stub_id"].(string)
stubType := e.Args["stub_type"].(string)

correctStub := false
for _, t := range c.StubTypes {
if t == stubType {
correctStub = true
break
}
}

if !correctStub {
return false
}

if err := c.reload(stubId); err != nil {
return false
}

return true
}},
)
go eventBus.ReceiveEvents(c.ctx)

if err := c.load(); err != nil {
return err
}

return nil
}

func (c *Controller) Warmup(
ctx context.Context,
stubId string,
) error {
instance, err := c.getOrCreateInstance(ctx, stubId)
if err != nil {
return err
}

return instance.HandleScalingEvent(1)
}

func (c *Controller) load() error {
stubs, err := c.backendRepo.ListDeploymentsWithRelated(
c.ctx,
types.DeploymentFilter{
StubType: c.StubTypes,
MinContainersGTE: 1,
Active: ptr.To(true),
},
)
if err != nil {
return err
}

for _, stub := range stubs {
_, err := c.getOrCreateInstance(c.ctx, stub.Stub.ExternalId)
if err != nil {
return err
}
}

return nil
}

func (c *Controller) reload(stubId string) error {
instance, err := c.getOrCreateInstance(c.ctx, stubId)
if err != nil {
return err
}

instance.Reload()

return nil
}
74 changes: 23 additions & 51 deletions pkg/abstractions/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,20 @@ type EndpointService interface {

type HttpEndpointService struct {
pb.UnimplementedEndpointServiceServer
ctx context.Context
config types.AppConfig
rdb *common.RedisClient
keyEventManager *common.KeyEventManager
scheduler *scheduler.Scheduler
backendRepo repository.BackendRepository
workspaceRepo repository.WorkspaceRepository
containerRepo repository.ContainerRepository
eventRepo repository.EventRepository
taskRepo repository.TaskRepository
endpointInstances *common.SafeMap[*endpointInstance]
tailscale *network.Tailscale
taskDispatcher *task.Dispatcher
ctx context.Context
config types.AppConfig
rdb *common.RedisClient
keyEventManager *common.KeyEventManager
scheduler *scheduler.Scheduler
backendRepo repository.BackendRepository
workspaceRepo repository.WorkspaceRepository
containerRepo repository.ContainerRepository
eventRepo repository.EventRepository
taskRepo repository.TaskRepository
endpointInstances *common.SafeMap[*endpointInstance]
tailscale *network.Tailscale
taskDispatcher *task.Dispatcher
instanceController *abstractions.Controller
}

var (
Expand Down Expand Up @@ -102,34 +103,17 @@ func NewHTTPEndpointService(

// Listen for container events with a certain prefix
// For example if a container is created, destroyed, or updated
eventManager, err := abstractions.NewContainerEventManager(endpointContainerPrefix, keyEventManager, es.InstanceFactory)
eventManager, err := abstractions.NewContainerEventManager(ctx, endpointContainerPrefix, keyEventManager, es.InstanceFactory)
if err != nil {
return nil, err
}
eventManager.Listen(ctx)
eventManager.Listen()

eventBus := common.NewEventBus(
opts.RedisClient,
common.EventBusSubscriber{Type: common.EventTypeReloadInstance, Callback: func(e *common.Event) bool {
stubId := e.Args["stub_id"].(string)
stubType := e.Args["stub_type"].(string)

if stubType != types.StubTypeEndpointDeployment && stubType != types.StubTypeASGIDeployment {
// Assume the callback succeeded to avoid retries
return true
}

instance, err := es.getOrCreateEndpointInstance(es.ctx, stubId)
if err != nil {
return false
}

instance.Reload()

return true
}},
)
go eventBus.ReceiveEvents(ctx)
es.instanceController = abstractions.NewController(ctx, es.InstanceFactory, []string{types.StubTypeEndpointDeployment, types.StubTypeASGIDeployment}, es.backendRepo, es.rdb)
err = es.instanceController.Init()
if err != nil {
return nil, err
}

// Register task dispatcher
es.taskDispatcher.Register(string(types.ExecutorEndpoint), es.endpointTaskFactory)
Expand Down Expand Up @@ -205,20 +189,8 @@ func (es *HttpEndpointService) forwardRequest(
return task.Execute(ctx.Request().Context(), ctx)
}

func (es *HttpEndpointService) warmup(
ctx echo.Context,
stubId string,
) error {
instance, err := es.getOrCreateEndpointInstance(ctx.Request().Context(), stubId)
if err != nil {
return err
}

return instance.HandleScalingEvent(1)
}

func (es *HttpEndpointService) InstanceFactory(stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
return es.getOrCreateEndpointInstance(es.ctx, stubId)
func (es *HttpEndpointService) InstanceFactory(ctx context.Context, stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
return es.getOrCreateEndpointInstance(ctx, stubId)
}

func (es *HttpEndpointService) getOrCreateEndpointInstance(ctx context.Context, stubId string, options ...func(*endpointInstance)) (*endpointInstance, error) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/abstractions/endpoint/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ func (g *endpointGroup) warmup(
return err
}

return g.es.warmup(
ctx,
return g.es.instanceController.Warmup(
ctx.Request().Context(),
stubId,
)
}
3 changes: 2 additions & 1 deletion pkg/abstractions/taskqueue/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ func (g *taskQueueGroup) TaskQueueWarmUp(ctx echo.Context) error {
return err
}

err = g.tq.warmup(
err = g.tq.instanceController.Warmup(
ctx.Request().Context(),
stubId,
)
if err != nil {
Expand Down
54 changes: 15 additions & 39 deletions pkg/abstractions/taskqueue/taskqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ type RedisTaskQueue struct {
backendRepo repository.BackendRepository
scheduler *scheduler.Scheduler
pb.UnimplementedTaskQueueServiceServer
queueInstances *common.SafeMap[*taskQueueInstance]
keyEventManager *common.KeyEventManager
queueClient *taskQueueClient
tailscale *network.Tailscale
eventRepo repository.EventRepository
queueInstances *common.SafeMap[*taskQueueInstance]
keyEventManager *common.KeyEventManager
queueClient *taskQueueClient
tailscale *network.Tailscale
eventRepo repository.EventRepository
instanceController *abstractions.Controller
}

func NewRedisTaskQueueService(
Expand Down Expand Up @@ -103,34 +104,18 @@ func NewRedisTaskQueueService(

// Listen for container events with a certain prefix
// For example if a container is created, destroyed, or updated
eventManager, err := abstractions.NewContainerEventManager(taskQueueContainerPrefix, keyEventManager, tq.InstanceFactory)
eventManager, err := abstractions.NewContainerEventManager(ctx, taskQueueContainerPrefix, keyEventManager, tq.InstanceFactory)
if err != nil {
return nil, err
}
eventManager.Listen(ctx)
eventManager.Listen()

eventBus := common.NewEventBus(
opts.RedisClient,
common.EventBusSubscriber{Type: common.EventTypeReloadInstance, Callback: func(e *common.Event) bool {
stubId := e.Args["stub_id"].(string)
stubType := e.Args["stub_type"].(string)

if stubType != types.StubTypeTaskQueueDeployment {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jsun-m the place where you're doing the correctStub check, you need to return true I think or else it will resend the event

// Assume the callback succeeded to avoid retries
return true
}

instance, err := tq.getOrCreateQueueInstance(stubId)
if err != nil {
return false
}

instance.Reload()

return true
}},
)
go eventBus.ReceiveEvents(ctx)
// Initialize deployment manager
tq.instanceController = abstractions.NewController(ctx, tq.InstanceFactory, []string{types.StubTypeTaskQueueDeployment}, opts.BackendRepo, opts.RedisClient)
err = tq.instanceController.Init()
if err != nil {
return nil, err
}

// Register task dispatcher
tq.taskDispatcher.Register(string(types.ExecutorTaskQueue), tq.taskQueueTaskFactory)
Expand Down Expand Up @@ -184,15 +169,6 @@ func (tq *RedisTaskQueue) getStubConfig(stubId string) (*types.StubConfigV1, err
return config, nil
}

func (tq *RedisTaskQueue) warmup(stubId string) error {
instance, err := tq.getOrCreateQueueInstance(stubId)
if err != nil {
return err
}

return instance.HandleScalingEvent(1)
}

func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stubId string, payload *types.TaskPayload) (string, error) {
stubConfig, err := tq.getStubConfig(stubId)
if err != nil {
Expand Down Expand Up @@ -547,7 +523,7 @@ func (tq *RedisTaskQueue) TaskQueueLength(ctx context.Context, in *pb.TaskQueueL
}, nil
}

func (tq *RedisTaskQueue) InstanceFactory(stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
func (tq *RedisTaskQueue) InstanceFactory(ctx context.Context, stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
return tq.getOrCreateQueueInstance(stubId)
}

Expand Down
Loading
Loading