Skip to content

Commit fc0bd7a

Browse files
authored
Merge pull request #1969 from mindprince/fix-race
Initialize NVML on demand.
2 parents f834c0f + 2ce4161 commit fc0bd7a

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

accelerators/nvidia.go

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import (
3131
)
3232

3333
type NvidiaManager struct {
34-
sync.RWMutex
34+
sync.Mutex
35+
36+
// true if there are NVIDIA devices present on the node
37+
devicesPresent bool
3538

3639
// true if the NVML library (libnvidia-ml.so.1) was loaded successfully
3740
nvmlInitialized bool
@@ -51,20 +54,9 @@ func (nm *NvidiaManager) Setup() {
5154
return
5255
}
5356

54-
nm.initializeNVML()
55-
if nm.nvmlInitialized {
56-
return
57-
}
58-
go func() {
59-
glog.V(2).Info("Starting goroutine to initialize NVML")
60-
// TODO: use globalHousekeepingInterval
61-
for range time.Tick(time.Minute) {
62-
nm.initializeNVML()
63-
if nm.nvmlInitialized {
64-
return
65-
}
66-
}
67-
}()
57+
nm.devicesPresent = true
58+
59+
initializeNVML(nm)
6860
}
6961

7062
// detectDevices returns true if a device with given pci id is present on the node.
@@ -91,20 +83,18 @@ func detectDevices(vendorId string) bool {
9183
}
9284

9385
// initializeNVML initializes the NVML library and sets up the nvmlDevices map.
94-
func (nm *NvidiaManager) initializeNVML() {
86+
// This is defined as a variable to help in testing.
87+
var initializeNVML = func(nm *NvidiaManager) {
9588
if err := gonvml.Initialize(); err != nil {
9689
// This is under a logging level because otherwise we may cause
9790
// log spam if the drivers/nvml is not installed on the system.
9891
glog.V(4).Infof("Could not initialize NVML: %v", err)
9992
return
10093
}
94+
nm.nvmlInitialized = true
10195
numDevices, err := gonvml.DeviceCount()
10296
if err != nil {
10397
glog.Warningf("GPU metrics would not be available. Failed to get the number of nvidia devices: %v", err)
104-
nm.Lock()
105-
// Even though we won't have GPU metrics, the library was initialized and should be shutdown when exiting.
106-
nm.nvmlInitialized = true
107-
nm.Unlock()
10898
return
10999
}
110100
glog.V(1).Infof("NVML initialized. Number of nvidia devices: %v", numDevices)
@@ -122,10 +112,6 @@ func (nm *NvidiaManager) initializeNVML() {
122112
}
123113
nm.nvidiaDevices[int(minorNumber)] = device
124114
}
125-
nm.Lock()
126-
// Doing this at the end to avoid race in accessing nvidiaDevices in GetCollector.
127-
nm.nvmlInitialized = true
128-
nm.Unlock()
129115
}
130116

131117
// Destroy shuts down NVML.
@@ -139,12 +125,21 @@ func (nm *NvidiaManager) Destroy() {
139125
// present in the devices.list file in the given devicesCgroupPath.
140126
func (nm *NvidiaManager) GetCollector(devicesCgroupPath string) (AcceleratorCollector, error) {
141127
nc := &NvidiaCollector{}
142-
nm.RLock()
128+
129+
if !nm.devicesPresent {
130+
return nc, nil
131+
}
132+
// Makes sure that we don't call initializeNVML() concurrently and
133+
// that we only call initializeNVML() when it's not initialized.
134+
nm.Lock()
135+
if !nm.nvmlInitialized {
136+
initializeNVML(nm)
137+
}
143138
if !nm.nvmlInitialized || len(nm.nvidiaDevices) == 0 {
144-
nm.RUnlock()
139+
nm.Unlock()
145140
return nc, nil
146141
}
147-
nm.RUnlock()
142+
nm.Unlock()
148143
nvidiaMinorNumbers, err := parseDevicesCgroup(devicesCgroupPath)
149144
if err != nil {
150145
return nc, err

accelerators/nvidia_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,32 @@ func TestGetCollector(t *testing.T) {
7171
return []int{2, 3}, nil
7272
}
7373
parseDevicesCgroup = mockParser
74+
originalInitializeNVML := initializeNVML
75+
initializeNVML = func(_ *NvidiaManager) {}
7476
defer func() {
7577
parseDevicesCgroup = originalParser
78+
initializeNVML = originalInitializeNVML
7679
}()
7780

7881
nm := &NvidiaManager{}
7982

80-
// When nvmlInitialized is false, empty collector should be returned.
83+
// When devicesPresent is false, empty collector should be returned.
8184
ac, err := nm.GetCollector("does-not-matter")
8285
assert.Nil(t, err)
8386
assert.NotNil(t, ac)
8487
nc, ok := ac.(*NvidiaCollector)
8588
assert.True(t, ok)
8689
assert.Equal(t, 0, len(nc.Devices))
8790

91+
// When nvmlInitialized is false, empty collector should be returned.
92+
nm.devicesPresent = true
93+
ac, err = nm.GetCollector("does-not-matter")
94+
assert.Nil(t, err)
95+
assert.NotNil(t, ac)
96+
nc, ok = ac.(*NvidiaCollector)
97+
assert.True(t, ok)
98+
assert.Equal(t, 0, len(nc.Devices))
99+
88100
// When nvidiaDevices is empty, empty collector should be returned.
89101
nm.nvmlInitialized = true
90102
ac, err = nm.GetCollector("does-not-matter")

0 commit comments

Comments
 (0)