Skip to content

Commit 7218685

Browse files
committed
introduce service.gpus
Signed-off-by: Nicolas De Loof <[email protected]>
1 parent a59035a commit 7218685

File tree

8 files changed

+276
-204
lines changed

8 files changed

+276
-204
lines changed

loader/loader_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,6 +2224,32 @@ services:
22242224
assert.ErrorContains(t, err, `capabilities is required`)
22252225
}
22262226

2227+
func TestServiceGpus(t *testing.T) {
2228+
p, err := loadYAML(`
2229+
name: service-gpus
2230+
services:
2231+
test:
2232+
image: redis:alpine
2233+
gpus:
2234+
- driver: nvidia
2235+
- driver: 3dfx
2236+
device_ids: ["voodoo2"]
2237+
capabilities: ["directX"]
2238+
`)
2239+
assert.NilError(t, err)
2240+
assert.DeepEqual(t, p.Services["test"].Gpus, []types.DeviceRequest{
2241+
{
2242+
Driver: "nvidia",
2243+
Count: -1,
2244+
},
2245+
{
2246+
Capabilities: []string{"directX"},
2247+
Driver: "3dfx",
2248+
IDs: []string{"voodoo2"},
2249+
},
2250+
})
2251+
}
2252+
22272253
func TestServicePullPolicy(t *testing.T) {
22282254
actual, err := loadYAML(`
22292255
name: service-pull-policy

schema/compose-spec.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
},
268268
"external_links": {"type": "array", "items": {"type": "string"}, "uniqueItems": true},
269269
"extra_hosts": {"$ref": "#/definitions/extra_hosts"},
270+
"gpus": {"$ref": "#/definitions/gpus"},
270271
"group_add": {
271272
"type": "array",
272273
"items": {
@@ -651,6 +652,23 @@
651652
}
652653
},
653654

655+
"gpus": {
656+
"id": "#/definitions/gpus",
657+
"type": "array",
658+
"items": {
659+
"type": "object",
660+
"properties": {
661+
"capabilities": {"$ref": "#/definitions/list_of_strings"},
662+
"count": {"type": ["string", "integer"]},
663+
"device_ids": {"$ref": "#/definitions/list_of_strings"},
664+
"driver":{"type": "string"},
665+
"options":{"$ref": "#/definitions/list_or_dict"}
666+
},
667+
"additionalProperties": false,
668+
"patternProperties": {"^x-": {}}
669+
}
670+
},
671+
654672
"include": {
655673
"id": "#/definitions/include",
656674
"oneOf": [

transform/canonical.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ func init() {
2929
transformers["services.*.build.secrets.*"] = transformFileMount
3030
transformers["services.*.build.additional_contexts"] = transformKeyValue
3131
transformers["services.*.depends_on"] = transformDependsOn
32-
transformers["services.*.deploy.resources.reservations.devices.*"] = transformDeviceRequest
3332
transformers["services.*.env_file"] = transformEnvFile
3433
transformers["services.*.extends"] = transformExtends
3534
transformers["services.*.networks"] = transformServiceNetworks

transform/defaults.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ func init() {
2626
defaultValues["services.*.build"] = defaultBuildContext
2727
defaultValues["services.*.secrets.*"] = defaultSecretMount
2828
defaultValues["services.*.ports.*"] = portDefaults
29+
defaultValues["services.*.deploy.resources.reservations.devices.*"] = deviceRequestDefaults
30+
defaultValues["services.*.gpus.*"] = deviceRequestDefaults
2931
}
3032

3133
// SetDefaultValues transforms a compose model to set default values to missing attributes

transform/devices.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,15 @@ import (
2222
"github.com/compose-spec/compose-go/v2/tree"
2323
)
2424

25-
func transformDeviceRequest(data any, p tree.Path, ignoreParseError bool) (any, error) {
26-
switch v := data.(type) {
27-
case map[string]any:
28-
_, hasCount := v["count"]
29-
_, hasIds := v["device_ids"]
30-
if hasCount && hasIds {
31-
return nil, fmt.Errorf(`%s: "count" and "device_ids" attributes are exclusive`, p)
32-
}
33-
if !hasCount && !hasIds {
34-
v["count"] = "all"
35-
}
36-
return transformMapping(v, p, ignoreParseError)
37-
default:
25+
func deviceRequestDefaults(data any, p tree.Path, _ bool) (any, error) {
26+
v, ok := data.(map[string]any)
27+
if !ok {
3828
return data, fmt.Errorf("%s: invalid type %T for device request", p, v)
3929
}
30+
_, hasCount := v["count"]
31+
_, hasIds := v["device_ids"]
32+
if !hasCount && !hasIds {
33+
v["count"] = "all"
34+
}
35+
return v, nil
4036
}

0 commit comments

Comments
 (0)