Skip to content

Commit

Permalink
Add worker management command to CLI (#590)
Browse files Browse the repository at this point in the history
- Add worker management command group
- Add list, cordon, uncordon, and drain worker commands
- Prevent requests from being scheduled on disabled workers
- Prevent disabled workers from being calculated in pool capacity
functionality

Resolve BE-1601
  • Loading branch information
nickpetrovic authored Oct 8, 2024
1 parent 8697786 commit 945175c
Show file tree
Hide file tree
Showing 14 changed files with 1,709 additions and 228 deletions.
4 changes: 4 additions & 0 deletions pkg/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Gateway struct {
EventRepo repository.EventRepository
Tailscale *network.Tailscale
metricsRepo repository.MetricsRepository
workerRepo repository.WorkerRepository
Storage storage.Storage
Scheduler *scheduler.Scheduler
ctx context.Context
Expand Down Expand Up @@ -127,6 +128,7 @@ func NewGateway() (*Gateway, error) {

containerRepo := repository.NewContainerRedisRepository(redisClient)
providerRepo := repository.NewProviderRedisRepository(redisClient)
workerRepo := repository.NewWorkerRedisRepository(redisClient, config.Worker)
taskRepo := repository.NewTaskRedisRepository(redisClient)
taskDispatcher, err := task.NewDispatcher(ctx, taskRepo)
if err != nil {
Expand All @@ -144,6 +146,7 @@ func NewGateway() (*Gateway, error) {
gateway.TaskDispatcher = taskDispatcher
gateway.metricsRepo = metricsRepo
gateway.EventRepo = eventRepo
gateway.workerRepo = workerRepo

return gateway, nil
}
Expand Down Expand Up @@ -367,6 +370,7 @@ func (g *Gateway) registerServices() error {
TaskDispatcher: g.TaskDispatcher,
RedisClient: g.RedisClient,
EventRepo: g.EventRepo,
WorkerRepo: g.workerRepo,
})
if err != nil {
return err
Expand Down
57 changes: 57 additions & 0 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ service GatewayService {
rpc CreateToken(CreateTokenRequest) returns (CreateTokenResponse);
rpc ToggleToken(ToggleTokenRequest) returns (ToggleTokenResponse);
rpc DeleteToken(DeleteTokenRequest) returns (DeleteTokenResponse);

// Workers
rpc ListWorkers(ListWorkersRequest) returns (ListWorkersResponse);
rpc CordonWorker(CordonWorkerRequest) returns (CordonWorkerResponse);
rpc UncordonWorker(UncordonWorkerRequest) returns (UncordonWorkerResponse);
rpc DrainWorker(DrainWorkerRequest) returns (DrainWorkerResponse);
}

message AuthorizeRequest {}
Expand Down Expand Up @@ -445,3 +451,54 @@ message GetURLResponse {
string err_msg = 2;
string url = 3;
}

message Worker {
string id = 1;
string status = 2;
string gpu = 3;
string pool_name = 4;
string machine_id = 5;
int32 priority = 6;
int64 total_cpu = 7;
int64 total_memory = 8;
uint32 total_gpu_count = 9;
int64 free_cpu = 10;
int64 free_memory = 11;
uint32 free_gpu_count = 12;
repeated Container active_containers = 13;
}

message ListWorkersRequest {}

message ListWorkersResponse {
bool ok = 1;
string err_msg = 2;
repeated Worker workers = 3;
}

message CordonWorkerRequest {
string worker_id = 1;
}

message CordonWorkerResponse {
bool ok = 1;
string err_msg = 2;
}

message UncordonWorkerRequest {
string worker_id = 1;
}

message UncordonWorkerResponse {
bool ok = 1;
string err_msg = 2;
}

message DrainWorkerRequest {
string worker_id = 1;
}

message DrainWorkerResponse {
bool ok = 1;
string err_msg = 2;
}
3 changes: 3 additions & 0 deletions pkg/gateway/services/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type GatewayService struct {
taskDispatcher *task.Dispatcher
redisClient *common.RedisClient
eventRepo repository.EventRepository
workerRepo repository.WorkerRepository
pb.UnimplementedGatewayServiceServer
}

Expand All @@ -30,6 +31,7 @@ type GatewayServiceOpts struct {
TaskDispatcher *task.Dispatcher
RedisClient *common.RedisClient
EventRepo repository.EventRepository
WorkerRepo repository.WorkerRepository
}

func NewGatewayService(opts *GatewayServiceOpts) (*GatewayService, error) {
Expand All @@ -42,5 +44,6 @@ func NewGatewayService(opts *GatewayServiceOpts) (*GatewayService, error) {
taskDispatcher: opts.TaskDispatcher,
redisClient: opts.RedisClient,
eventRepo: opts.EventRepo,
workerRepo: opts.WorkerRepo,
}, nil
}
199 changes: 199 additions & 0 deletions pkg/gateway/services/worker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package gatewayservices

import (
"context"
"errors"
"time"

"github.com/beam-cloud/beta9/pkg/auth"
"github.com/beam-cloud/beta9/pkg/types"
pb "github.com/beam-cloud/beta9/proto"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/timestamppb"
)

func (gws *GatewayService) ListWorkers(ctx context.Context, in *pb.ListWorkersRequest) (*pb.ListWorkersResponse, error) {
if _, err := isClusterAdmin(ctx); err != nil {
return &pb.ListWorkersResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

workers, err := gws.workerRepo.GetAllWorkers()
if err != nil {
return &pb.ListWorkersResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

sortWorkers(workers)

pbWorkers := make([]*pb.Worker, len(workers))
for i, w := range workers {
pbWorkers[i] = &pb.Worker{
Id: w.Id,
Status: string(w.Status),
Gpu: w.Gpu,
PoolName: w.PoolName,
MachineId: w.MachineId,
Priority: w.Priority,
TotalCpu: w.TotalCpu,
TotalMemory: w.TotalMemory,
TotalGpuCount: w.TotalGpuCount,
FreeCpu: w.FreeCpu,
FreeMemory: w.FreeMemory,
FreeGpuCount: w.FreeGpuCount,
}

containers, err := gws.containerRepo.GetActiveContainersByWorkerId(w.Id)
if err != nil {
continue
}

pbWorkers[i].ActiveContainers = make([]*pb.Container, len(containers))
for j, c := range containers {
pbWorkers[i].ActiveContainers[j] = &pb.Container{
ContainerId: c.ContainerId,
WorkspaceId: string(c.WorkspaceId),
Status: string(c.Status),
ScheduledAt: timestamppb.New(time.Unix(c.ScheduledAt, 0)),
}
}
}

return &pb.ListWorkersResponse{
Ok: true,
Workers: pbWorkers,
}, nil
}

func sortWorkers(w []*types.Worker) {
slices.SortFunc(w, func(i, j *types.Worker) int {
if i.PoolName < j.PoolName {
return -1
}
if i.PoolName > j.PoolName {
return 1
}
if i.Status < j.Status {
return -1
}
if i.Status > j.Status {
return 1
}
if i.Id < j.Id {
return -1
}
if i.Id > j.Id {
return 1
}
return 0
})
}

func (gws *GatewayService) CordonWorker(ctx context.Context, in *pb.CordonWorkerRequest) (*pb.CordonWorkerResponse, error) {
if _, err := isClusterAdmin(ctx); err != nil {
return &pb.CordonWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

worker, err := gws.workerRepo.GetWorkerById(in.WorkerId)
if err != nil {
return &pb.CordonWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

if err := gws.workerRepo.UpdateWorkerStatus(worker.Id, types.WorkerStatusDisabled); err != nil {
return &pb.CordonWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

return &pb.CordonWorkerResponse{
Ok: true,
}, nil
}

func (gws *GatewayService) UncordonWorker(ctx context.Context, in *pb.UncordonWorkerRequest) (*pb.UncordonWorkerResponse, error) {
if _, err := isClusterAdmin(ctx); err != nil {
return &pb.UncordonWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

worker, err := gws.workerRepo.GetWorkerById(in.WorkerId)
if err != nil {
return &pb.UncordonWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

err = gws.workerRepo.UpdateWorkerStatus(worker.Id, types.WorkerStatusAvailable)
if err != nil {
return &pb.UncordonWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

return &pb.UncordonWorkerResponse{
Ok: true,
}, nil
}

func (gws *GatewayService) DrainWorker(ctx context.Context, in *pb.DrainWorkerRequest) (*pb.DrainWorkerResponse, error) {
if _, err := isClusterAdmin(ctx); err != nil {
return &pb.DrainWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

worker, err := gws.workerRepo.GetWorkerById(in.WorkerId)
if err != nil {
return &pb.DrainWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

containers, err := gws.containerRepo.GetActiveContainersByWorkerId(worker.Id)
if err != nil {
return &pb.DrainWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, err
}

var group errgroup.Group
for _, container := range containers {
group.Go(func() error {
return gws.scheduler.Stop(container.ContainerId)
})
}
if err := group.Wait(); err != nil {
return &pb.DrainWorkerResponse{
Ok: false,
ErrMsg: err.Error(),
}, nil
}

return &pb.DrainWorkerResponse{
Ok: true,
}, nil
}

func isClusterAdmin(ctx context.Context) (bool, error) {
authInfo, _ := auth.AuthInfoFromContext(ctx)
return authInfo.Token.TokenType == types.TokenTypeClusterAdmin, errors.New("This action is not permitted")
}
1 change: 1 addition & 0 deletions pkg/repository/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type WorkerRepository interface {
GetAllWorkersOnMachine(machineId string) ([]*types.Worker, error)
AddWorker(w *types.Worker) error
ToggleWorkerAvailable(workerId string) error
UpdateWorkerStatus(workerId string, status types.WorkerStatus) error
RemoveWorker(w *types.Worker) error
SetWorkerKeepAlive(workerId string) error
UpdateWorkerCapacity(w *types.Worker, cr *types.ContainerRequest, ut types.CapacityUpdateType) error
Expand Down
36 changes: 20 additions & 16 deletions pkg/repository/worker_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,7 @@ func (r *WorkerRedisRepository) RemoveWorker(worker *types.Worker) error {
return nil
}

func (r *WorkerRedisRepository) SetWorkerKeepAlive(workerId string) error {
stateKey := common.RedisKeys.SchedulerWorkerState(workerId)

// Set TTL on state key
err := r.rdb.Expire(context.TODO(), stateKey, time.Duration(types.WorkerStateTtlS)*time.Second).Err()
if err != nil {
return fmt.Errorf("failed to set worker state ttl <%v>: %w", stateKey, err)
}

return nil
}

func (r *WorkerRedisRepository) ToggleWorkerAvailable(workerId string) error {
func (r *WorkerRedisRepository) UpdateWorkerStatus(workerId string, status types.WorkerStatus) error {
err := r.lock.Acquire(context.TODO(), common.RedisKeys.SchedulerWorkerLock(workerId), common.RedisLockOptions{TtlS: 10, Retries: 3})
if err != nil {
return err
Expand All @@ -121,12 +109,12 @@ func (r *WorkerRedisRepository) ToggleWorkerAvailable(workerId string) error {
return err
}

// Make worker available by setting status
// Update worker status
worker.Status = status
worker.ResourceVersion++
worker.Status = types.WorkerStatusAvailable
err = r.rdb.HSet(context.TODO(), stateKey, common.ToSlice(worker)).Err()
if err != nil {
return fmt.Errorf("failed to toggle worker state <%s>: %v", stateKey, err)
return fmt.Errorf("failed to update worker status <%s>: %v", stateKey, err)
}

// Set TTL on state key
Expand All @@ -138,6 +126,22 @@ func (r *WorkerRedisRepository) ToggleWorkerAvailable(workerId string) error {
return nil
}

func (r *WorkerRedisRepository) SetWorkerKeepAlive(workerId string) error {
stateKey := common.RedisKeys.SchedulerWorkerState(workerId)

// Set TTL on state key
err := r.rdb.Expire(context.TODO(), stateKey, time.Duration(types.WorkerStateTtlS)*time.Second).Err()
if err != nil {
return fmt.Errorf("failed to set worker state ttl <%v>: %w", stateKey, err)
}

return nil
}

func (r *WorkerRedisRepository) ToggleWorkerAvailable(workerId string) error {
return r.UpdateWorkerStatus(workerId, types.WorkerStatusAvailable)
}

// getWorkers retrieves a list of worker objects from the Redis store that match a given pattern.
// If useLock is set to true, a lock will be acquired for each worker and released after retrieval.
// If you can afford to not have the most up-to-date worker information, you can set useLock to false.
Expand Down
Loading

0 comments on commit 945175c

Please sign in to comment.