From 5496f676dadf35127b16d2ef8dd2608b7026edc7 Mon Sep 17 00:00:00 2001 From: Anton Myagkov Date: Tue, 8 Oct 2024 17:18:17 +0000 Subject: [PATCH] fix: handle kubelet restart --- cloud/blockstore/tests/csi_driver/test.py | 51 ++++++++++++++++ .../tools/csi_driver/internal/driver/node.go | 59 ++++++++++++++++++- 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/cloud/blockstore/tests/csi_driver/test.py b/cloud/blockstore/tests/csi_driver/test.py index c4354e5db37..287c06a5ca0 100644 --- a/cloud/blockstore/tests/csi_driver/test.py +++ b/cloud/blockstore/tests/csi_driver/test.py @@ -489,3 +489,54 @@ def test_publish_volume_twice_on_the_same_node(): with called_process_error_logged(): env.csi.delete_volume(volume_name) cleanup_after_test(env) + + +def test_restart_kubelet_with_old_format_endpoint(): + env, run = init() + try: + volume_name = "example-disk" + volume_size = 1024 ** 3 + pod_name1 = "example-pod-1" + pod_id1 = "deadbeef1" + env.csi.create_volume(name=volume_name, size=volume_size) + + # skip stage to create endpoint with old format + env.csi.publish_volume(pod_id1, volume_name, pod_name1) + env.csi.stage_volume(volume_name) + env.csi.publish_volume(pod_id1, volume_name, pod_name1) + except subprocess.CalledProcessError as e: + log_called_process_error(e) + raise + finally: + with called_process_error_logged(): + env.csi.unpublish_volume(pod_id1, volume_name) + with called_process_error_logged(): + env.csi.unstage_volume(volume_name) + with called_process_error_logged(): + env.csi.delete_volume(volume_name) + cleanup_after_test(env) + + +def test_restart_kubelet_with_new_format_endpoint(): + env, run = init() + try: + volume_name = "example-disk" + volume_size = 1024 ** 3 + pod_name1 = "example-pod-1" + pod_id1 = "deadbeef1" + env.csi.create_volume(name=volume_name, size=volume_size) + env.csi.stage_volume(volume_name) + env.csi.publish_volume(pod_id1, volume_name, pod_name1) + env.csi.stage_volume(volume_name) + env.csi.publish_volume(pod_id1, volume_name, pod_name1) + except subprocess.CalledProcessError as e: + log_called_process_error(e) + raise + finally: + with called_process_error_logged(): + env.csi.unpublish_volume(pod_id1, volume_name) + with called_process_error_logged(): + env.csi.unstage_volume(volume_name) + with called_process_error_logged(): + env.csi.delete_volume(volume_name) + cleanup_after_test(env) diff --git a/cloud/blockstore/tools/csi_driver/internal/driver/node.go b/cloud/blockstore/tools/csi_driver/internal/driver/node.go index 03a28180ca4..91d9e3cf8f7 100644 --- a/cloud/blockstore/tools/csi_driver/internal/driver/node.go +++ b/cloud/blockstore/tools/csi_driver/internal/driver/node.go @@ -650,12 +650,65 @@ func (s *nodeService) nodePublishDiskAsFilesystem( return nil } +func (s *nodeService) IsMountConflictError(err error) bool { + if err != nil { + var clientErr *nbsclient.ClientError + if errors.As(err, &clientErr) { + if clientErr.Code == nbsclient.E_MOUNT_CONFLICT { + return true + } + } + } + + return false +} + +func (s *nodeService) hasDeprecatedEndpoint( + ctx context.Context, + volumeId string) (bool, error) { + + listEndpointsResp, err := s.nbsClient.ListEndpoints( + ctx, &nbsapi.TListEndpointsRequest{}, + ) + if err != nil { + log.Printf("List endpoints failed %v", err) + return false, err + } + + if len(listEndpointsResp.Endpoints) == 0 { + return false, nil + } + + for _, endpoint := range listEndpointsResp.Endpoints { + if endpoint.DiskId == volumeId { + pathList := filepath.SplitList(endpoint.UnixSocketPath) + for _, path := range pathList { + if path == "v2" { + return false, nil + } + } + } + } + + return true, nil +} + func (s *nodeService) nodeStageDiskAsFilesystem( ctx context.Context, req *csi.NodeStageVolumeRequest) error { - resp, err := s.startNbsEndpointForNBD(ctx, s.nodeID, req.VolumeId, req.VolumeContext) + instanceId := filepath.Join("v2", s.nodeID) + resp, err := s.startNbsEndpointForNBD(ctx, instanceId, req.VolumeId, req.VolumeContext) if err != nil { + if s.IsMountConflictError(err) { + deprecatedEndpoint, err := s.hasDeprecatedEndpoint(ctx, req.VolumeId) + if err != nil { + return err + } + if deprecatedEndpoint { + return nil + } + } return fmt.Errorf("failed to start NBS endpoint: %w", err) } @@ -855,7 +908,7 @@ func (s *nodeService) nodeUnstageVolume( return err } - endpointDir := s.getEndpointDir(s.nodeID, req.VolumeId) + endpointDir := s.getEndpointDir(filepath.Join("v2", s.nodeID), req.VolumeId) if s.nbsClient != nil { _, err := s.nbsClient.StopEndpoint(ctx, &nbsapi.TStopEndpointRequest{ UnixSocketPath: filepath.Join(endpointDir, nbsSocketName), @@ -1300,7 +1353,7 @@ func (s *nodeService) NodeExpandVolume( endpointDirOld := s.getEndpointDir(podId, req.VolumeId) unixSocketPathOld := filepath.Join(endpointDirOld, nbsSocketName) - endpointDirNew := s.getEndpointDir(s.nodeID, req.VolumeId) + endpointDirNew := s.getEndpointDir(filepath.Join("v2", s.nodeID), req.VolumeId) unixSocketPathNew := filepath.Join(endpointDirNew, nbsSocketName) listEndpointsResp, err := s.nbsClient.ListEndpoints(