diff --git a/pkg/proxy/internal/util.go b/pkg/proxy/internal/util.go index 6543521..059d3c5 100644 --- a/pkg/proxy/internal/util.go +++ b/pkg/proxy/internal/util.go @@ -46,8 +46,37 @@ func writeYaml(file string, data interface{}) error { if err != nil { return fmt.Errorf("failed to marshal yaml: %w", err) } - if err := os.WriteFile(file, b, 0664); err != nil { - return fmt.Errorf("failed to write file: %w", err) + + // NOTE(Hue): This is technically not a good thing to do here, + // because for whatever reason we might *want* to write an empty file. + // And we're also making this function context aware (read coupled) + // which is not a good thing. + // However, since this function is **only** used to write `provider.yaml` file and + // we know we don't want that file to be empty, we can safely return an error here. + // Make sure to remove this check if the above statement is no longer true. + if data == nil || len(b) == 0 { + return fmt.Errorf("empty yaml data") + } + + dir := filepath.Dir(file) + tmpFile, err := os.CreateTemp(dir, "tmp-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.Write(b); err != nil { + tmpFile.Close() + return fmt.Errorf("failed to write to temp file: %w", err) } + + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + if err := os.Rename(tmpFile.Name(), file); err != nil { + return fmt.Errorf("failed to rename temp file to target file: %w", err) + } + return nil } diff --git a/pkg/proxy/internal/util_test.go b/pkg/proxy/internal/util_test.go index 12ca763..6d0bbe8 100644 --- a/pkg/proxy/internal/util_test.go +++ b/pkg/proxy/internal/util_test.go @@ -2,7 +2,10 @@ package internal import ( "os" + "sync" "testing" + + . "github.com/onsi/gomega" ) func TestGetDefaultProviderFile(t *testing.T) { @@ -35,3 +38,53 @@ func TestGetDefaultProviderFile(t *testing.T) { }) } } + +func TestWriteFile(t *testing.T) { + t.Run("EmptyData", func(t *testing.T) { + g := NewWithT(t) + + file, err := os.CreateTemp("", "test-*.yaml") + defer os.Remove(file.Name()) + + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(writeYaml(file.Name(), nil)).ToNot(Succeed()) + }) + + t.Run("ValidData", func(t *testing.T) { + g := NewWithT(t) + + file, err := os.CreateTemp("", "test-*.yaml") + defer os.Remove(file.Name()) + + g.Expect(err).ToNot(HaveOccurred()) + + const ( + numWriters = 200 + numIterations = 200 + ) + + var wg sync.WaitGroup + wg.Add(numWriters) + + // The data to write to the file + testData := map[string]interface{}{ + "key": "value", + } + + for i := 0; i < numWriters; i++ { + go func(writerID int) { + defer wg.Done() + + for j := 0; j < numIterations; j++ { + g.Expect(writeYaml(file.Name(), testData)).To(Succeed()) + + content, err := os.ReadFile(file.Name()) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(string(content)).To(Equal("key: value\n")) + } + }(i) + } + + wg.Wait() + }) +}