diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 9495c77593b..24542adb1a9 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "net" - "os" "strings" "testing" "time" @@ -453,8 +452,9 @@ func TestInitK8sNodeLocalConfig(t *testing.T) { expectedNodeConfig.NodeTransportIPv6Addr = tt.transportInterface.ipV6Net mockGetIPNetDeviceByCIDRs(t, tt.transportInterface.ipV4Net, tt.transportInterface.ipV6Net, tt.transportInterface.iface) } + + t.Setenv(env.NodeNameEnvKey, nodeName) mockGetIPNetDeviceFromIP(t, nodeIPNet, ipDevice) - mockNodeNameEnv(t, nodeName) mockGetNodeTimeout(t, 100*time.Millisecond) err := initializer.initK8sNodeLocalConfig(nodeName) @@ -479,11 +479,6 @@ func mockGetIPNetDeviceFromIP(t *testing.T, ipNet *net.IPNet, ipDevice *net.Inte t.Cleanup(func() { getIPNetDeviceFromIP = prevGetIPNetDeviceFromIP }) } -func mockNodeNameEnv(t *testing.T, name string) { - _ = os.Setenv(env.NodeNameEnvKey, name) - t.Cleanup(func() { os.Unsetenv(env.NodeNameEnvKey) }) -} - func mockGetNodeTimeout(t *testing.T, timeout time.Duration) { prevTimeout := getNodeTimeout getNodeTimeout = timeout @@ -826,11 +821,6 @@ func TestSetOVSDatapath(t *testing.T) { } } -func mockIPsecPSKEnv(t *testing.T, name string) { - os.Setenv(ipsecPSKEnvKey, name) - t.Cleanup(func() { os.Unsetenv(ipsecPSKEnvKey) }) -} - func TestReadIPSecPSK(t *testing.T) { tests := []struct { name string @@ -855,7 +845,7 @@ func TestReadIPSecPSK(t *testing.T) { }, } if tt.isIPsecPSK { - mockIPsecPSKEnv(t, "key") + t.Setenv(ipsecPSKEnvKey, "key") } err := initializer.readIPSecPSK() diff --git a/pkg/agent/nodeportlocal/npl_agent_test.go b/pkg/agent/nodeportlocal/npl_agent_test.go index a3bcdb43b86..53b33f61a01 100644 --- a/pkg/agent/nodeportlocal/npl_agent_test.go +++ b/pkg/agent/nodeportlocal/npl_agent_test.go @@ -21,7 +21,6 @@ import ( "context" "encoding/json" "fmt" - "os" "sync" "testing" "time" @@ -214,7 +213,7 @@ func (tc *testConfig) withCustomPodPortRulesExpectations(fn customizePodPortRule } func setUp(t *testing.T, tc *testConfig, objects ...runtime.Object) *testData { - os.Setenv("NODE_NAME", defaultNodeName) + t.Setenv("NODE_NAME", defaultNodeName) mockCtrl := gomock.NewController(t) @@ -302,7 +301,6 @@ func setUpWithTestServiceAndPod(t *testing.T, tc *testConfig, customNodePort *in func (t *testData) tearDown() { close(t.stopCh) t.wg.Wait() - os.Unsetenv("NODE_NAME") } func (t *testData) pollForPodAnnotation(podName string, found bool) ([]types.NPLAnnotation, error) { diff --git a/pkg/antctl/raw/set/flowaggregator/command_test.go b/pkg/antctl/raw/set/flowaggregator/command_test.go index b48656a5227..83ffd289cd2 100644 --- a/pkg/antctl/raw/set/flowaggregator/command_test.go +++ b/pkg/antctl/raw/set/flowaggregator/command_test.go @@ -156,10 +156,8 @@ func TestUpdateRunE(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { cmd := NewFlowAggregatorSetCommand() - os.Setenv("POD_NAMESPACE", tc.podNamespace) - os.Setenv("FA_CONFIG_MAP_NAME", tc.configMapName) - defer os.Unsetenv("POD_NAMESPACE") - defer os.Unsetenv("FA_CONFIG_MAP_NAME") + t.Setenv("POD_NAMESPACE", tc.podNamespace) + t.Setenv("FA_CONFIG_MAP_NAME", tc.configMapName) err := updateRunE(cmd, tc.args) if tc.expectedErr != "" { assert.ErrorContains(t, err, tc.expectedErr) diff --git a/pkg/apiserver/handlers/featuregates/handler_test.go b/pkg/apiserver/handlers/featuregates/handler_test.go index b39ef17485e..172b3a8a242 100644 --- a/pkg/apiserver/handlers/featuregates/handler_test.go +++ b/pkg/apiserver/handlers/featuregates/handler_test.go @@ -18,7 +18,6 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" @@ -160,8 +159,8 @@ func TestHandleFunc(t *testing.T) { }, ) - os.Setenv("POD_NAME", "antrea-controller-wotqiwth") - os.Setenv("ANTREA_CONFIG_MAP_NAME", "antrea-config-aswieut") + t.Setenv("POD_NAME", "antrea-controller-wotqiwth") + t.Setenv("ANTREA_CONFIG_MAP_NAME", "antrea-config-aswieut") handler := HandleFunc(fakeClient) req, err := http.NewRequest(http.MethodGet, "", nil) diff --git a/pkg/flowaggregator/exporter/clickhouse_test.go b/pkg/flowaggregator/exporter/clickhouse_test.go index 75703b9c1e5..c0a0f3f9a43 100644 --- a/pkg/flowaggregator/exporter/clickhouse_test.go +++ b/pkg/flowaggregator/exporter/clickhouse_test.go @@ -16,7 +16,6 @@ package exporter import ( "database/sql" - "os" "testing" "time" @@ -30,10 +29,8 @@ import ( ) func TestClickHouse_UpdateOptions(t *testing.T) { - os.Setenv("CH_USERNAME", "default") - os.Setenv("CH_PASSWORD", "default") - defer os.Unsetenv("CH_USERNAME") - defer os.Unsetenv("CH_PASSWORD") + t.Setenv("CH_USERNAME", "default") + t.Setenv("CH_PASSWORD", "default") PrepareClickHouseConnectionSaved := clickhouseclient.PrepareClickHouseConnection clickhouseclient.PrepareClickHouseConnection = func(input clickhouseclient.ClickHouseConfig) (*sql.DB, error) { return nil, nil diff --git a/pkg/util/env/env_test.go b/pkg/util/env/env_test.go index fce3fff66fb..bdc14c0ab88 100644 --- a/pkg/util/env/env_test.go +++ b/pkg/util/env/env_test.go @@ -33,22 +33,19 @@ func TestGetNodeName(t *testing.T) { } for k, v := range testTable { - compareNodeName(k, v, t) - } -} - -func compareNodeName(k, v string, t *testing.T) { - if k != "" { - _ = os.Setenv(NodeNameEnvKey, k) - defer os.Unsetenv(NodeNameEnvKey) - } - nodeName, err := GetNodeName() - if err != nil { - t.Errorf("Failure with expected name %s: %v", k, err) - return - } - if nodeName != v { - t.Errorf("Failed to retrieve nodename, want: %s, get: %s", v, nodeName) + t.Run("nodeName: "+k, func(t *testing.T) { + if k != "" { + t.Setenv(NodeNameEnvKey, k) + } + nodeName, err := GetNodeName() + if err != nil { + t.Errorf("Failure with expected name %s: %v", k, err) + return + } + if nodeName != v { + t.Errorf("Failed to retrieve nodename, want: %s, get: %s", v, nodeName) + } + }) } } @@ -60,18 +57,15 @@ func TestGetPodName(t *testing.T) { } for k, v := range testTable { - comparePodName(k, v, t) - } -} - -func comparePodName(k, v string, t *testing.T) { - if k != "" { - _ = os.Setenv(podNameEnvKey, k) - defer os.Unsetenv(podNameEnvKey) - } - podName := GetPodName() - if podName != v { - t.Errorf("Failed to retrieve pod name, want: %s, get: %s", v, podName) + t.Run("podName: "+k, func(t *testing.T) { + if k != "" { + t.Setenv(podNameEnvKey, k) + } + podName := GetPodName() + if podName != v { + t.Errorf("Failed to retrieve pod name, want: %s, get: %s", v, podName) + } + }) } } @@ -82,13 +76,14 @@ func TestGetAntreaConfigMapName(t *testing.T) { } for k, v := range testTable { - if k != "" { - _ = os.Setenv(antreaConfigMapEnvKey, k) - defer os.Unsetenv(antreaConfigMapEnvKey) - } - configMapName := GetAntreaConfigMapName() - if configMapName != v { - t.Errorf("Failed to retrieve antrea configmap name, want: %s, get: %s", v, configMapName) - } + t.Run("config: "+k, func(t *testing.T) { + if k != "" { + t.Setenv(antreaConfigMapEnvKey, k) + } + configMapName := GetAntreaConfigMapName() + if configMapName != v { + t.Errorf("Failed to retrieve antrea configmap name, want: %s, get: %s", v, configMapName) + } + }) } } diff --git a/pkg/util/k8s/client_test.go b/pkg/util/k8s/client_test.go index 4dae583598d..0b94a1a707a 100644 --- a/pkg/util/k8s/client_test.go +++ b/pkg/util/k8s/client_test.go @@ -108,8 +108,8 @@ func TestOverrideKubeAPIServer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer setEnvDuringTest(kubeServiceHostEnvKey, originalHost)() - defer setEnvDuringTest(kubeServicePortEnvKey, originalPort)() + t.Setenv(kubeServiceHostEnvKey, originalHost) + t.Setenv(kubeServicePortEnvKey, originalPort) OverrideKubeAPIServer(tt.kubeAPIServerOverride) assert.Equal(t, tt.expectHost, os.Getenv(kubeServiceHostEnvKey)) @@ -117,11 +117,3 @@ func TestOverrideKubeAPIServer(t *testing.T) { }) } } - -func setEnvDuringTest(key, value string) func() { - originalValue := os.Getenv(key) - os.Setenv(key, value) - return func() { - os.Setenv(key, originalValue) - } -}