diff --git a/cmd/csi-provisioner/csi-provisioner.go b/cmd/csi-provisioner/csi-provisioner.go index ab1e96616..1d10f28ac 100644 --- a/cmd/csi-provisioner/csi-provisioner.go +++ b/cmd/csi-provisioner/csi-provisioner.go @@ -302,9 +302,9 @@ func main() { // TODO: metrics for the queue?! workqueue.NewNamedRateLimitingQueue(rateLimiter, "csistoragecapacity"), *controller, + namespace, topologyInformer, factory.Storage().V1().StorageClasses(), - -1, /* let API server generate names */ ) } diff --git a/go.sum b/go.sum index 048f2d65e..4e517c60a 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,7 @@ github.com/containerd/ttrpc v1.0.0/go.mod h1:PvCDdDGpgqzQIzDW1TphrGLssLDZp2GuS+X github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= github.com/containerd/typeurl v1.0.0/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= github.com/containernetworking/cni v0.7.1/go.mod h1:LGwApLUm2FpoOfxTDEeq8T9ipbpZ61X79hmU3w8FmsY= +github.com/containernetworking/cni v0.8.0/go.mod h1:LGwApLUm2FpoOfxTDEeq8T9ipbpZ61X79hmU3w8FmsY= github.com/coredns/corefile-migration v1.0.8/go.mod h1:OFwBp/Wc9dJt5cAZzHWMNhK1r5L0p0jDwIBc6j8NC8E= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= @@ -176,6 +177,8 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logr/logr v0.1.0 h1:M1Tv3VzNlEHg6uyACnRdtrploV2P7wZqH8BoQMtz0cg= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= +github.com/go-logr/logr v0.2.0 h1:QvGt2nLcHH0WK9orKa+ppBPAxREcH364nPUedEpK0TY= +github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= github.com/go-openapi/analysis v0.0.0-20180825180245-b006789cd277/go.mod h1:k70tL6pCuVxPJOHXQ+wIac1FUrvNkHolPie/cLEU6hI= github.com/go-openapi/analysis v0.17.0/go.mod h1:IowGgpVeD0vNm45So8nr+IcQ3pxVtpRoBWb8PVZO0ik= github.com/go-openapi/analysis v0.18.0/go.mod h1:IowGgpVeD0vNm45So8nr+IcQ3pxVtpRoBWb8PVZO0ik= @@ -847,6 +850,8 @@ k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE= k8s.io/klog/v2 v2.1.0 h1:X3+Mru/L3jy4BI4vcAYkHvL6PyU+QBsuhEqwlI4mgkA= k8s.io/klog/v2 v2.1.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE= +k8s.io/klog/v2 v2.2.0 h1:XRvcwJozkgZ1UQJmfMGpvRthQHOvihEhYtDfAaxMz/A= +k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/kube-aggregator v0.19.0-beta.2/go.mod h1:CdbKxeXJJQNMWL8ZeZyBHORXFNLKlQT2BhgM7M2vk78= k8s.io/kube-controller-manager v0.19.0-beta.2/go.mod h1:gEPC5sAWKdjiw8aNnrCi7Np4H3q5vKK//m5ZNwbilMA= k8s.io/kube-openapi v0.0.0-20200427153329-656914f816f9 h1:5NC2ITmvg8RoxoH0wgmL4zn4VZqXGsKbxrikjaQx6s4= diff --git a/pkg/capacity/controller.go b/pkg/capacity/controller.go index 7cc080009..1d2d7b46c 100644 --- a/pkg/capacity/controller.go +++ b/pkg/capacity/controller.go @@ -77,9 +77,9 @@ type Controller struct { client kubernetes.Interface queue workqueue.RateLimitingInterface owner metav1.OwnerReference + ownerNamespace string topologyInformer topology.Informer scInformer storageinformersv1.StorageClassInformer - enumerateObjects int // capacities contains one entry for each object that is supposed // to exist. @@ -114,9 +114,9 @@ func NewCentralCapacityController( client kubernetes.Interface, queue workqueue.RateLimitingInterface, owner metav1.OwnerReference, + ownerNamespace string, topologyInformer topology.Informer, scInformer storageinformersv1.StorageClassInformer, - enumerateObjects int, ) *Controller { c := &Controller{ csiController: csiController, @@ -124,10 +124,10 @@ func NewCentralCapacityController( client: client, queue: queue, owner: owner, + ownerNamespace: ownerNamespace, topologyInformer: topologyInformer, scInformer: scInformer, capacities: map[workItem]*storagev1alpha1.CSIStorageCapacity{}, - enumerateObjects: enumerateObjects, } // Now register for changes. Depending on the implementation of the informers, @@ -398,16 +398,16 @@ func (c *Controller) syncCapacity(ctx context.Context, item workItem) error { capacity.Capacity = quantity var err error klog.V(5).Infof("Capacity Controller: updating %s for %+v, new capacity %v", capacity.Name, item, quantity) - capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities().Update(ctx, capacity, metav1.UpdateOptions{}) + capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities(capacity.Namespace).Update(ctx, capacity, metav1.UpdateOptions{}) if err != nil && apierrs.IsConflict(err) { // Handle the case where we had a stale copy of the object. Can only happen // when someone else was making changes to it, which should be rare. - capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities().Get(ctx, capacity.Name, metav1.GetOptions{}) + capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities(capacity.Namespace).Get(ctx, capacity.Name, metav1.GetOptions{}) if err != nil { return fmt.Errorf("getting fresh copy of CSIStorageCapacity for %+v: %v", item, err) } capacity.Capacity = quantity - capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities().Update(ctx, capacity, metav1.UpdateOptions{}) + capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities(capacity.Namespace).Update(ctx, capacity, metav1.UpdateOptions{}) } if err != nil { return fmt.Errorf("update CSIStorageCapacity for %+v: %v", item, err) @@ -423,16 +423,9 @@ func (c *Controller) syncCapacity(ctx context.Context, item workItem) error { NodeTopology: item.segment.GetLabelSelector(), Capacity: quantity, } - if c.enumerateObjects >= 0 { - // Workaround for testing with a fake client: it doesn't - // set the name, so we have to make up something ourselves. - c.enumerateObjects++ - capacity.Name = fmt.Sprintf("csisc-test-%d", c.enumerateObjects) - capacity.GenerateName = "" - } var err error klog.V(5).Infof("Capacity Controller: creating new object for %+v, new capacity %v", item, quantity) - capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities().Create(ctx, capacity, metav1.CreateOptions{}) + capacity, err = c.client.StorageV1alpha1().CSIStorageCapacities(c.ownerNamespace).Create(ctx, capacity, metav1.CreateOptions{}) if err != nil { return fmt.Errorf("create CSIStorageCapacity for %+v: %v", item, err) } @@ -455,12 +448,12 @@ func (c *Controller) syncCapacity(ctx context.Context, item workItem) error { func (c *Controller) deleteCapacity(ctx context.Context, capacity *storagev1alpha1.CSIStorageCapacity) error { klog.V(5).Infof("Capacity Controller: removing CSIStorageCapacity %s", capacity.Name) - return c.client.StorageV1alpha1().CSIStorageCapacities().Delete(ctx, capacity.Name, metav1.DeleteOptions{}) + return c.client.StorageV1alpha1().CSIStorageCapacities(capacity.Namespace).Delete(ctx, capacity.Name, metav1.DeleteOptions{}) } func (c *Controller) syncCSIStorageObjects(ctx context.Context) error { klog.V(3).Infof("Capacity Controller: syncing CSIStorageCapacity objects") - capacities, err := c.client.StorageV1alpha1().CSIStorageCapacities().List(ctx, metav1.ListOptions{}) + capacities, err := c.client.StorageV1alpha1().CSIStorageCapacities(c.ownerNamespace).List(ctx, metav1.ListOptions{}) if err != nil { return err } diff --git a/pkg/capacity/controller_test.go b/pkg/capacity/controller_test.go index aa2335bee..9818a9d13 100644 --- a/pkg/capacity/controller_test.go +++ b/pkg/capacity/controller_test.go @@ -31,10 +31,13 @@ import ( storagev1alpha1 "k8s.io/api/storage/v1alpha1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" + krand "k8s.io/apimachinery/pkg/util/rand" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/informers" fakeclientset "k8s.io/client-go/kubernetes/fake" + ktesting "k8s.io/client-go/testing" "k8s.io/client-go/util/workqueue" "k8s.io/klog" ) @@ -44,7 +47,8 @@ func init() { } const ( - driverName = "test-driver" + driverName = "test-driver" + ownerNamespace = "testns" ) var ( @@ -153,6 +157,8 @@ func TestController(t *testing.T) { // TODO: multiple segments, multiple classes, both // TODO: remove stale objects // TODO: update tests - remove segment, remove class + // TODO: reuse existing CSIStorageClasses - must check that the exact same objects are used, not just something semantically equivalent! + // TODO: check that modifications by others are reverted (deleting an object, adding one, modifying capacity, modifying owner) } for name, tc := range testcases { @@ -160,29 +166,25 @@ func TestController(t *testing.T) { t.Run(name, func(t *testing.T) { // There is no good way to shut down the controller. It spawns // various goroutines and some of them (in particular shared informer) - // become very unhappy ("close on close channel") when using a context + // become very unhappy ("close on closed channel") when using a context // that gets cancelled. Therefore we just keep everything running. ctx := context.Background() - // ctx, cancel := context.WithCancel(context.Background()) - // defer cancel() var objects []runtime.Object objects = append(objects, makeCapacities(tc.initialCapacities)...) objects = append(objects, makeSCs(tc.initialSCs)...) clientSet := fakeclientset.NewSimpleClientset(objects...) + clientSet.PrependReactor("create", "csistoragecapacity", generateNameReactor) c := fakeController(ctx, clientSet, &tc.storage, &tc.topology) - - // We don't know when the controller is in a quiesence state, - // we can only give it some time and then check. - // TODO (?): use some mock queue and process items until the queue is empty? - go c.Run(ctx, 1) - time.Sleep(time.Second) - + c.prepare(ctx) + if err := process(ctx, c); err != nil { + t.Fatalf("unexpected error: %v", err) + } matched := map[testCapacity]bool{} for _, expected := range tc.expectedCapacities { matched[expected] = false } - actualCapacities, err := clientSet.StorageV1alpha1().CSIStorageCapacities().List(ctx, metav1.ListOptions{}) + actualCapacities, err := clientSet.StorageV1alpha1().CSIStorageCapacities(ownerNamespace).List(ctx, metav1.ListOptions{}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -216,6 +218,16 @@ func TestController(t *testing.T) { } +// generateNameReactor implements the logic required for the GenerateName field to work when using +// the fake client. Add it with client.PrependReactor to your fake client. +func generateNameReactor(action ktesting.Action) (handled bool, ret runtime.Object, err error) { + s := action.(ktesting.CreateAction).GetObject().(*storagev1alpha1.CSIStorageCapacity) + if s.Name == "" && s.GenerateName != "" { + s.Name = fmt.Sprintf("%s-%s", s.GenerateName, krand.String(16)) + } + return false, nil, nil +} + func fakeController(ctx context.Context, client *fakeclientset.Clientset, storage CSICapacityClient, topologyInformer topology.Informer) *Controller { utilruntime.ReallyCrash = false // avoids os.Exit after "close of closed channel" in shared informer code @@ -227,19 +239,68 @@ func fakeController(ctx context.Context, client *fakeclientset.Clientset, storag rateLimiter := workqueue.NewItemExponentialFailureRateLimiter(time.Second, 2*time.Second) queue := workqueue.NewNamedRateLimitingQueue(rateLimiter, "items") - // Not needed? - // informerFactory.WaitForCacheSync(ctx.Done()) - // go informerFactory.Start(ctx.Done()) - - return NewCentralCapacityController( + c := NewCentralCapacityController( storage, driverName, client, queue, owner, + ownerNamespace, topologyInformer, scInformer, - 0 /* enumerate objects */) + ) + + // This ensures that the informers are running and up-to-date. + go informerFactory.Start(ctx.Done()) + informerFactory.WaitForCacheSync(ctx.Done()) + + return c +} + +// process handles work items until the queue is empty and the informers are synced. +func process(ctx context.Context, c *Controller) error { + for { + if c.queue.Len() == 0 { + done, err := storageClassesSynced(ctx, c) + if err != nil { + return fmt.Errorf("check storage classes: %v", err) + } + if done { + return nil + } + } + // There's no atomic "try to get a work item". Let's + // check one more time before potentially blocking + // in c.queue.Get(). + if c.queue.Len() > 0 { + c.processNextWorkItem(ctx) + } + } +} + +func storageClassesSynced(ctx context.Context, c *Controller) (bool, error) { + actualStorageClasses, err := c.client.StorageV1().StorageClasses().List(ctx, metav1.ListOptions{}) + if err != nil { + return false, err + } + informerStorageClasses, err := c.scInformer.Lister().List(labels.Everything()) + if len(informerStorageClasses) != len(actualStorageClasses.Items) { + return false, nil + } + if len(informerStorageClasses) > 0 && !func() bool { + for _, actualStorageClass := range actualStorageClasses.Items { + for _, informerStorageClass := range informerStorageClasses { + if reflect.DeepEqual(actualStorageClass, *informerStorageClass) { + return true + } + } + } + return false + }() { + return false, nil + } + + return true, nil } const ( diff --git a/pkg/capacity/topology/doc.go b/pkg/capacity/topology/doc.go index 79a154221..6adfde6a4 100644 --- a/pkg/capacity/topology/doc.go +++ b/pkg/capacity/topology/doc.go @@ -19,115 +19,3 @@ limitations under the License. // which does that based on the CSINodeDriver.TopologyKeys and the // corresponding labels for the nodes. package topology - -import ( - "context" - "sort" - "strings" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) - -// Segment represents a topology segment. Entries are always sorted by -// key and keys are unique. In contrast to a map, segments therefore -// can be compared efficiently. A nil segment matches no nodes -// nodes in a cluster, an empty segment all of them. -type Segment []SegmentEntry - -var _ sort.Interface = Segment{} - -func (s Segment) String() string { - var parts []string - for _, entry := range s { - parts = append(parts, entry.String()) - } - return "{" + strings.Join(parts, ", ") + "}" -} - -// Compare returns -1 if s is considered smaller than the other segment (less keys, -// keys and/or values smaller), 0 if equal and 1 otherwise. -func (s Segment) Compare(other Segment) int { - if len(s) < len(other) { - return -1 - } - if len(s) > len(other) { - return 1 - } - for i := 0; i < len(s); i++ { - cmp := s[i].Compare(other[i]) - if cmp != 0 { - return cmp - } - } - return 0 -} - -func (s Segment) Len() int { return len(s) } -func (s Segment) Less(i, j int) bool { return s[i].Compare(s[j]) < 0 } -func (s Segment) Swap(i, j int) { - entry := s[i] - s[i] = s[j] - s[j] = entry -} - -// SegmentEntry represents one topology key/value pair. -type SegmentEntry struct { - Key, Value string -} - -func (se SegmentEntry) String() string { - return se.Key + ": " + se.Value -} - -// Compare returns -1 if se is considered smaller than the other segment entry (key or value smaller), -// 0 if equal and 1 otherwise. -func (se SegmentEntry) Compare(other SegmentEntry) int { - cmp := strings.Compare(se.Key, other.Key) - if cmp != 0 { - return cmp - } - return strings.Compare(se.Value, other.Value) -} - -// GetLabelSelector returns a LabelSelector with the key/value entries -// as label match criteria. -func (s Segment) GetLabelSelector() *metav1.LabelSelector { - return &metav1.LabelSelector{ - MatchLabels: s.GetLabelMap(), - } -} - -// GetLabelMap returns nil if the Segment itself is nil, -// otherwise a map with all key/value pairs. -func (s Segment) GetLabelMap() map[string]string { - if s == nil { - return nil - } - labels := map[string]string{} - for _, entry := range s { - labels[entry.Key] = entry.Value - } - return labels -} - -// Informer keeps a list of discovered topology segments and can -// notify one or more clients when it discovers changes. Segments -// are identified by their address and guaranteed to be unique. -type Informer interface { - // AddCallback ensures that the function is called each time - // changes to the list of segments are detected. It also gets - // called immediately when adding the callback and there are - // already some known segments. - AddCallback(cb Callback) - - // List returns all known segments, in no particular order. - List() []*Segment - - // Run starts watching for changes. - Run(ctx context.Context) - - // HasSynced returns true once all segments have been found. - HasSynced() bool -} - -type Callback func(added []*Segment, removed []*Segment) diff --git a/pkg/capacity/topology/nodes.go b/pkg/capacity/topology/nodes.go index 1d8012139..7a24ea73e 100644 --- a/pkg/capacity/topology/nodes.go +++ b/pkg/capacity/topology/nodes.go @@ -25,14 +25,12 @@ import ( "reflect" "sort" "sync" - "time" v1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" apierrs "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/labels" utilruntime "k8s.io/apimachinery/pkg/util/runtime" - "k8s.io/apimachinery/pkg/util/wait" coreinformersv1 "k8s.io/client-go/informers/core/v1" storageinformersv1 "k8s.io/client-go/informers/storage/v1" "k8s.io/client-go/kubernetes" @@ -61,19 +59,29 @@ func NewNodeTopology( // immediately, but it is better to let the input data settle // a bit and just remember that there is work to be done. nodeHandler := cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { queue.Add("") }, + AddFunc: func(obj interface{}) { + klog.V(5).Infof("capacity topology: new node: %s", obj.(*v1.Node).Name) + queue.Add("") + }, UpdateFunc: func(oldObj interface{}, newObj interface{}) { if reflect.DeepEqual(oldObj.(*v1.Node).Labels, newObj.(*v1.Node).Labels) { // Shortcut: labels haven't changed, no need to sync. return } + klog.V(5).Infof("capacity topology: updated node: %s", newObj.(*v1.Node).Name) + queue.Add("") + }, + DeleteFunc: func(obj interface{}) { + klog.V(5).Infof("capacity topology: removed node: %s", obj.(*v1.Node).Name) queue.Add("") }, - DeleteFunc: func(obj interface{}) { queue.Add("") }, } nodeInformer.Informer().AddEventHandler(nodeHandler) csiNodeHandler := cache.ResourceEventHandlerFuncs{ - AddFunc: func(obj interface{}) { queue.Add("") }, + AddFunc: func(obj interface{}) { + klog.V(5).Infof("capacity topology: new CSINode: %s", obj.(*storagev1.CSINode).Name) + queue.Add("") + }, UpdateFunc: func(oldObj interface{}, newObj interface{}) { oldKeys := nt.driverTopologyKeys(oldObj.(*storagev1.CSINode)) newKeys := nt.driverTopologyKeys(newObj.(*storagev1.CSINode)) @@ -81,9 +89,13 @@ func NewNodeTopology( // Shortcut: keys haven't changed, no need to sync. return } + klog.V(5).Infof("capacity topology: updated CSINode: %s", newObj.(*storagev1.CSINode).Name) + queue.Add("") + }, + DeleteFunc: func(obj interface{}) { + klog.V(5).Infof("capacity topology: removed CSINode: %s", obj.(*storagev1.CSINode).Name) queue.Add("") }, - DeleteFunc: func(obj interface{}) { queue.Add("") }, } csiNodeInformer.Informer().AddEventHandler(csiNodeHandler) @@ -140,9 +152,7 @@ func (nt *nodeTopology) List() []*Segment { func (nt *nodeTopology) Run(ctx context.Context) { go nt.nodeInformer.Informer().Run(ctx.Done()) go nt.csiNodeInformer.Informer().Run(ctx.Done()) - go wait.Until(func() { - nt.runWorker(ctx) - }, time.Second, ctx.Done()) + go nt.runWorker(ctx) klog.Info("Started node topology informer") <-ctx.Done() diff --git a/pkg/capacity/topology/nodes_test.go b/pkg/capacity/topology/nodes_test.go index 1a40b42b3..a57bf668d 100644 --- a/pkg/capacity/topology/nodes_test.go +++ b/pkg/capacity/topology/nodes_test.go @@ -18,6 +18,8 @@ package topology import ( "context" + "fmt" + "reflect" "sort" "testing" "time" @@ -25,7 +27,9 @@ import ( v1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/informers" fakeclientset "k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/util/workqueue" @@ -64,6 +68,16 @@ var ( {networkStorageKeys[1], "NY"}, {networkStorageKeys[2], "1"}, } + networkStorageLabels2 = map[string]string{ + networkStorageKeys[0]: "US", + networkStorageKeys[1]: "NY", + networkStorageKeys[2]: "2", + } + networkStorage2 = &Segment{ + {networkStorageKeys[0], "US"}, + {networkStorageKeys[1], "NY"}, + {networkStorageKeys[2], "2"}, + } ) func removeNode(t *testing.T, client *fakeclientset.Clientset, nodeName string) { @@ -232,6 +246,25 @@ func TestNodeTopology(t *testing.T) { }, expectedSegments: []*Segment{localStorageNode1, networkStorage}, }, + "partial-match": { + initialNodes: []testNode{ + { + name: node1, + driverKeys: map[string][]string{ + driverName: networkStorageKeys, + }, + labels: networkStorageLabels, + }, + { + name: node2, + driverKeys: map[string][]string{ + driverName: networkStorageKeys, + }, + labels: networkStorageLabels2, + }, + }, + expectedSegments: []*Segment{networkStorage, networkStorage2}, + }, "unsorted-keys": { initialNodes: []testNode{ { @@ -346,13 +379,15 @@ func TestNodeTopology(t *testing.T) { for name, tc := range testcases { // Not run in parallel. That doesn't work well in combination with global logging. t.Run(name, func(t *testing.T) { - // There is no good way to shut down the controller. It spawns + // There is no good way to shut down the informers. They spawn // various goroutines and some of them (in particular shared informer) - // become very unhappy ("close on close channel") when using a context + // become very unhappy ("close on closed channel") when using a context // that gets cancelled. Therefore we just keep everything running. + // + // The informers also catch up with changes made via the client API + // asynchronously. To ensure expected input for sync(), we wait until + // the content of the informers is identical to what is currently stored. ctx := context.Background() - // ctx, cancel := context.WithCancel(context.Background()) - // defer cancel() testDriverName := tc.driverName if testDriverName == "" { @@ -363,42 +398,50 @@ func TestNodeTopology(t *testing.T) { objects = append(objects, makeNodes(tc.initialNodes)...) clientSet := fakeclientset.NewSimpleClientset(objects...) nt := fakeNodeTopology(ctx, testDriverName, clientSet) - - added, removed, _ := addTestCallback(nt) - go nt.Run(ctx) - - time.Sleep(1 * time.Second) - if len(added) != len(tc.expectedSegments) { - t.Errorf("unexpected added segments reported via callback: %v", added) - } - if len(removed) != 0 { - t.Errorf("unexpected removed segments reported via callback: %v", removed) + if err := waitForInformers(ctx, nt); err != nil { + t.Fatalf("unexpected error: %v", err) } - validate(t, nt, tc.expectedSegments) + validate(t, nt, tc.expectedSegments, nil, tc.expectedSegments) if tc.update != nil { - added, removed, _ := addTestCallback(nt) tc.update(t, clientSet) - time.Sleep(5 * time.Second) - for segment := range added { - if containsSegment(tc.expectedSegments, segment) || !containsSegment(tc.expectedUpdatedSegments, segment) { - t.Errorf("unexpected added segments during update: %v", added) + if err := waitForInformers(ctx, nt); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Determine the expected changes based on the delta. + var expectedAdded, expectedRemoved []*Segment + for _, segment := range tc.expectedUpdatedSegments { + if !containsSegment(tc.expectedSegments, segment) { + expectedAdded = append(expectedAdded, segment) } } - for segment := range removed { - if !containsSegment(tc.expectedSegments, segment) || containsSegment(tc.expectedUpdatedSegments, segment) { - t.Errorf("unexpected removed segments during update: %v", removed) + for _, segment := range tc.expectedSegments { + if !containsSegment(tc.expectedUpdatedSegments, segment) { + expectedRemoved = append(expectedRemoved, segment) } } - validate(t, nt, tc.expectedUpdatedSegments) + validate(t, nt, expectedAdded, expectedRemoved, tc.expectedUpdatedSegments) } }) } } -func addTestCallback(nt *nodeTopology) (added, removed map[*Segment]bool, called *bool) { - added = map[*Segment]bool{} - removed = map[*Segment]bool{} +type segmentsFound map[*Segment]bool + +func (sf segmentsFound) Found() []*Segment { + var found []*Segment + for key, value := range sf { + if value { + found = append(found, key) + } + } + return found +} + +func addTestCallback(nt *nodeTopology) (added, removed segmentsFound, called *bool) { + added = segmentsFound{} + removed = segmentsFound{} called = new(bool) nt.AddCallback(func(a, r []*Segment) { *called = true @@ -445,22 +488,81 @@ func fakeNodeTopology(ctx context.Context, testDriverName string, client *fakecl return nt } -func validate(t *testing.T, nt *nodeTopology, expected []*Segment) { - // No changes expected. +func waitForInformers(ctx context.Context, nt *nodeTopology) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + err := wait.PollImmediateUntil(time.Millisecond, func() (bool, error) { + actualNodes, err := nt.client.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) + if err != nil { + return false, err + } + informerNodes, err := nt.nodeInformer.Lister().List(labels.Everything()) + if len(informerNodes) != len(actualNodes.Items) { + return false, nil + } + if len(informerNodes) > 0 && !func() bool { + for _, actualNode := range actualNodes.Items { + for _, informerNode := range informerNodes { + if reflect.DeepEqual(actualNode, *informerNode) { + return true + } + } + } + return false + }() { + return false, nil + } + + actualCSINodes, err := nt.client.StorageV1().CSINodes().List(ctx, metav1.ListOptions{}) + if err != nil { + return false, err + } + informerCSINodes, err := nt.csiNodeInformer.Lister().List(labels.Everything()) + if len(informerCSINodes) != len(actualCSINodes.Items) { + return false, nil + } + if len(informerCSINodes) > 0 && !func() bool { + for _, actualCSINode := range actualCSINodes.Items { + for _, informerCSINode := range informerCSINodes { + if reflect.DeepEqual(actualCSINode, *informerCSINode) { + return true + } + } + } + return false + }() { + return false, nil + } + + return true, nil + }, ctx.Done()) + if err != nil { + return fmt.Errorf("get informers in sync: %v", err) + } + return nil +} + +func validate(t *testing.T, nt *nodeTopology, expectedAdded, expectedRemoved, expectedAll []*Segment) { added, removed, called := addTestCallback(nt) nt.sync(context.Background()) - if *called { - t.Errorf("sync should not have invoked callbacks") - } - if len(added) > 0 { - t.Errorf("unexpeced added segments: %v", added) + expectedChanges := len(expectedAdded) > 0 || len(expectedRemoved) > 0 + if expectedChanges && !*called { + t.Error("change callback not invoked") } - if len(removed) > 0 { - t.Errorf("unexpected removed segments: %v", removed) + if !expectedChanges && *called { + t.Error("change callback invoked unexpectedly") } + validateSegments(t, "added", added.Found(), expectedAdded) + validateSegments(t, "removed", removed.Found(), expectedRemoved) + validateSegments(t, "final", nt.List(), expectedAll) - actual := nt.List() + if t.Failed() { + t.FailNow() + } +} +func validateSegments(t *testing.T, what string, actual, expected []*Segment) { // We can just compare the string representation because that covers all // relevant content of the segments and is readable. found := map[string]bool{} @@ -470,7 +572,7 @@ func validate(t *testing.T, nt *nodeTopology, expected []*Segment) { for _, str := range segmentsToStrings(actual) { _, exists := found[str] if !exists { - t.Errorf("unexpected segment: %s", str) + t.Errorf("unexpected %s segment: %s", what, str) t.Fail() continue } @@ -478,13 +580,10 @@ func validate(t *testing.T, nt *nodeTopology, expected []*Segment) { } for str, matched := range found { if !matched { - t.Errorf("expected segment not found: %s", str) + t.Errorf("expected %s segment not found: %s", what, str) t.Fail() } } - if t.Failed() { - t.FailNow() - } } func segmentsToStrings(segments []*Segment) []string {