diff --git a/morph.go b/morph.go index 2f3dbbe..e3c89bb 100644 --- a/morph.go +++ b/morph.go @@ -17,7 +17,6 @@ import ( "os/exec" "path/filepath" "strings" - "syscall" ) // This is set at build time via -ldflags magic @@ -368,25 +367,11 @@ func execDeploy(hosts []nix.Host) (string, error) { } if deployReboot { - if cmd, err := sshContext.Cmd(&host, "sudo", "reboot"); cmd != nil { - fmt.Fprint(os.Stderr, "Asking host to reboot ... ") - if err = cmd.Run(); err != nil { - // Here we assume that exit code 255 means: "SSH connection got disconnected", - // which is OK for a reboot - sshd may close active connections before we disconnect after all - if exitErr, ok := err.(*exec.ExitError); ok { - if status, ok := exitErr.Sys().(syscall.WaitStatus); ok && status.ExitStatus() == 255 { - fmt.Fprintln(os.Stderr, "Remote host disconnected.") - err = nil - } - } - } - } - + err = host.Reboot(sshContext) if err != nil { + fmt.Fprintln(os.Stderr, "Reboot failed") return "", err } - - fmt.Fprintln(os.Stderr) } if !skipHealthChecks { @@ -578,10 +563,10 @@ func buildHosts(hosts []nix.Host) (resultPath string, err error) { nixBuildTargets = fmt.Sprintf("{ \"out\" = %s; }", nixBuildTarget) } - ctx := getNixContext() + ctx := getNixContext() resultPath, err = ctx.BuildMachines(deploymentPath, hosts, nixBuildArg, nixBuildTargets) - if err != nil { + if err != nil { return } diff --git a/nix/nix.go b/nix/nix.go index 001872a..7b4ee44 100644 --- a/nix/nix.go +++ b/nix/nix.go @@ -12,6 +12,8 @@ import ( "os" "os/exec" "path/filepath" + "syscall" + "time" ) type Host struct { @@ -36,6 +38,66 @@ func (host *Host) GetHealthChecks() healthchecks.HealthChecks { return host.HealthChecks } +func (host *Host) Reboot(sshContext *ssh.SSHContext) error { + + var ( + oldBootID string + newBootID string + ) + + oldBootID, err := sshContext.GetBootID(host) + // If the host doesn't support getting boot ID's for some reason, warn about it, and skip the comparison + skipBootIDComparison := err != nil + if skipBootIDComparison { + fmt.Fprintf(os.Stderr, "Error getting boot ID (this is used to determine when the reboot is complete): %v\n", err) + fmt.Fprintf(os.Stderr, "This makes it impossible to detect when the host has rebooted, so health checks might pass before the host has rebooted.\n") + } + + if cmd, err := sshContext.Cmd(host, "sudo", "reboot"); cmd != nil { + fmt.Fprint(os.Stderr, "Asking host to reboot ... ") + if err = cmd.Run(); err != nil { + // Here we assume that exit code 255 means: "SSH connection got disconnected", + // which is OK for a reboot - sshd may close active connections before we disconnect after all + if exitErr, ok := err.(*exec.ExitError); ok { + if status, ok := exitErr.Sys().(syscall.WaitStatus); ok && status.ExitStatus() == 255 { + fmt.Fprintln(os.Stderr, "Remote host disconnected.") + err = nil + } + } + } + + if err != nil { + fmt.Fprintln(os.Stderr, "Failed") + return err + } + } + + fmt.Fprintln(os.Stderr, "OK") + + if !skipBootIDComparison { + fmt.Fprint(os.Stderr, "Waiting for host to come online ") + + // Wait for the host to get a new boot ID. These ID's should be unique for each boot, + // meaning a reboot will have been completed when the boot ID has changed. + for { + fmt.Fprint(os.Stderr, ".") + + // Ignore errors; there'll be plenty of them since we'll be attempting to connect to an offline host, + // and we know from previously that the host should support boot ID's + newBootID, _ = sshContext.GetBootID(host) + + if newBootID != "" && oldBootID != newBootID { + fmt.Fprintln(os.Stderr, " OK") + break + } + + time.Sleep(2 * time.Second) + } + } + + return nil +} + func (ctx *NixContext) GetMachines(deploymentPath string) (hosts []Host, err error) { args := []string{"eval", diff --git a/ssh/ssh.go b/ssh/ssh.go index 6732c11..045ba2d 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "syscall" + "time" ) type Context interface { @@ -237,6 +238,26 @@ func (ctx *SSHContext) ActivateConfiguration(host Host, configuration string, ac return nil } +func (sshCtx *SSHContext) GetBootID(host Host) (string, error) { + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + cmd, err := sshCtx.CmdContext(ctx, host, "cat", "/proc/sys/kernel/random/boot_id") + if err != nil { + return "", err + } + + var stdout bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = os.Stderr + + err = cmd.Run() + if err != nil { + return "", err + } + + return strings.TrimSpace(stdout.String()), nil +} + func (ctx *SSHContext) MakeTempFile(host Host) (path string, err error) { cmd, _ := ctx.Cmd(host, "mktemp")