Skip to content

Commit

Permalink
koordlet: fix prediction restore for node priorities (#1749)
Browse files Browse the repository at this point in the history
Signed-off-by: saintube <[email protected]>
  • Loading branch information
saintube authored Nov 21, 2023
1 parent 08b6e89 commit dccef94
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 22 deletions.
15 changes: 13 additions & 2 deletions pkg/koordlet/prediction/predict_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,24 @@ func (p *peakPredictServer) restoreModels() (unknownUIDs []UIDType) {
}

knownUIDs := make(map[UIDType]bool)
// pods checkpoints
pods := p.informer.ListPods()
for _, pod := range pods {
knownUIDs[UIDType(pod.UID)] = true
podUID := p.uidGenerator.Pod(pod)
knownUIDs[podUID] = true
}
// node checkpoint
node := p.informer.GetNode()
if node != nil {
knownUIDs[UIDType(node.UID)] = true
nodeUID := p.uidGenerator.Node()
knownUIDs[nodeUID] = true
}
// node items checkpoints (priority classes)
systemUID := p.uidGenerator.NodeItem(SystemItemID)
knownUIDs[systemUID] = true
for _, priorityClass := range extension.KnownPriorityClasses {
priorityUID := p.uidGenerator.NodeItem(string(priorityClass))
knownUIDs[priorityUID] = true
}

for _, checkpoint := range checkpoints {
Expand Down
57 changes: 41 additions & 16 deletions pkg/koordlet/prediction/predict_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,29 +460,54 @@ func TestDoCheckpoint_RestoreModels(t *testing.T) {
},
},
}
testModels := map[UIDType]*PredictModel{
UIDType(DefaultNodeID): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("__node-sys__"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("__node-koord-prod__"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("__node-koord-mid__"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("__node-koord-batch__"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("__node-koord-free__"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("__node-__"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("pod1"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("pod2"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
}
cfg := NewDefaultConfig()
cfg.ModelCheckpointMaxPerStep = 5
cfg.ModelCheckpointMaxPerStep = 20
predictServer := &peakPredictServer{
cfg: cfg,
hasSynced: &atomic.Bool{},
informer: &mockInformer{Pods: pods, Node: node},
metricServer: &mockMetricServer{},
uidGenerator: &generator{},
clock: mockClock,
models: map[UIDType]*PredictModel{
UIDType("node1"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("pod1"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
UIDType("pod2"): {
CPU: makeTestHistogram(),
Memory: makeTestHistogram(),
},
},
models: testModels,
checkpointer: NewFileCheckpointer(tempDir),
}
predictServer.hasSynced.Store(true)
Expand All @@ -491,7 +516,7 @@ func TestDoCheckpoint_RestoreModels(t *testing.T) {
// clear the models in memory and restore it
predictServer.models = make(map[UIDType]*PredictModel)
predictServer.restoreModels()
assert.Equal(t, 3, len(predictServer.models), "restore models from checkpoint")
assert.Equal(t, len(testModels), len(predictServer.models), "restore models from checkpoint")

// mock another model and restore it to unknownUIDs
predictServer.models["unknown"] = &PredictModel{
Expand Down
8 changes: 4 additions & 4 deletions pkg/koordlet/prediction/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
const (
DefaultNodeID = "__node__"
DefaultNodeItemIDFmt = "__node-%s__"
SystemItemID = "sys" // node item ID for the system overhead which is not counted in any pod
AllPodsItemID = "all-pods"
SystemItemID = "sys" // node item ID for the system overhead which is not counted in any pod
AllPodsItemID = "all-pods" // not stored for now, just used for calculating the sys
)

type UIDType string
Expand Down Expand Up @@ -99,8 +99,8 @@ func (i *informer) HasSynced() bool {
func (i *informer) ListPods() []*v1.Pod {
pods := i.statesInformer.GetAllPods()
result := make([]*v1.Pod, len(pods))
for i := range pods {
result[i] = pods[i].Pod
for j := range pods {
result[j] = pods[j].Pod
}
return result
}
Expand Down

0 comments on commit dccef94

Please sign in to comment.