From 660f4881eaead969307b37b3b8b63627949e16fd Mon Sep 17 00:00:00 2001 From: Henry Avetisyan Date: Tue, 21 May 2024 15:30:37 -0700 Subject: [PATCH] keep state when key/cert were backed up for restore in case of failure Signed-off-by: Henry Avetisyan --- libs/go/sia/util/util.go | 66 +++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/libs/go/sia/util/util.go b/libs/go/sia/util/util.go index 6fc13659c91..3f4e6d9e3c7 100644 --- a/libs/go/sia/util/util.go +++ b/libs/go/sia/util/util.go @@ -600,7 +600,7 @@ func Copy(sourceFile, destFile string, perm os.FileMode) error { } return os.WriteFile(destFile, sourceBytes, perm) } - // source file does not exist to take back up of. so no error + // source file does not exist to take backup of so no error return nil } @@ -852,20 +852,24 @@ func SaveRoleCertKey(key, cert []byte, keyFile, certFile, svcKeyFile, roleName s backUpKeyFile := fmt.Sprintf("%s/%s.key.pem", backupDir, roleName) backUpCertFile := fmt.Sprintf("%s/%s.cert.pem", backupDir, roleName) - // if we're not givena role key file, it means we're re-using our service private key + // if we're not given a role key file, it means we're re-using our service private key // thus there is no need to update any files + filesBackedUp := false if keyFile != "" { if rotateKey { err = EnsureBackUpDir(backupDir) if err != nil { return err } - // taking back up of key and cert - log.Printf("taking back up of cert: %s to %s and key: %s to %s\n", certFile, backUpCertFile, keyFile, backUpKeyFile) - err = CopyCertKeyFile(keyFile, backUpKeyFile, certFile, backUpCertFile, os.FileMode(fileMode), fileDirectUpdate) - if err != nil { - log.Printf("Error while taking back up %v\n", err) - return err + // taking backup of key and cert + if FileExists(keyFile) || FileExists(certFile) { + log.Printf("taking backup of cert: %s to %s and key: %s to %s\n", certFile, backUpCertFile, keyFile, backUpKeyFile) + err = CopyCertKeyFile(keyFile, backUpKeyFile, certFile, backUpCertFile, os.FileMode(fileMode), fileDirectUpdate) + if err != nil { + log.Printf("Error while taking backup %v\n", err) + return err + } + filesBackedUp = true } //write the new key and x509KeyPair to disk log.Printf("writing new key file: %s to disk\n", keyFile) @@ -887,7 +891,7 @@ func SaveRoleCertKey(key, cert []byte, keyFile, certFile, svcKeyFile, roleName s UpdateKeyOwnership(keyFile, uid, gid, os.FileMode(fileMode), fileDirectUpdate) } } else { - // since we're using our service key file, let's set the key file as such + // since we're using our service key file, let's set the key file as such, // so we can load and validate the x509KeyPair later in this method keyFile = svcKeyFile } @@ -903,20 +907,28 @@ func SaveRoleCertKey(key, cert []byte, keyFile, certFile, svcKeyFile, roleName s x509KeyPair, err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { log.Printf("x509KeyPair: %s, key: %s do not match, error: %v\n", certFile, keyFile, err) - return CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + // restore the original contents only if we had successfully backed up the files + if filesBackedUp { + err = CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + } + return err } _, err = x509.ParseCertificate(x509KeyPair.Certificate[0]) if err != nil { log.Printf("x509KeyPair: %s, key: %s, unable to parse cert, error: %v\n", certFile, keyFile, err) - return CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + // restore the original contents only if we had successfully backed up the files + if filesBackedUp { + err = CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + } + return err } return nil } -// SaveServiceCertKey writes the key and cert to disk and takes back up of existing key and cert if rotateKey is true -// this method is only called when we're refreshing the service certificate. during service registeration we directly +// SaveServiceCertKey writes the key and cert to disk and takes backup of existing key and cert if rotateKey is true +// this method is only called when we're refreshing the service certificate. during service registration we directly // update key/cert/ca-cert files func SaveServiceCertKey(key, cert []byte, keyFile, certFile, serviceName string, uid, gid, fileMode int, rotateKey bool, backupDir string, fileDirectUpdate bool) error { // perform validation of x509KeyPair pair match before writing to disk @@ -932,17 +944,21 @@ func SaveServiceCertKey(key, cert []byte, keyFile, certFile, serviceName string, backUpKeyFile := fmt.Sprintf("%s/%s.key.pem", backupDir, serviceName) backUpCertFile := fmt.Sprintf("%s/%s.cert.pem", backupDir, serviceName) + filesBackedUp := false if rotateKey { err = EnsureBackUpDir(backupDir) if err != nil { return err } - // taking back up of key and cert - log.Printf("taking back up of cert: %s to %s and key: %s to %s\n", certFile, backUpCertFile, keyFile, backUpKeyFile) - err = CopyCertKeyFile(keyFile, backUpKeyFile, certFile, backUpCertFile, os.FileMode(fileMode), fileDirectUpdate) - if err != nil { - log.Printf("Error while taking back up %v\n", err) - return err + // taking backup of key and cert + if FileExists(keyFile) || FileExists(certFile) { + log.Printf("taking backup of cert: %s to %s and key: %s to %s\n", certFile, backUpCertFile, keyFile, backUpKeyFile) + err = CopyCertKeyFile(keyFile, backUpKeyFile, certFile, backUpCertFile, os.FileMode(fileMode), fileDirectUpdate) + if err != nil { + log.Printf("Error while taking backup %v\n", err) + return err + } + filesBackedUp = true } //write the new key and x509KeyPair to disk log.Printf("writing new key file: %s to disk\n", keyFile) @@ -966,13 +982,21 @@ func SaveServiceCertKey(key, cert []byte, keyFile, certFile, serviceName string, x509KeyPair, err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { log.Printf("x509KeyPair: %s, key: %s do not match, error: %v\n", certFile, keyFile, err) - return CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + // restore the original contents only if we had successfully backed up the files + if filesBackedUp { + err = CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + } + return err } _, err = x509.ParseCertificate(x509KeyPair.Certificate[0]) if err != nil { log.Printf("x509KeyPair: %s, key: %s, unable to parse cert, error: %v\n", certFile, keyFile, err) - return CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + // restore the original contents only if we had successfully backed up the files + if filesBackedUp { + err = CopyCertKeyFile(backUpKeyFile, keyFile, backUpCertFile, certFile, os.FileMode(fileMode), fileDirectUpdate) + } + return err } return nil