Skip to content

Commit 02af3d8

Browse files
authored
Merge pull request #22 from elezar/fix-devices
Fix construction of linked devices.
2 parents 8fc3087 + b0ec32c commit 02af3d8

File tree

29 files changed

+1468
-29
lines changed

29 files changed

+1468
-29
lines changed

examples/devices/main.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/**
2+
# Copyright 2024 NVIDIA CORPORATION
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package main
18+
19+
import (
20+
"fmt"
21+
"os"
22+
23+
"github.com/NVIDIA/go-gpuallocator/gpuallocator"
24+
)
25+
26+
func main() {
27+
dl, err := gpuallocator.NewDevices()
28+
if err != nil {
29+
fmt.Printf("error getting devices: %v\n", err)
30+
os.Exit(1)
31+
}
32+
33+
fmt.Printf("Found %d devices:\n", len(dl))
34+
for i, device := range dl {
35+
fmt.Printf("device %d:\n", i)
36+
fmt.Printf("%s\n", device.Details())
37+
}
38+
}

go.mod

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ module github.com/NVIDIA/go-gpuallocator
22

33
go 1.20
44

5-
require github.com/NVIDIA/go-nvlib v0.0.0-20231116150931-9fd385bace0d
5+
require github.com/NVIDIA/go-nvlib v0.0.0-20240109130712-11603560817a
66

7-
require github.com/NVIDIA/go-nvml v0.12.0-1.0.20231020145430-e06766c5e74f // indirect
7+
require (
8+
github.com/NVIDIA/go-nvml v0.12.0-1.0.20231020145430-e06766c5e74f // indirect
9+
github.com/google/uuid v1.4.0 // indirect
10+
)
811

912
replace (
1013
k8s.io/api => k8s.io/api v0.18.2

go.sum

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
github.com/NVIDIA/go-nvlib v0.0.0-20231116150931-9fd385bace0d h1:XxRHS7eNkZVcPpZZmUcoT4oO8FEcoYKn06sooQh5niU=
2-
github.com/NVIDIA/go-nvlib v0.0.0-20231116150931-9fd385bace0d/go.mod h1:HPFNPAYqQeoos58MKUboWsdZMu71EzSQrbmd+QBRD40=
1+
github.com/NVIDIA/go-nvlib v0.0.0-20240109130712-11603560817a h1:EH7wiaq9+NYDgCBJEcGa3HTO2Sz6dRlmO2y9yMxA5jE=
2+
github.com/NVIDIA/go-nvlib v0.0.0-20240109130712-11603560817a/go.mod h1:U82N6/xKp6OnoqpALBH0C5SO59Buu4sX1Z3rQtBsBKQ=
33
github.com/NVIDIA/go-nvml v0.12.0-1.0.20231020145430-e06766c5e74f h1:FTblgO87K1vPB8tcwM5EOFpFf6UpsrlDpErPm25mFWE=
44
github.com/NVIDIA/go-nvml v0.12.0-1.0.20231020145430-e06766c5e74f/go.mod h1:7ruy85eOM73muOc/I37euONSwEyFqZsv5ED9AogD4G0=
55
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
66
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
77
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
8+
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
9+
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
810
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
911
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
1012
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

gpuallocator/device.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,36 +98,37 @@ func (o *deviceListBuilder) build() (DeviceList, error) {
9898
_ = o.nvmllib.Shutdown()
9999
}()
100100

101+
nvmlDevices, err := o.devicelib.GetDevices()
102+
if err != nil {
103+
return nil, fmt.Errorf("failed to get devices: %v", err)
104+
}
105+
101106
var devices DeviceList
102-
err := o.devicelib.VisitDevices(func(i int, d device.Device) error {
107+
for i, d := range nvmlDevices {
103108
device, err := newDevice(i, d)
104109
if err != nil {
105-
return fmt.Errorf("failed to construct linked device: %v", err)
110+
return nil, fmt.Errorf("failed to construct linked device: %v", err)
106111
}
107112
devices = append(devices, device)
108-
return nil
109-
})
110-
if err != nil {
111-
return nil, fmt.Errorf("failed to get devices: %v", err)
112113
}
113114

114-
for i, d1 := range devices {
115-
for j, d2 := range devices {
115+
for i, d1 := range nvmlDevices {
116+
for j, d2 := range nvmlDevices {
116117
if i != j {
117118
p2plink, err := links.GetP2PLink(d1, d2)
118119
if err != nil {
119120
return nil, fmt.Errorf("error getting P2PLink for devices (%v, %v): %v", i, j, err)
120121
}
121122
if p2plink != links.P2PLinkUnknown {
122-
d1.Links[d2.Index] = append(d1.Links[d2.Index], P2PLink{d2, p2plink})
123+
devices[i].Links[j] = append(devices[i].Links[j], P2PLink{devices[j], p2plink})
123124
}
124125

125126
nvlink, err := links.GetNVLink(d1, d2)
126127
if err != nil {
127128
return nil, fmt.Errorf("error getting NVLink for devices (%v, %v): %v", i, j, err)
128129
}
129130
if nvlink != links.P2PLinkUnknown {
130-
d1.Links[d2.Index] = append(d1.Links[d2.Index], P2PLink{d2, nvlink})
131+
devices[i].Links[j] = append(devices[i].Links[j], P2PLink{devices[j], nvlink})
131132
}
132133
}
133134
}

internal/links/device.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,62 @@ const (
5656
EighteenNVLINKLinks
5757
)
5858

59+
// String returns the string representation of the P2PLink type.
60+
func (l P2PLinkType) String() string {
61+
switch l {
62+
case P2PLinkCrossCPU:
63+
return "P2PLinkCrossCPU"
64+
case P2PLinkSameCPU:
65+
return "P2PLinkSameCPU"
66+
case P2PLinkHostBridge:
67+
return "P2PLinkHostBridge"
68+
case P2PLinkMultiSwitch:
69+
return "P2PLinkMultiSwitch"
70+
case P2PLinkSingleSwitch:
71+
return "P2PLinkSingleSwitch"
72+
case P2PLinkSameBoard:
73+
return "P2PLinkSameBoard"
74+
case SingleNVLINKLink:
75+
return "SingleNVLINKLink"
76+
case TwoNVLINKLinks:
77+
return "TwoNVLINKLinks"
78+
case ThreeNVLINKLinks:
79+
return "ThreeNVLINKLinks"
80+
case FourNVLINKLinks:
81+
return "FourNVLINKLinks"
82+
case FiveNVLINKLinks:
83+
return "FiveNVLINKLinks"
84+
case SixNVLINKLinks:
85+
return "SixNVLINKLinks"
86+
case SevenNVLINKLinks:
87+
return "SevenNVLINKLinks"
88+
case EightNVLINKLinks:
89+
return "EightNVLINKLinks"
90+
case NineNVLINKLinks:
91+
return "NineNVLINKLinks"
92+
case TenNVLINKLinks:
93+
return "TenNVLINKLinks"
94+
case ElevenNVLINKLinks:
95+
return "ElevenNVLINKLinks"
96+
case TwelveNVLINKLinks:
97+
return "TwelveNVLINKLinks"
98+
case ThirteenNVLINKLinks:
99+
return "ThirteenNVLINKLinks"
100+
case FourteenNVLINKLinks:
101+
return "FourteenNVLINKLinks"
102+
case FifteenNVLINKLinks:
103+
return "FifteenNVLINKLinks"
104+
case SixteenNVLINKLinks:
105+
return "SixteenNVLINKLinks"
106+
case SeventeenNVLINKLinks:
107+
return "SeventeenNVLINKLinks"
108+
case EighteenNVLINKLinks:
109+
return "EighteenNVLINKLinks"
110+
default:
111+
return fmt.Sprintf("UNKOWN (%v)", uint(l))
112+
}
113+
}
114+
59115
// GetP2PLink gets the peer-to-peer connectivity between two devices.
60116
func GetP2PLink(dev1 device.Device, dev2 device.Device) (P2PLinkType, error) {
61117
level, ret := dev1.GetTopologyCommonAncestor(dev2)
@@ -149,23 +205,23 @@ func getAllNvLinkRemotePciInfo(dev device.Device) ([]PciInfo, error) {
149205
var pciInfos []PciInfo
150206
for i := 0; i < nvml.NVLINK_MAX_LINKS; i++ {
151207
state, ret := dev.GetNvLinkState(i)
152-
if ret == nvml.ERROR_NOT_SUPPORTED {
208+
if ret == nvml.ERROR_NOT_SUPPORTED || ret == nvml.ERROR_INVALID_ARGUMENT {
153209
continue
154210
}
155211
if ret != nvml.SUCCESS {
156212
return nil, fmt.Errorf("failed to get nvlink state: %v", ret)
157213
}
158-
159-
if state == nvml.FEATURE_ENABLED {
160-
pciInfo, ret := dev.GetNvLinkRemotePciInfo(i)
161-
if ret == nvml.ERROR_NOT_SUPPORTED {
162-
continue
163-
}
164-
if ret != nvml.SUCCESS {
165-
return nil, fmt.Errorf("failed to get remote pci info: %v", ret)
166-
}
167-
pciInfos = append(pciInfos, PciInfo(pciInfo))
214+
if state != nvml.FEATURE_ENABLED {
215+
continue
216+
}
217+
pciInfo, ret := dev.GetNvLinkRemotePciInfo(i)
218+
if ret == nvml.ERROR_NOT_SUPPORTED || ret == nvml.ERROR_INVALID_ARGUMENT {
219+
continue
220+
}
221+
if ret != nvml.SUCCESS {
222+
return nil, fmt.Errorf("failed to get remote pci info: %v", ret)
168223
}
224+
pciInfos = append(pciInfos, PciInfo(pciInfo))
169225
}
170226

171227
return pciInfos, nil

vendor/github.com/NVIDIA/go-nvlib/pkg/nvlib/device/identifier.go

Lines changed: 94 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/device.go

Lines changed: 8 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)