Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance privilege elevation logic #722

Merged
merged 8 commits into from
Dec 31, 2024
130 changes: 93 additions & 37 deletions cmd/internal/su.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,100 @@
package internal

import (
"fmt"
"os"
"os/exec"
"path/filepath"

"github.com/sirupsen/logrus"
"fmt"
"github.com/sirupsen/logrus"
"os"
"os/exec"
)

func AutoSu() {
if os.Getuid() == 0 {
return
}
program := filepath.Base(os.Args[0])
pathSudo, err := exec.LookPath("sudo")
if err != nil {
// skip
return
}
// https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85
p, err := os.StartProcess(pathSudo, append([]string{
pathSudo,
"-E",
"-p",
fmt.Sprintf("%v must be run as root. Please enter the password for %%u to continue: ", program),
"--",
}, os.Args...), &os.ProcAttr{
Files: []*os.File{
os.Stdin,
os.Stdout,
os.Stderr,
},
})
if err != nil {
logrus.Fatal(err)
}
stat, err := p.Wait()
if err != nil {
os.Exit(1)
}
os.Exit(stat.ExitCode())
if os.Geteuid() == 0 {
return
}
path, arg := trySudo()
if path == "" {
path, arg = tryDoas()
}
if path == "" {
path, arg = tryPolkit()
}

if path == "" {
return
}
logrus.Infof("use [ %s ] to elevate privileges to run [ %s ]", path, os.Args[0])
p, err := os.StartProcess(path, append(arg, os.Args...), &os.ProcAttr{
Files: []*os.File{
os.Stdin,
os.Stdout,
os.Stderr,
},
})
if err != nil {
logrus.Fatal(err)
}
stat, err := p.Wait()
if err != nil {
os.Exit(1)
}
os.Exit(stat.ExitCode())
}

func trySudo() (path string, arg []string) {
pathSudo, err := exec.LookPath("sudo")
if err != nil || !isExistAndExecutable(pathSudo) {
return "", nil
}
// https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85
return pathSudo, []string{
pathSudo,
"-E",
"-p",
fmt.Sprintf("Please enter the password for %%u to continue: "),
"--",
}
}

func tryDoas() (path string, arg []string) {
// https://man.archlinux.org/man/doas.1
var err error
path, err = exec.LookPath("doas")
if err != nil {
return "", nil
}
return path, []string{path, "-u", "root"}
}

func tryPolkit() (path string, arg []string) {
// https://github.com/systemd/systemd/releases/tag/v256
// introduced run0 which is a polkit wrapper.
var possible = []string{"run0", "pkexec"}
for _, v := range possible {
path, err := exec.LookPath(v)
if err != nil {
continue
}
if isExistAndExecutable(path) {
switch v {
case "run0":
return path, []string{path}
case "pkexec":
return path, []string{path, "--keep-cwd", "--user", "root"}
}
}
}
return "", nil
}

func isExistAndExecutable(path string) bool {
if path == "" {
return false
}

st, err := os.Stat(path)
if err == nil {
// https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go
return st.Mode()&0o111 == 0o111
}
return false
}
2 changes: 1 addition & 1 deletion cmd/reload.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var (
Use: "reload [pid]",
Short: "To reload config file without interrupt connections.",
Run: func(cmd *cobra.Command, args []string) {
internal.AutoSu()
internal.AutoSu()
if len(args) == 0 {
_pid, err := os.ReadFile(PidFilePath)
if err != nil {
Expand Down
15 changes: 10 additions & 5 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ func init() {
runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.")
runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.")
runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.")
runCmd.PersistentFlags().BoolVarP(&disableTimestamp, "disable-timestamp", "", false, "Disable timestamp.")
runCmd.PersistentFlags().BoolVarP(&disablePidFile, "disable-pidfile", "", false, "Not generate /var/run/dae.pid.")

runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.")
runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.")
runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions")
rand.Shuffle(len(CheckNetworkLinks), func(i, j int) {
CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i]
})
Expand All @@ -74,6 +74,7 @@ var (
logFileMaxBackups int
disableTimestamp bool
disablePidFile bool
disableAuthSudo bool

runCmd = &cobra.Command{
Use: "run",
Expand All @@ -82,9 +83,13 @@ var (
if cfgFile == "" {
logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.")
}

if disableAuthSudo && os.Geteuid() != 0 {
logrus.Fatalln("Auto-sudo is disabled and current user is not root.")
}
// Require "sudo" if necessary.
internal.AutoSu()
if !disableAuthSudo {
internal.AutoSu()
}

// Read config from --config cfgFile.
conf, includes, err := readConfig(cfgFile)
Expand Down
18 changes: 7 additions & 11 deletions cmd/sysdump.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ package cmd

import (
"bytes"
"io/ioutil"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"time"

"github.com/vishvananda/netlink"
"github.com/spf13/cobra"
"github.com/mholt/archiver/v3"
"github.com/shirou/gopsutil/v4/net"
"github.com/spf13/cobra"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)

Expand Down Expand Up @@ -46,7 +46,7 @@ func dumpNetworkInfo() {
dumpNetfilter(tempDir)
dumpIPTables(tempDir)

tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz",time.Now().Unix())
tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz", time.Now().Unix())
if err := archiver.Archive([]string{tempDir}, tarFile); err != nil {
fmt.Printf("Failed to create tar archive: %v\n", err)
return
Expand All @@ -55,7 +55,6 @@ func dumpNetworkInfo() {
fmt.Printf("System network information collected and saved to %s\n", tarFile)
}


// Translate scope enum into semantic words
func scopeToString(scope netlink.Scope) string {
switch scope {
Expand All @@ -74,7 +73,6 @@ func scopeToString(scope netlink.Scope) string {
}
}


// Translate protocol enum into semantic words
func protocolToString(proto int) string {
switch proto {
Expand Down Expand Up @@ -157,7 +155,6 @@ func typeToString(typ int) string {
}
}


func dumpRouting(outputDir string) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
Expand Down Expand Up @@ -232,7 +229,6 @@ func dumpNetInterfaces(outputDir string) {
ioutil.WriteFile(filepath.Join(outputDir, "interfaces.txt"), buffer.Bytes(), 0644)
}


func dumpSysctl(outputDir string) {
sysctlPath := "/proc/sys/net"
var buffer bytes.Buffer
Expand Down Expand Up @@ -281,12 +277,12 @@ func dumpIPTables(outputDir string) {
ioutil.WriteFile(filepath.Join(outputDir, "iptables.txt"), output, 0644)
}

ip6tables := exec.Command("ip6tables-save","-c")
ip6tables := exec.Command("ip6tables-save", "-c")
output, err = ip6tables.CombinedOutput()
if err != nil {
fmt.Printf("Failed to get ip6tables: %v\n", err)
} else {
ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644)
ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644)
}
}

Expand Down
10 changes: 5 additions & 5 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ type DnsController struct {

fixedDomainTtl map[string]int
// mutex protects the dnsCache.
dnsCacheMu sync.Mutex
dnsCache map[string]*DnsCache
dnsCacheMu sync.Mutex
dnsCache map[string]*DnsCache
dnsForwarderCacheMu sync.Mutex
dnsForwarderCache map[string]DnsForwarder
}
Expand Down Expand Up @@ -113,9 +113,9 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont
bestDialerChooser: option.BestDialerChooser,
timeoutExceedCallback: option.TimeoutExceedCallback,

fixedDomainTtl: option.FixedDomainTtl,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*DnsCache),
fixedDomainTtl: option.FixedDomainTtl,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*DnsCache),
dnsForwarderCacheMu: sync.Mutex{},
dnsForwarderCache: make(map[string]DnsForwarder),
}, nil
Expand Down
61 changes: 30 additions & 31 deletions trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"encoding/binary"
"errors"
"fmt"
"slices"
"net"
"os"
"slices"
"syscall"
"unsafe"

Expand Down Expand Up @@ -278,43 +278,42 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre
logrus.Debugf("failed to parse ringbuf event: %+v", err)
continue
}
if skb2events[event.Skb]==nil {
skb2events[event.Skb] = []bpfEvent{}
if skb2events[event.Skb] == nil {
skb2events[event.Skb] = []bpfEvent{}
}
skb2events[event.Skb] = append(skb2events[event.Skb],event)
skb2events[event.Skb] = append(skb2events[event.Skb], event)


sym := NearestSymbol(event.Pc);
if skb2symNames[event.Skb]==nil {
sym := NearestSymbol(event.Pc)
if skb2symNames[event.Skb] == nil {
skb2symNames[event.Skb] = []string{}
}
skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name)
skb2symNames[event.Skb] = append(skb2symNames[event.Skb], sym.Name)
switch sym.Name {
case "__kfree_skb","kfree_skbmem":
// most skb end in the call of kfree_skbmem
if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") {
// trace dropOnly with drop reason or all skb
for _,skb_ev := range skb2events[event.Skb] {
fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport))
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags))
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(skb_ev.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam])
}
fmt.Fprintf(writer, "\n")
case "__kfree_skb", "kfree_skbmem":
// most skb end in the call of kfree_skbmem
if !dropOnly || slices.Contains(skb2symNames[event.Skb], "kfree_skb_reason") {
// trace dropOnly with drop reason or all skb
for _, skb_ev := range skb2events[event.Skb] {
fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport))
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags))
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(skb_ev.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam])
}
fmt.Fprintf(writer, "\n")
}
delete(skb2events, event.Skb)
delete(skb2symNames, event.Skb)
}
delete(skb2symNames, event.Skb)
}
}
}
}
Loading