Skip to content

Commit

Permalink
Enable multi-gpu containers (#442)
Browse files Browse the repository at this point in the history
- Adds SDK support for multiple GPUs per container

**NOTE**:  This limits the max gpu_count to 2 for now
  • Loading branch information
luke-lombardi authored Nov 20, 2024
1 parent e97794b commit f6f55fa
Show file tree
Hide file tree
Showing 16 changed files with 593 additions and 521 deletions.
4 changes: 2 additions & 2 deletions pkg/abstractions/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func (cs *CmdContainerService) ExecuteCommand(in *pb.CommandExecutionRequest, st
gpuRequest = append(gpuRequest, stubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := stubConfig.Runtime.GpuCount
if stubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/abstractions/endpoint/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func (i *endpointInstance) startContainers(containersToRun int) error {
gpuRequest = append(gpuRequest, i.StubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := i.StubConfig.Runtime.GpuCount
if i.StubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/abstractions/function/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func (t *FunctionTask) run(ctx context.Context, stub *types.StubWithRelated) err
gpuRequest = append(gpuRequest, stubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := stubConfig.Runtime.GpuCount
if stubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/abstractions/taskqueue/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (i *taskQueueInstance) startContainers(containersToRun int) error {
gpuRequest = append(gpuRequest, i.StubConfig.Runtime.Gpu.String())
}

gpuCount := 0
if len(gpuRequest) > 0 {
gpuCount := i.StubConfig.Runtime.GpuCount
if i.StubConfig.RequiresGPU() && gpuCount == 0 {
gpuCount = 1
}

Expand Down
1 change: 1 addition & 0 deletions pkg/common/config.default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ gateway:
stubLimits:
memory: 32768
maxReplicas: 10
maxGpuCount: 2
imageService:
localCacheEnabled: true
registryStore: local
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ message GetOrCreateStubRequest {
TaskPolicy task_policy = 23;
uint32 concurrent_requests = 24;
string extra = 25;
uint32 gpu_count = 26;
}

message GetOrCreateStubResponse {
Expand Down
83 changes: 48 additions & 35 deletions pkg/gateway/services/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,6 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea

gpus := types.GPUTypesFromString(in.Gpu)

if len(gpus) > 0 {
concurrencyLimit, err := gws.backendRepo.GetConcurrencyLimitByWorkspaceId(ctx, authInfo.Workspace.ExternalId)
if err != nil && concurrencyLimit != nil && concurrencyLimit.GPULimit <= 0 {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "GPU concurrency limit is 0.",
}, nil
}

gpuCounts, err := gws.providerRepo.GetGPUCounts(gws.appConfig.Worker.Pools)
if err != nil {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "Failed to get GPU counts.",
}, nil
}

// T4s are currently in a different pool than other GPUs and won't show up in gpu counts
lowGpus := []string{}

for _, gpu := range gpus {
if gpuCounts[gpu.String()] <= 1 && gpu.String() != types.GPU_T4.String() {
lowGpus = append(lowGpus, gpu.String())
}
}

if len(lowGpus) > 0 {
warning = fmt.Sprintf("GPU capacity for %s is currently low.", strings.Join(lowGpus, ", "))
}
}

autoscaler := &types.Autoscaler{}
if in.Autoscaler.Type == "" {
autoscaler.Type = types.QueueDepthAutoscaler
Expand All @@ -83,12 +52,20 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
in.Extra = "{}"
}

if in.GpuCount > gws.appConfig.GatewayService.StubLimits.MaxGpuCount {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: fmt.Sprintf("GPU count must be %d or less.", gws.appConfig.GatewayService.StubLimits.MaxGpuCount),
}, nil
}

stubConfig := types.StubConfigV1{
Runtime: types.Runtime{
Cpu: in.Cpu,
Gpus: gpus,
Memory: in.Memory,
ImageId: in.ImageId,
Cpu: in.Cpu,
Gpus: gpus,
GpuCount: in.GpuCount,
Memory: in.Memory,
ImageId: in.ImageId,
},
Handler: in.Handler,
OnStart: in.OnStart,
Expand All @@ -106,6 +83,42 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
Extra: json.RawMessage(in.Extra),
}

// Ensure GPU count is at least 1 if a GPU is required
if stubConfig.RequiresGPU() && in.GpuCount == 0 {
stubConfig.Runtime.GpuCount = 1
}

if stubConfig.RequiresGPU() {
concurrencyLimit, err := gws.backendRepo.GetConcurrencyLimitByWorkspaceId(ctx, authInfo.Workspace.ExternalId)
if err != nil && concurrencyLimit != nil && concurrencyLimit.GPULimit <= 0 {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "GPU concurrency limit is 0.",
}, nil
}

gpuCounts, err := gws.providerRepo.GetGPUCounts(gws.appConfig.Worker.Pools)
if err != nil {
return &pb.GetOrCreateStubResponse{
Ok: false,
ErrMsg: "Failed to get GPU counts.",
}, nil
}

// T4s are currently in a different pool than other GPUs and won't show up in gpu counts
lowGpus := []string{}

for _, gpu := range gpus {
if gpuCounts[gpu.String()] <= 1 && gpu.String() != types.GPU_T4.String() {
lowGpus = append(lowGpus, gpu.String())
}
}

if len(lowGpus) > 0 {
warning = fmt.Sprintf("GPU capacity for %s is currently low.", strings.Join(lowGpus, ", "))
}
}

// Get secrets
for _, secret := range in.Secrets {
secret, err := gws.backendRepo.GetSecretByName(ctx, authInfo.Workspace, secret.Name)
Expand Down
15 changes: 10 additions & 5 deletions pkg/types/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ type StubConfigV1 struct {
Extra json.RawMessage `json:"extra"`
}

func (c *StubConfigV1) RequiresGPU() bool {
return len(c.Runtime.Gpus) > 0 || c.Runtime.Gpu != ""
}

type AutoscalerType string

const (
Expand Down Expand Up @@ -290,11 +294,12 @@ type Image struct {
}

type Runtime struct {
Cpu int64 `json:"cpu"`
Gpu GpuType `json:"gpu"`
Memory int64 `json:"memory"`
ImageId string `json:"image_id"`
Gpus []GpuType `json:"gpus"`
Cpu int64 `json:"cpu"`
Gpu GpuType `json:"gpu"`
GpuCount uint32 `json:"gpu_count"`
Memory int64 `json:"memory"`
ImageId string `json:"image_id"`
Gpus []GpuType `json:"gpus"`
}

type GpuType string
Expand Down
1 change: 1 addition & 0 deletions pkg/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ type CORSConfig struct {
type StubLimits struct {
Memory uint64 `key:"memory" json:"memory"`
MaxReplicas uint64 `key:"maxReplicas" json:"max_replicas"`
MaxGpuCount uint32 `key:"maxGpuCount" json:"max_gpu_count"`
}

type GatewayServiceConfig struct {
Expand Down
Loading

0 comments on commit f6f55fa

Please sign in to comment.