diff --git a/pkg/process/process.go b/pkg/process/process.go new file mode 100644 index 00000000..086cc3e2 --- /dev/null +++ b/pkg/process/process.go @@ -0,0 +1,119 @@ +package process + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/shirou/gopsutil/v4/process" +) + +type Manager interface { + ReadPidFile() (int, error) + Name() string + PidFilePath() string + Exists() bool + Terminate() error + Kill() error + FindProcess() (*os.Process, error) + WritePidFile(pid int) error +} + +type Process struct { + name string + pidFilePath string + executablePath string +} + +func New(name, pidFilePath, executablePath string) (*Process, error) { + return &Process{name: name, pidFilePath: pidFilePath, executablePath: executablePath}, nil +} + +func (p *Process) Name() string { + return p.name +} + +func (p *Process) PidFilePath() string { + return p.pidFilePath +} + +func (p *Process) ExecutablePath() string { + return p.executablePath +} + +func (p *Process) ReadPidFile() (int, error) { + data, err := os.ReadFile(p.PidFilePath()) + if err != nil { + return -1, err + } + pidStr := strings.TrimSpace(string(data)) + pid, err := strconv.Atoi(pidStr) + if err != nil { + return -1, fmt.Errorf("invalid pid file: %v", err) + } + return pid, nil +} + +func (p *Process) FindProcess() (*process.Process, error) { + pid, err := p.ReadPidFile() + if err != nil { + return nil, fmt.Errorf("cannot find process: %v", err) + } + + exists, err := process.PidExists(int32(pid)) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("process not found") + } + + proc, err := process.NewProcess(int32(pid)) + if err != nil { + return nil, fmt.Errorf("cannot find process: %v", err) + } + if proc == nil { + return nil, fmt.Errorf("process not found") + } + name, err := proc.Name() + if err != nil { + return nil, fmt.Errorf("cannot find process name: %v", err) + } + if name != p.Name() { + return nil, fmt.Errorf("pid %d is stale, and is being used by %s", pid, name) + } + exe, err := proc.Exe() + if err != nil { + return nil, fmt.Errorf("cannot find process exe: %v", err) + } + if exe != p.ExecutablePath() { + return nil, fmt.Errorf("pid %d is stale, and is being used by %s", pid, exe) + } + return proc, nil +} + +func (p *Process) Exists() bool { + proc, err := p.FindProcess() + return err == nil && proc != nil +} + +func (p *Process) Terminate() error { + proc, err := p.FindProcess() + if err != nil { + return fmt.Errorf("cannot find process: %v", err) + } + return proc.Terminate() +} + +func (p *Process) Kill() error { + proc, err := p.FindProcess() + if err != nil { + return fmt.Errorf("cannot find process: %v", err) + } + return proc.Kill() +} + +func (p *Process) WritePidFile(pid int) error { + return os.WriteFile(p.pidFilePath, []byte(strconv.Itoa(pid)), 0600) +} diff --git a/pkg/process/process_test.go b/pkg/process/process_test.go new file mode 100644 index 00000000..d4f30f33 --- /dev/null +++ b/pkg/process/process_test.go @@ -0,0 +1,99 @@ +package process + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + dummyProcessName = "sleep" + dummyProcessArgs = "60" +) + +var ( + dummyProcess *exec.Cmd + managedProcess *Process + pidFilePath = filepath.Join(os.TempDir(), "pid") +) + +func startDummyProcess() error { + dummyProcess = exec.Command(dummyProcessName, dummyProcessArgs) + err := dummyProcess.Start() + if err != nil { + return err + } + return nil +} + +func TestMain(m *testing.M) { + err := startDummyProcess() + if err != nil { + fmt.Fprintln(os.Stderr, "Failed to start process:", err) + os.Exit(1) + } + + managedProcess, err = New(dummyProcessName, pidFilePath, dummyProcess.Path) + if err != nil { + fmt.Fprintln(os.Stderr, "Failed to create process:", err) + os.Exit(1) + } + err = managedProcess.WritePidFile(dummyProcess.Process.Pid) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + exitCode := m.Run() + if dummyProcess.Process != nil { + _ = dummyProcess.Process.Kill() + } + + os.Exit(exitCode) +} + +func TestProcess_Name(t *testing.T) { + assert.Equal(t, dummyProcessName, managedProcess.Name()) +} + +func TestProcess_FindProcess(t *testing.T) { + foundProcess, err := managedProcess.FindProcess() + assert.NoError(t, err) + assert.NotNil(t, foundProcess) + assert.Equal(t, dummyProcess.Process.Pid, int(foundProcess.Pid)) + + assert.True(t, managedProcess.Exists()) +} + +func TestProcess_KillProcess(t *testing.T) { + err := managedProcess.Kill() + assert.NoError(t, err) + assert.False(t, managedProcess.Exists()) + + // Try to kill the non-existent process + // This should result in an error + err = managedProcess.Kill() + assert.Error(t, err) +} + +func TestProcess_FindProcess_InvalidPidFile(t *testing.T) { + tmpfile, err := os.CreateTemp("", "invalid_pid") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write non-numeric content into the file to mimic an invalid pid + _, err = tmpfile.WriteString("non-numeric") + require.NoError(t, err) + tmpfile.Close() + + invalidProcess, err := New("invalid-process", tmpfile.Name(), "invalid-path") + assert.NoError(t, err) + + foundProcess, err := invalidProcess.FindProcess() + assert.Error(t, err) + assert.Nil(t, foundProcess) +}