From b1113e56205dbbec6d6392cd1339f6f5253ccc34 Mon Sep 17 00:00:00 2001 From: Bhoopesh Date: Sat, 26 Oct 2024 03:13:08 +0530 Subject: [PATCH] fix: use rand no for temp file Signed-off-by: Bhoopesh --- sztp-agent/pkg/secureagent/status.go | 1 - sztp-agent/pkg/secureagent/status_test.go | 35 ++++++++++++++++++++--- sztp-agent/pkg/secureagent/utils.go | 17 ++++++++--- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index 6738c27..4922a45 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -169,7 +169,6 @@ func (a *Agent) updateStageStatus(status *Status, stageType StageType, isStart b return fmt.Errorf("unknown stage: %s", stage) } - // Update the current stage if isStart { status.Stage = stage + "-in-progress" } else { diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index afdf405..75c9470 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -4,9 +4,36 @@ // Package secureagent implements the secure agent package secureagent -import "testing" +import ( + "testing" +) + +const StatusTestContent = `{ + "init": {"errors": [], "start": 1729891263, "end": 0}, + "downloading-file": {"errors": [], "start": 0, "end": 0}, + "pending-reboot": {"errors": [], "start": 0, "end": 0}, + "parsing": {"errors": [], "start": 0, "end": 0}, + "onboarding": {"errors": [], "start": 0, "end": 0}, + "redirect": {"errors": [], "start": 0, "end": 0}, + "boot-image": {"errors": [], "start": 1729891263, "end": 1729891263}, + "pre-script": {"errors": [], "start": 1729891264, "end": 1729891264}, + "config": {"errors": [], "start": 1729891264, "end": 1729891264}, + "post-script": {"errors": [], "start": 1729891264, "end": 1729891264}, + "bootstrap": {"errors": [], "start": 1729891263, "end": 1729891264}, + "is-completed": {"errors": [], "start": 1729891263, "end": 1729891264}, + "informational": "", + "stage": "is-completed-completed" +}` + +const ResultTestContent = `{ + "errors": ["error1", "error2"], +}` func TestAgent_RunCommandStatus(t *testing.T) { + testStatusFile := "/tmp/sztp/status.json" + testResultFile := "/tmp/sztp/result.json" + testSymLinkDir := "/tmp/symlink" + type fields struct { BootstrapURL string SerialNumber string @@ -44,9 +71,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, - StatusFilePath: "/var/lib/sztp/status.json", - ResultFilePath: "/var/lib/sztp/result.json", - SymLinkDir: "/run/sztp", + StatusFilePath: testStatusFile, + ResultFilePath: testResultFile, + SymLinkDir: testSymLinkDir, }, }, } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 23ad4f6..21aa3c7 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -9,6 +9,7 @@ Copyright (C) 2022 Red Hat. package secureagent import ( + "crypto/rand" "crypto/sha256" "encoding/json" "fmt" @@ -90,7 +91,9 @@ func calculateSHA256File(filePath string) (string, error) { } func saveToFile(data interface{}, filePath string) error { - tempPath := filePath + ".tmp" + filePath = filepath.Clean(filePath) + random, _ := rand.Prime(rand.Reader, 64) + tempPath := fmt.Sprintf("%s.%d.tmp", filePath, random) // rand number to avoid conflicts when multiple agents are running tempPath = filepath.Clean(tempPath) file, err := os.Create(tempPath) if err != nil { @@ -108,7 +111,11 @@ func saveToFile(data interface{}, filePath string) error { } // Atomic move of temp file to replace the original. - return os.Rename(tempPath, filePath) + if err := os.Rename(tempPath, filePath); err != nil { + return fmt.Errorf("failed to rename %s to %s: %v", tempPath, filePath, err) + } + + return nil } func ensureDirExists(dir string) error { @@ -127,11 +134,13 @@ func ensureFileExists(filePath string) error { return err } + fmt.Printf("Checking if file %s exists...\n", filePath) + if _, err := os.Stat(filePath); os.IsNotExist(err) { filePath = filepath.Clean(filePath) file, err := os.Create(filePath) if err != nil { - return fmt.Errorf("failed to create file %s: %v", filePath, err) + return fmt.Errorf("[ERROR] failed to create file %s: %v", filePath, err) } defer func() { if err := file.Close(); err != nil { @@ -157,7 +166,7 @@ func createSymlink(targetFile, linkFile string) error { // Check if linkFile exists and is a symlink to targetFile if existingTarget, err := os.Readlink(linkFile); err == nil { if existingTarget == targetFile { - return nil // Symlink already points to the target; skip creation + return nil // Symlink already points to the target -> skip creation } // Remove the existing file (even if it's a wrong symlink or regular file) if err := os.Remove(linkFile); err != nil {