diff --git a/pkg/cloudmeta/metadata.go b/pkg/cloudmeta/metadata.go index 15aee9a17..7a2e7c37b 100644 --- a/pkg/cloudmeta/metadata.go +++ b/pkg/cloudmeta/metadata.go @@ -87,23 +87,31 @@ func (cloud *CloudMeta) GetInstanceMetadata(ctx context.Context) (InstanceMetada go func(provider CloudMetadataProvider) { meta, err := cloud.runWithTimeout(ctx, provider) - select { - case <-ctx.Done(): + if ctx.Err() != nil { return - case results <- msg{meta: meta, err: err}: } + + results <- msg{meta: meta, err: err} }(provider) } // Return the first non error result or wait until all providers return err. var mErr error for range len(cloud.providers) { - res := <-results - if res.err != nil { - mErr = multierr.Append(mErr, res.err) - continue + select { + case <-ctx.Done(): + return InstanceMetadata{}, ctx.Err() + case res := <-results: + // Additional context check just in case messages in results and in ctx.Done channels were available at the same time. + if err := ctx.Err(); err != nil { + return InstanceMetadata{}, err + } + if res.err != nil { + mErr = multierr.Append(mErr, res.err) + continue + } + return res.meta, nil } - return res.meta, nil } return InstanceMetadata{}, mErr } diff --git a/pkg/cloudmeta/metadata_test.go b/pkg/cloudmeta/metadata_test.go index a2d5b782d..fe9fb7dde 100644 --- a/pkg/cloudmeta/metadata_test.go +++ b/pkg/cloudmeta/metadata_test.go @@ -77,7 +77,8 @@ func TestGetInstanceMetadata(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { cloudmeta := &CloudMeta{ - providers: tc.providers, + providers: tc.providers, + providerTimeout: 1 * time.Second, } meta, err := cloudmeta.GetInstanceMetadata(context.Background()) @@ -101,6 +102,31 @@ func TestGetInstanceMetadata(t *testing.T) { } } +func TestGetInstanceMetadataWithCancelledContext(t *testing.T) { + cloudmeta := &CloudMeta{ + providers: []CloudMetadataProvider{ + newTestProvider(t, "test_provider_1", "x-test-1", 1*time.Second, nil), + }, + providerTimeout: 100 * time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + _ = time.AfterFunc(50*time.Millisecond, cancel) + + meta, err := cloudmeta.GetInstanceMetadata(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + + if meta.CloudProvider != "" { + t.Fatalf("meta.CloudProvider should be empty, got %s", meta.CloudProvider) + } + + if meta.InstanceType != "" { + t.Fatalf("meta.InstanceType should be empty, got %s", meta.InstanceType) + } +} + func newTestProvider(t *testing.T, providerName, instanceType string, latency time.Duration, err error) *testProvider { t.Helper() @@ -120,7 +146,11 @@ type testProvider struct { } func (tp testProvider) Metadata(ctx context.Context) (InstanceMetadata, error) { - time.Sleep(tp.latency) + select { + case <-time.After(tp.latency): + case <-ctx.Done(): + return InstanceMetadata{}, ctx.Err() + } if tp.err != nil { return InstanceMetadata{}, tp.err @@ -128,5 +158,5 @@ func (tp testProvider) Metadata(ctx context.Context) (InstanceMetadata, error) { return InstanceMetadata{ CloudProvider: tp.name, InstanceType: tp.instanceType, - }, nil + }, ctx.Err() }