Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add min_containers to QueueDepthAutoscaler and instance controller for loading and reloading instances #846

Merged
14 changes: 8 additions & 6 deletions pkg/abstractions/common/container_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,28 @@ import (
)

type ContainerEventManager struct {
ctx context.Context
containerPrefix string
keyEventChan chan common.KeyEvent
keyEventManager *common.KeyEventManager
instanceFactory func(stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)
instanceFactory func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)
}

func NewContainerEventManager(containerPrefix string, keyEventManager *common.KeyEventManager, instanceFactory func(stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)) (*ContainerEventManager, error) {
func NewContainerEventManager(ctx context.Context, containerPrefix string, keyEventManager *common.KeyEventManager, instanceFactory func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)) (*ContainerEventManager, error) {
keyEventChan := make(chan common.KeyEvent)

return &ContainerEventManager{
ctx: ctx,
containerPrefix: containerPrefix,
instanceFactory: instanceFactory,
keyEventChan: keyEventChan,
keyEventManager: keyEventManager,
}, nil
}

func (em *ContainerEventManager) Listen(ctx context.Context) {
go em.keyEventManager.ListenForPattern(ctx, common.RedisKeys.SchedulerContainerState(em.containerPrefix), em.keyEventChan)
go em.handleContainerEvents(ctx)
func (em *ContainerEventManager) Listen() {
go em.keyEventManager.ListenForPattern(em.ctx, common.RedisKeys.SchedulerContainerState(em.containerPrefix), em.keyEventChan)
go em.handleContainerEvents(em.ctx)
}

func (em *ContainerEventManager) handleContainerEvents(ctx context.Context) {
Expand Down Expand Up @@ -59,7 +61,7 @@ func (em *ContainerEventManager) handleContainerEvents(ctx context.Context) {
containerIdParts := strings.Split(containerId, "-")
stubId := strings.Join(containerIdParts[1:6], "-")

instance, err := em.instanceFactory(stubId)
instance, err := em.instanceFactory(em.ctx, stubId)
if err != nil {
continue
}
Expand Down
110 changes: 109 additions & 1 deletion pkg/abstractions/common/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"github.com/beam-cloud/beta9/pkg/scheduler"
"github.com/beam-cloud/beta9/pkg/types"
"github.com/rs/zerolog/log"
"k8s.io/utils/ptr"
)

const IgnoreScalingEventInterval = 10 * time.Second

type IAutoscaledInstance interface {
ConsumeScaleResult(*AutoscalerResult)
ConsumeContainerEvent(types.ContainerEvent)
HandleScalingEvent(int) error
Reload() error
}

type AutoscaledInstanceState struct {
Expand Down Expand Up @@ -175,7 +178,7 @@ func (i *AutoscaledInstance) WaitForContainer(ctx context.Context, duration time
}

func (i *AutoscaledInstance) ConsumeScaleResult(result *AutoscalerResult) {
i.ScaleEventChan <- result.DesiredContainers
i.ScaleEventChan <- max(result.DesiredContainers, int(i.StubConfig.Autoscaler.MinContainers))
}

func (i *AutoscaledInstance) ConsumeContainerEvent(event types.ContainerEvent) {
Expand Down Expand Up @@ -321,3 +324,108 @@ func (i *AutoscaledInstance) emitUnhealthyEvent(stubId, currentState, reason str
log.Info().Str("instance_name", i.Name).Msgf("%s\n", reason)
go i.EventRepo.PushStubStateUnhealthy(i.Workspace.ExternalId, stubId, currentState, state, reason, containers)
}

type InstanceController struct {
luke-lombardi marked this conversation as resolved.
Show resolved Hide resolved
ctx context.Context
getOrCreateInstance func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error)
StubTypes []string
backendRepo repository.BackendRepository
redisClient *common.RedisClient
}

func NewController(
ctx context.Context,
getOrCreateInstance func(ctx context.Context, stubId string, options ...func(IAutoscaledInstance)) (IAutoscaledInstance, error),
stubTypes []string,
backendRepo repository.BackendRepository,
redisClient *common.RedisClient,
) *InstanceController {
return &InstanceController{
ctx: ctx,
getOrCreateInstance: getOrCreateInstance,
StubTypes: stubTypes,
backendRepo: backendRepo,
redisClient: redisClient,
}
}

func (c *InstanceController) Init() error {
eventBus := common.NewEventBus(
c.redisClient,
common.EventBusSubscriber{Type: common.EventTypeReloadInstance, Callback: func(e *common.Event) bool {
stubId := e.Args["stub_id"].(string)
stubType := e.Args["stub_type"].(string)

correctStub := false
for _, t := range c.StubTypes {
if t == stubType {
correctStub = true
break
}
}

if !correctStub {
return true
}

if err := c.reload(stubId); err != nil {
return false
}

return true
}},
)
go eventBus.ReceiveEvents(c.ctx)

if err := c.load(); err != nil {
return err
}

return nil
}

func (c *InstanceController) Warmup(
ctx context.Context,
stubId string,
) error {
instance, err := c.getOrCreateInstance(ctx, stubId)
if err != nil {
return err
}

return instance.HandleScalingEvent(1)
}

func (c *InstanceController) load() error {
stubs, err := c.backendRepo.ListDeploymentsWithRelated(
c.ctx,
types.DeploymentFilter{
StubType: c.StubTypes,
MinContainersGTE: 1,
Active: ptr.To(true),
},
)
if err != nil {
return err
}

for _, stub := range stubs {
_, err := c.getOrCreateInstance(c.ctx, stub.Stub.ExternalId)
if err != nil {
return err
}
}

return nil
}

func (c *InstanceController) reload(stubId string) error {
instance, err := c.getOrCreateInstance(c.ctx, stubId)
if err != nil {
return err
}

instance.Reload()

return nil
}
48 changes: 10 additions & 38 deletions pkg/abstractions/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type HttpEndpointService struct {
endpointInstances *common.SafeMap[*endpointInstance]
tailscale *network.Tailscale
taskDispatcher *task.Dispatcher
controller *abstractions.InstanceController
}

var (
Expand Down Expand Up @@ -102,34 +103,17 @@ func NewHTTPEndpointService(

// Listen for container events with a certain prefix
// For example if a container is created, destroyed, or updated
eventManager, err := abstractions.NewContainerEventManager(endpointContainerPrefix, keyEventManager, es.InstanceFactory)
eventManager, err := abstractions.NewContainerEventManager(ctx, endpointContainerPrefix, keyEventManager, es.InstanceFactory)
if err != nil {
return nil, err
}
eventManager.Listen(ctx)
eventManager.Listen()

eventBus := common.NewEventBus(
opts.RedisClient,
common.EventBusSubscriber{Type: common.EventTypeReloadInstance, Callback: func(e *common.Event) bool {
stubId := e.Args["stub_id"].(string)
stubType := e.Args["stub_type"].(string)

if stubType != types.StubTypeEndpointDeployment && stubType != types.StubTypeASGIDeployment {
// Assume the callback succeeded to avoid retries
return true
}

instance, err := es.getOrCreateEndpointInstance(es.ctx, stubId)
if err != nil {
return false
}

instance.Reload()

return true
}},
)
go eventBus.ReceiveEvents(ctx)
es.controller = abstractions.NewController(ctx, es.InstanceFactory, []string{types.StubTypeEndpointDeployment, types.StubTypeASGIDeployment}, es.backendRepo, es.rdb)
err = es.controller.Init()
if err != nil {
return nil, err
}

// Register task dispatcher
es.taskDispatcher.Register(string(types.ExecutorEndpoint), es.endpointTaskFactory)
Expand Down Expand Up @@ -205,20 +189,8 @@ func (es *HttpEndpointService) forwardRequest(
return task.Execute(ctx.Request().Context(), ctx)
}

func (es *HttpEndpointService) warmup(
ctx echo.Context,
stubId string,
) error {
instance, err := es.getOrCreateEndpointInstance(ctx.Request().Context(), stubId)
if err != nil {
return err
}

return instance.HandleScalingEvent(1)
}

func (es *HttpEndpointService) InstanceFactory(stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
return es.getOrCreateEndpointInstance(es.ctx, stubId)
func (es *HttpEndpointService) InstanceFactory(ctx context.Context, stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
return es.getOrCreateEndpointInstance(ctx, stubId)
}

func (es *HttpEndpointService) getOrCreateEndpointInstance(ctx context.Context, stubId string, options ...func(*endpointInstance)) (*endpointInstance, error) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/abstractions/endpoint/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ func (g *endpointGroup) warmup(
return err
}

return g.es.warmup(
ctx,
return g.es.controller.Warmup(
ctx.Request().Context(),
stubId,
)
}
3 changes: 2 additions & 1 deletion pkg/abstractions/taskqueue/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ func (g *taskQueueGroup) TaskQueueWarmUp(ctx echo.Context) error {
return err
}

err = g.tq.warmup(
err = g.tq.controller.Warmup(
ctx.Request().Context(),
stubId,
)
if err != nil {
Expand Down
44 changes: 10 additions & 34 deletions pkg/abstractions/taskqueue/taskqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type RedisTaskQueue struct {
queueClient *taskQueueClient
tailscale *network.Tailscale
eventRepo repository.EventRepository
controller *abstractions.InstanceController
}

func NewRedisTaskQueueService(
Expand Down Expand Up @@ -103,34 +104,18 @@ func NewRedisTaskQueueService(

// Listen for container events with a certain prefix
// For example if a container is created, destroyed, or updated
eventManager, err := abstractions.NewContainerEventManager(taskQueueContainerPrefix, keyEventManager, tq.InstanceFactory)
eventManager, err := abstractions.NewContainerEventManager(ctx, taskQueueContainerPrefix, keyEventManager, tq.InstanceFactory)
if err != nil {
return nil, err
}
eventManager.Listen(ctx)
eventManager.Listen()

eventBus := common.NewEventBus(
opts.RedisClient,
common.EventBusSubscriber{Type: common.EventTypeReloadInstance, Callback: func(e *common.Event) bool {
stubId := e.Args["stub_id"].(string)
stubType := e.Args["stub_type"].(string)

if stubType != types.StubTypeTaskQueueDeployment {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jsun-m the place where you're doing the correctStub check, you need to return true I think or else it will resend the event

// Assume the callback succeeded to avoid retries
return true
}

instance, err := tq.getOrCreateQueueInstance(stubId)
if err != nil {
return false
}

instance.Reload()

return true
}},
)
go eventBus.ReceiveEvents(ctx)
// Initialize deployment manager
tq.controller = abstractions.NewController(ctx, tq.InstanceFactory, []string{types.StubTypeTaskQueueDeployment}, opts.BackendRepo, opts.RedisClient)
err = tq.controller.Init()
if err != nil {
return nil, err
}

// Register task dispatcher
tq.taskDispatcher.Register(string(types.ExecutorTaskQueue), tq.taskQueueTaskFactory)
Expand Down Expand Up @@ -184,15 +169,6 @@ func (tq *RedisTaskQueue) getStubConfig(stubId string) (*types.StubConfigV1, err
return config, nil
}

func (tq *RedisTaskQueue) warmup(stubId string) error {
instance, err := tq.getOrCreateQueueInstance(stubId)
if err != nil {
return err
}

return instance.HandleScalingEvent(1)
}

func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stubId string, payload *types.TaskPayload) (string, error) {
stubConfig, err := tq.getStubConfig(stubId)
if err != nil {
Expand Down Expand Up @@ -547,7 +523,7 @@ func (tq *RedisTaskQueue) TaskQueueLength(ctx context.Context, in *pb.TaskQueueL
}, nil
}

func (tq *RedisTaskQueue) InstanceFactory(stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
func (tq *RedisTaskQueue) InstanceFactory(ctx context.Context, stubId string, options ...func(abstractions.IAutoscaledInstance)) (abstractions.IAutoscaledInstance, error) {
return tq.getOrCreateQueueInstance(stubId)
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/api/v1/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,16 @@ func (g *DeploymentGroup) DeleteDeployment(ctx echo.Context) error {
return HTTPBadRequest("Deployment not found")
}

// Stop deployment first
if err := g.stopDeployments([]types.DeploymentWithRelated{*deploymentWithRelated}, ctx); err != nil {
return HTTPInternalServerError("Failed to stop deployment")
}

// Delete deployment
if err := g.backendRepo.DeleteDeployment(ctx.Request().Context(), deploymentWithRelated.Deployment); err != nil {
return HTTPInternalServerError("Failed to delete deployment")
}

// Stop deployment
if err := g.stopDeployments([]types.DeploymentWithRelated{*deploymentWithRelated}, ctx); err != nil {
luke-lombardi marked this conversation as resolved.
Show resolved Hide resolved
return HTTPInternalServerError("Failed to stop deployment")
}

return ctx.NoContent(http.StatusOK)
}

Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ message Autoscaler {
string type = 1;
uint32 max_containers = 2;
uint32 tasks_per_container = 3;
uint32 min_containers = 4;
}

message TaskPolicy {
Expand Down
Loading
Loading