From 68cd598056d2e32eac6c7a6903dac23db6e11cad Mon Sep 17 00:00:00 2001 From: woshikedayaa Date: Thu, 26 Dec 2024 21:47:48 +0800 Subject: [PATCH 1/6] feat(init): disable sudo flag --- cmd/run.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index 16f2fc5b4e..66ff76348e 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -59,9 +59,9 @@ func init() { runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.") runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.") runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.") - runCmd.PersistentFlags().BoolVarP(&disableTimestamp, "disable-timestamp", "", false, "Disable timestamp.") - runCmd.PersistentFlags().BoolVarP(&disablePidFile, "disable-pidfile", "", false, "Not generate /var/run/dae.pid.") - + runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.") + runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.") + runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions") rand.Shuffle(len(CheckNetworkLinks), func(i, j int) { CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i] }) @@ -74,6 +74,7 @@ var ( logFileMaxBackups int disableTimestamp bool disablePidFile bool + disableAuthSudo bool runCmd = &cobra.Command{ Use: "run", @@ -82,7 +83,9 @@ var ( if cfgFile == "" { logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.") } - + if disableAuthSudo && os.Getuid() != 0 { + logrus.Fatalln("Auto-sudo is disabled and current user is not root.") + } // Require "sudo" if necessary. internal.AutoSu() From bfe4d3ac7be294b03ffa88e1d2be8063b1ac282d Mon Sep 17 00:00:00 2001 From: woshikedayaa Date: Fri, 27 Dec 2024 11:41:56 +0800 Subject: [PATCH 2/6] feat(init): linux desktop polkit integer --- cmd/internal/su.go | 79 +++- cmd/reload.go | 2 +- cmd/run.go | 886 +++++++++++++++++++++-------------------- cmd/sysdump.go | 18 +- control/dns_control.go | 10 +- trace/trace.go | 61 ++- 6 files changed, 554 insertions(+), 502 deletions(-) diff --git a/cmd/internal/su.go b/cmd/internal/su.go index 1151c3bc46..27b8ea48ca 100644 --- a/cmd/internal/su.go +++ b/cmd/internal/su.go @@ -18,20 +18,14 @@ func AutoSu() { if os.Getuid() == 0 { return } - program := filepath.Base(os.Args[0]) - pathSudo, err := exec.LookPath("sudo") - if err != nil { - // skip + path, arg := tryDesktopSudo() + if path == "" { + path, arg = trySudo() + } + if path == "" { return } - // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 - p, err := os.StartProcess(pathSudo, append([]string{ - pathSudo, - "-E", - "-p", - fmt.Sprintf("%v must be run as root. Please enter the password for %%u to continue: ", program), - "--", - }, os.Args...), &os.ProcAttr{ + p, err := os.StartProcess(path, append(arg, os.Args...), &os.ProcAttr{ Files: []*os.File{ os.Stdin, os.Stdout, @@ -47,3 +41,64 @@ func AutoSu() { } os.Exit(stat.ExitCode()) } + +func trySudo() (path string, arg []string) { + pathSudo, err := exec.LookPath("sudo") + if err != nil { + // fallback + var possibleSudoPath = []string{ + "/usr/bin/sudo", "/usr/sbin/sudo", + } + var found = false + for _, v := range possibleSudoPath { + if isExistAndExecutable(v) { + pathSudo = v + found = true + break + } + } + if !found { + return "", nil + } + } + // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 + return pathSudo, []string{ + pathSudo, + "-E", + "-p", + fmt.Sprintf("%v must be run as root. Please enter the password for %%u to continue: ", filepath.Base(os.Args[0])), + "--", + } +} + +func tryDesktopSudo() (path string, arg []string) { + // https://specifications.freedesktop.org/desktop-entry-spec/latest + desktop := os.Getenv("XDG_CURRENT_DESKTOP") + if desktop != "" { + var possible = []string{"pkexec"} + for _, v := range possible { + path, err := exec.LookPath(v) + if err != nil { + continue + } + if isExistAndExecutable(path) { + switch v { + case "pkexec": + return path, []string{path, "--keep-cwd", "--user", "root"} + } + } + } + } + return "", nil +} + +func isExistAndExecutable(path string) bool { + st, err := os.Stat(path) + if err == nil { + // https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go + if st.Mode()&0o111 == 0o111 { + return true + } + } + return false +} diff --git a/cmd/reload.go b/cmd/reload.go index 4f8815e85f..d6d18616d6 100644 --- a/cmd/reload.go +++ b/cmd/reload.go @@ -38,7 +38,7 @@ var ( Use: "reload [pid]", Short: "To reload config file without interrupt connections.", Run: func(cmd *cobra.Command, args []string) { - internal.AutoSu() + internal.AutoSu() if len(args) == 0 { _pid, err := os.ReadFile(PidFilePath) if err != nil { diff --git a/cmd/run.go b/cmd/run.go index 66ff76348e..2a42ef4743 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -6,477 +6,479 @@ package cmd import ( - "context" - "errors" - "fmt" - "math/rand/v2" - "net" - "net/http" - "os" - "os/signal" - "path/filepath" - "runtime" - "strconv" - "strings" - "syscall" - "time" - - "github.com/daeuniverse/outbound/netproxy" - "github.com/daeuniverse/outbound/protocol/direct" - "gopkg.in/natefinch/lumberjack.v2" - - _ "net/http/pprof" - - "github.com/daeuniverse/dae/cmd/internal" - "github.com/daeuniverse/dae/common" - "github.com/daeuniverse/dae/common/consts" - "github.com/daeuniverse/dae/common/subscription" - "github.com/daeuniverse/dae/config" - "github.com/daeuniverse/dae/control" - "github.com/daeuniverse/dae/pkg/config_parser" - "github.com/daeuniverse/dae/pkg/logger" - "github.com/mohae/deepcopy" - "github.com/okzk/sdnotify" - "github.com/sirupsen/logrus" - "github.com/spf13/cobra" + "context" + "errors" + "fmt" + "math/rand/v2" + "net" + "net/http" + "os" + "os/signal" + "path/filepath" + "runtime" + "strconv" + "strings" + "syscall" + "time" + + "github.com/daeuniverse/outbound/netproxy" + "github.com/daeuniverse/outbound/protocol/direct" + "gopkg.in/natefinch/lumberjack.v2" + + _ "net/http/pprof" + + "github.com/daeuniverse/dae/cmd/internal" + "github.com/daeuniverse/dae/common" + "github.com/daeuniverse/dae/common/consts" + "github.com/daeuniverse/dae/common/subscription" + "github.com/daeuniverse/dae/config" + "github.com/daeuniverse/dae/control" + "github.com/daeuniverse/dae/pkg/config_parser" + "github.com/daeuniverse/dae/pkg/logger" + "github.com/mohae/deepcopy" + "github.com/okzk/sdnotify" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" ) const ( - PidFilePath = "/var/run/dae.pid" - SignalProgressFilePath = "/var/run/dae.progress" + PidFilePath = "/var/run/dae.pid" + SignalProgressFilePath = "/var/run/dae.progress" ) var ( - CheckNetworkLinks = []string{ - "http://edge.microsoft.com/captiveportal/generate_204", - "http://www.gstatic.com/generate_204", - "http://www.qualcomm.cn/generate_204", - } + CheckNetworkLinks = []string{ + "http://edge.microsoft.com/captiveportal/generate_204", + "http://www.gstatic.com/generate_204", + "http://www.qualcomm.cn/generate_204", + } ) func init() { - runCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "Config file of dae.(required)") - runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.") - runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.") - runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.") - runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.") - runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.") - runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions") - rand.Shuffle(len(CheckNetworkLinks), func(i, j int) { - CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i] - }) + runCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "Config file of dae.(required)") + runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.") + runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.") + runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.") + runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.") + runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.") + runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions") + rand.Shuffle(len(CheckNetworkLinks), func(i, j int) { + CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i] + }) } var ( - cfgFile string - logFile string - logFileMaxSize int - logFileMaxBackups int - disableTimestamp bool - disablePidFile bool - disableAuthSudo bool - - runCmd = &cobra.Command{ - Use: "run", - Short: "To run dae in the foreground.", - Run: func(cmd *cobra.Command, args []string) { - if cfgFile == "" { - logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.") - } - if disableAuthSudo && os.Getuid() != 0 { - logrus.Fatalln("Auto-sudo is disabled and current user is not root.") - } - // Require "sudo" if necessary. - internal.AutoSu() - - // Read config from --config cfgFile. - conf, includes, err := readConfig(cfgFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "err": err, - }).Fatalln("Failed to read config") - } - - var logOpts *lumberjack.Logger - if logFile != "" { - logOpts = &lumberjack.Logger{ - Filename: logFile, - MaxSize: logFileMaxSize, - MaxAge: 0, - MaxBackups: logFileMaxBackups, - LocalTime: true, - Compress: true, - } - } - log := logrus.New() - logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts) - logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts) - - log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) - if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil { - log.Fatalln(err) - } - }, - } + cfgFile string + logFile string + logFileMaxSize int + logFileMaxBackups int + disableTimestamp bool + disablePidFile bool + disableAuthSudo bool + + runCmd = &cobra.Command{ + Use: "run", + Short: "To run dae in the foreground.", + Run: func(cmd *cobra.Command, args []string) { + if cfgFile == "" { + logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.") + } + if disableAuthSudo && os.Getuid() != 0 { + logrus.Fatalln("Auto-sudo is disabled and current user is not root.") + } + // Require "sudo" if necessary. + if !disableAuthSudo { + internal.AutoSu() + } + + // Read config from --config cfgFile. + conf, includes, err := readConfig(cfgFile) + if err != nil { + logrus.WithFields(logrus.Fields{ + "err": err, + }).Fatalln("Failed to read config") + } + + var logOpts *lumberjack.Logger + if logFile != "" { + logOpts = &lumberjack.Logger{ + Filename: logFile, + MaxSize: logFileMaxSize, + MaxAge: 0, + MaxBackups: logFileMaxBackups, + LocalTime: true, + Compress: true, + } + } + log := logrus.New() + logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts) + logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts) + + log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) + if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil { + log.Fatalln(err) + } + }, + } ) func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (err error) { - // Remove AbortFile at beginning. - _ = os.Remove(AbortFile) - - // New ControlPlane. - c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs) - if err != nil { - return err - } - - var pprofServer *http.Server - if conf.Global.PprofPort != 0 { - pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) - pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} - go pprofServer.ListenAndServe() - } - - // Serve tproxy TCP/UDP server util signals. - var listener *control.Listener - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGILL, syscall.SIGUSR1, syscall.SIGUSR2) - go func() { - readyChan := make(chan bool, 1) - go func() { - <-readyChan - sdnotify.Ready() - if !disablePidFile { - _ = os.WriteFile(PidFilePath, []byte(strconv.Itoa(os.Getpid())), 0644) - } - _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadDone}, 0644) - }() - control.GetDaeNetns().With(func() error { - if listener, err = c.ListenAndServe(readyChan, conf.Global.TproxyPort); err != nil { - log.Errorln("ListenAndServe:", err) - } - return err - }) - sigs <- nil - }() - reloading := false - reloadingErr := error(nil) - isSuspend := false - abortConnections := false + // Remove AbortFile at beginning. + _ = os.Remove(AbortFile) + + // New ControlPlane. + c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs) + if err != nil { + return err + } + + var pprofServer *http.Server + if conf.Global.PprofPort != 0 { + pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) + pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} + go pprofServer.ListenAndServe() + } + + // Serve tproxy TCP/UDP server util signals. + var listener *control.Listener + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGILL, syscall.SIGUSR1, syscall.SIGUSR2) + go func() { + readyChan := make(chan bool, 1) + go func() { + <-readyChan + sdnotify.Ready() + if !disablePidFile { + _ = os.WriteFile(PidFilePath, []byte(strconv.Itoa(os.Getpid())), 0644) + } + _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadDone}, 0644) + }() + control.GetDaeNetns().With(func() error { + if listener, err = c.ListenAndServe(readyChan, conf.Global.TproxyPort); err != nil { + log.Errorln("ListenAndServe:", err) + } + return err + }) + sigs <- nil + }() + reloading := false + reloadingErr := error(nil) + isSuspend := false + abortConnections := false loop: - for sig := range sigs { - switch sig { - case nil: - if reloading { - if listener == nil { - // Failed to listen. Exit. - break loop - } - // Serve. - reloading = false - log.Warnln("[Reload] Serve") - readyChan := make(chan bool, 1) - go func() { - if err := c.Serve(readyChan, listener); err != nil { - log.Errorln("ListenAndServe:", err) - } - sigs <- nil - }() - <-readyChan - sdnotify.Ready() - if reloadingErr == nil { - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadDone}, []byte("\nOK")...), 0644) - } else { - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+reloadingErr.Error())...), 0644) - } - log.Warnln("[Reload] Finished") - } else { - // Listening error. - break loop - } - case syscall.SIGUSR2: - isSuspend = true - fallthrough - case syscall.SIGUSR1: - // Reload signal. - if isSuspend { - log.Warnln("[Reload] Received suspend signal; prepare to suspend") - } else { - log.Warnln("[Reload] Received reload signal; prepare to reload") - } - sdnotify.Reloading() - _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadProcessing}, 0644) - reloadingErr = nil - - // Load new config. - abortConnections = os.Remove(AbortFile) == nil - log.Warnln("[Reload] Load new config") - var newConf *config.Config - if isSuspend { - isSuspend = false - newConf, err = emptyConfig() - if err != nil { - log.WithFields(logrus.Fields{ - "err": err, - }).Errorln("[Reload] Failed to reload") - sdnotify.Ready() - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) - continue - } - newConf.Global = deepcopy.Copy(conf.Global).(config.Global) - newConf.Global.WanInterface = nil - newConf.Global.LanInterface = nil - newConf.Global.LogLevel = "warning" - } else { - var includes []string - newConf, includes, err = readConfig(cfgFile) - if err != nil { - log.WithFields(logrus.Fields{ - "err": err, - }).Errorln("[Reload] Failed to reload") - sdnotify.Ready() - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) - continue - } - log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) - } - // New logger. - oldLogOutput := log.Out - log = logrus.New() - logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil) - logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil) - log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK. - logrus.SetOutput(oldLogOutput) - - // New control plane. - obj := c.EjectBpf() - var dnsCache map[string]*control.DnsCache - if conf.Dns.IpVersionPrefer == newConf.Dns.IpVersionPrefer { - // Only keep dns cache when ip version preference not change. - dnsCache = c.CloneDnsCache() - } - log.Warnln("[Reload] Load new control plane") - newC, err := newControlPlane(log, obj, dnsCache, newConf, externGeoDataDirs) - if err != nil { - reloadingErr = err - log.WithFields(logrus.Fields{ - "err": err, - }).Errorln("[Reload] Failed to reload; try to roll back configuration") - // Load last config back. - newC, err = newControlPlane(log, obj, dnsCache, conf, externGeoDataDirs) - if err != nil { - sdnotify.Stopping() - obj.Close() - c.Close() - log.WithFields(logrus.Fields{ - "err": err, - }).Fatalln("[Reload] Failed to roll back configuration") - } - newConf = conf - log.Errorln("[Reload] Last reload failed; rolled back configuration") - } else { - log.Warnln("[Reload] Stopped old control plane") - } - - // Inject bpf objects into the new control plane life-cycle. - newC.InjectBpf(obj) - - // Prepare new context. - oldC := c - c = newC - conf = newConf - reloading = true - - // Ready to close. - if abortConnections { - oldC.AbortConnections() - } - oldC.Close() - - if pprofServer != nil { - pprofServer.Shutdown(context.Background()) - pprofServer = nil - } - if newConf.Global.PprofPort != 0 { - pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) - pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} - go pprofServer.ListenAndServe() - } - case syscall.SIGHUP: - // Ignore. - continue - default: - log.Infof("Received signal: %v", sig.String()) - break loop - } - } - defer os.Remove(PidFilePath) - defer control.GetDaeNetns().Close() - if e := c.Close(); e != nil { - return fmt.Errorf("close control plane: %w", e) - } - return nil + for sig := range sigs { + switch sig { + case nil: + if reloading { + if listener == nil { + // Failed to listen. Exit. + break loop + } + // Serve. + reloading = false + log.Warnln("[Reload] Serve") + readyChan := make(chan bool, 1) + go func() { + if err := c.Serve(readyChan, listener); err != nil { + log.Errorln("ListenAndServe:", err) + } + sigs <- nil + }() + <-readyChan + sdnotify.Ready() + if reloadingErr == nil { + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadDone}, []byte("\nOK")...), 0644) + } else { + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+reloadingErr.Error())...), 0644) + } + log.Warnln("[Reload] Finished") + } else { + // Listening error. + break loop + } + case syscall.SIGUSR2: + isSuspend = true + fallthrough + case syscall.SIGUSR1: + // Reload signal. + if isSuspend { + log.Warnln("[Reload] Received suspend signal; prepare to suspend") + } else { + log.Warnln("[Reload] Received reload signal; prepare to reload") + } + sdnotify.Reloading() + _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadProcessing}, 0644) + reloadingErr = nil + + // Load new config. + abortConnections = os.Remove(AbortFile) == nil + log.Warnln("[Reload] Load new config") + var newConf *config.Config + if isSuspend { + isSuspend = false + newConf, err = emptyConfig() + if err != nil { + log.WithFields(logrus.Fields{ + "err": err, + }).Errorln("[Reload] Failed to reload") + sdnotify.Ready() + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) + continue + } + newConf.Global = deepcopy.Copy(conf.Global).(config.Global) + newConf.Global.WanInterface = nil + newConf.Global.LanInterface = nil + newConf.Global.LogLevel = "warning" + } else { + var includes []string + newConf, includes, err = readConfig(cfgFile) + if err != nil { + log.WithFields(logrus.Fields{ + "err": err, + }).Errorln("[Reload] Failed to reload") + sdnotify.Ready() + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) + continue + } + log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) + } + // New logger. + oldLogOutput := log.Out + log = logrus.New() + logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil) + logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil) + log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK. + logrus.SetOutput(oldLogOutput) + + // New control plane. + obj := c.EjectBpf() + var dnsCache map[string]*control.DnsCache + if conf.Dns.IpVersionPrefer == newConf.Dns.IpVersionPrefer { + // Only keep dns cache when ip version preference not change. + dnsCache = c.CloneDnsCache() + } + log.Warnln("[Reload] Load new control plane") + newC, err := newControlPlane(log, obj, dnsCache, newConf, externGeoDataDirs) + if err != nil { + reloadingErr = err + log.WithFields(logrus.Fields{ + "err": err, + }).Errorln("[Reload] Failed to reload; try to roll back configuration") + // Load last config back. + newC, err = newControlPlane(log, obj, dnsCache, conf, externGeoDataDirs) + if err != nil { + sdnotify.Stopping() + obj.Close() + c.Close() + log.WithFields(logrus.Fields{ + "err": err, + }).Fatalln("[Reload] Failed to roll back configuration") + } + newConf = conf + log.Errorln("[Reload] Last reload failed; rolled back configuration") + } else { + log.Warnln("[Reload] Stopped old control plane") + } + + // Inject bpf objects into the new control plane life-cycle. + newC.InjectBpf(obj) + + // Prepare new context. + oldC := c + c = newC + conf = newConf + reloading = true + + // Ready to close. + if abortConnections { + oldC.AbortConnections() + } + oldC.Close() + + if pprofServer != nil { + pprofServer.Shutdown(context.Background()) + pprofServer = nil + } + if newConf.Global.PprofPort != 0 { + pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) + pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} + go pprofServer.ListenAndServe() + } + case syscall.SIGHUP: + // Ignore. + continue + default: + log.Infof("Received signal: %v", sig.String()) + break loop + } + } + defer os.Remove(PidFilePath) + defer control.GetDaeNetns().Close() + if e := c.Close(); e != nil { + return fmt.Errorf("close control plane: %w", e) + } + return nil } func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*control.DnsCache, conf *config.Config, externGeoDataDirs []string) (c *control.ControlPlane, err error) { - // Deep copy to prevent modification. - conf = deepcopy.Copy(conf).(*config.Config) - - /// Get tag -> nodeList mapping. - tagToNodeList := map[string][]string{} - if len(conf.Node) > 0 { - for _, node := range conf.Node { - tagToNodeList[""] = append(tagToNodeList[""], string(node)) - } - } - // Resolve subscriptions to nodes. - resolvingfailed := false - if !conf.Global.DisableWaitingNetwork { - epo := 5 * time.Second - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { - conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) - if err != nil { - return nil, err - } - return &netproxy.FakeNetConn{ - Conn: conn, - LAddr: nil, - RAddr: nil, - }, nil - }, - }, - Timeout: epo, - } - log.Infoln("Waiting for network...") - for i := 0; ; i++ { - resp, err := client.Get(CheckNetworkLinks[i%len(CheckNetworkLinks)]) - if err != nil { - log.Debugln("CheckNetwork:", err) - var neterr net.Error - if errors.As(err, &neterr) && neterr.Timeout() { - // Do not sleep. - continue - } - time.Sleep(epo) - continue - } - resp.Body.Close() - if resp.StatusCode >= 200 && resp.StatusCode < 500 { - break - } - log.Infof("Bad status: %v (%v)", resp.Status, resp.StatusCode) - time.Sleep(epo) - } - log.Infoln("Network online.") - } - if len(conf.Subscription) > 0 { - log.Infoln("Fetching subscriptions...") - } - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { - conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) - if err != nil { - return nil, err - } - return &netproxy.FakeNetConn{ - Conn: conn, - LAddr: nil, - RAddr: nil, - }, nil - }, - }, - Timeout: 30 * time.Second, - } - for _, sub := range conf.Subscription { - tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub)) - if err != nil { - log.Warnf(`failed to resolve subscription "%v": %v`, sub, err) - resolvingfailed = true - } - if len(nodes) > 0 { - tagToNodeList[tag] = append(tagToNodeList[tag], nodes...) - } - } - if len(tagToNodeList) == 0 { - if resolvingfailed { - log.Warnln("No node found because all subscription resolving failed.") - } else { - log.Warnln("No node found.") - } - } - - if len(conf.Global.LanInterface) == 0 && len(conf.Global.WanInterface) == 0 { - log.Warnln("No interface to bind.") - } - - if err = preprocessWanInterfaceAuto(conf); err != nil { - return nil, err - } - - c, err = control.NewControlPlane( - log, - bpf, - dnsCache, - tagToNodeList, - conf.Group, - &conf.Routing, - &conf.Global, - &conf.Dns, - externGeoDataDirs, - ) - if err != nil { - return nil, err - } - // Call GC to release memory. - runtime.GC() - - return c, nil + // Deep copy to prevent modification. + conf = deepcopy.Copy(conf).(*config.Config) + + /// Get tag -> nodeList mapping. + tagToNodeList := map[string][]string{} + if len(conf.Node) > 0 { + for _, node := range conf.Node { + tagToNodeList[""] = append(tagToNodeList[""], string(node)) + } + } + // Resolve subscriptions to nodes. + resolvingfailed := false + if !conf.Global.DisableWaitingNetwork { + epo := 5 * time.Second + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { + conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) + if err != nil { + return nil, err + } + return &netproxy.FakeNetConn{ + Conn: conn, + LAddr: nil, + RAddr: nil, + }, nil + }, + }, + Timeout: epo, + } + log.Infoln("Waiting for network...") + for i := 0; ; i++ { + resp, err := client.Get(CheckNetworkLinks[i%len(CheckNetworkLinks)]) + if err != nil { + log.Debugln("CheckNetwork:", err) + var neterr net.Error + if errors.As(err, &neterr) && neterr.Timeout() { + // Do not sleep. + continue + } + time.Sleep(epo) + continue + } + resp.Body.Close() + if resp.StatusCode >= 200 && resp.StatusCode < 500 { + break + } + log.Infof("Bad status: %v (%v)", resp.Status, resp.StatusCode) + time.Sleep(epo) + } + log.Infoln("Network online.") + } + if len(conf.Subscription) > 0 { + log.Infoln("Fetching subscriptions...") + } + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { + conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) + if err != nil { + return nil, err + } + return &netproxy.FakeNetConn{ + Conn: conn, + LAddr: nil, + RAddr: nil, + }, nil + }, + }, + Timeout: 30 * time.Second, + } + for _, sub := range conf.Subscription { + tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub)) + if err != nil { + log.Warnf(`failed to resolve subscription "%v": %v`, sub, err) + resolvingfailed = true + } + if len(nodes) > 0 { + tagToNodeList[tag] = append(tagToNodeList[tag], nodes...) + } + } + if len(tagToNodeList) == 0 { + if resolvingfailed { + log.Warnln("No node found because all subscription resolving failed.") + } else { + log.Warnln("No node found.") + } + } + + if len(conf.Global.LanInterface) == 0 && len(conf.Global.WanInterface) == 0 { + log.Warnln("No interface to bind.") + } + + if err = preprocessWanInterfaceAuto(conf); err != nil { + return nil, err + } + + c, err = control.NewControlPlane( + log, + bpf, + dnsCache, + tagToNodeList, + conf.Group, + &conf.Routing, + &conf.Global, + &conf.Dns, + externGeoDataDirs, + ) + if err != nil { + return nil, err + } + // Call GC to release memory. + runtime.GC() + + return c, nil } func preprocessWanInterfaceAuto(params *config.Config) error { - // preprocess "auto". - ifs := make([]string, 0, len(params.Global.WanInterface)+2) - for _, ifname := range params.Global.WanInterface { - if ifname == "auto" { - defaultIfs, err := common.GetDefaultIfnames() - if err != nil { - return fmt.Errorf("failed to convert 'auto': %w", err) - } - ifs = append(ifs, defaultIfs...) - } else { - ifs = append(ifs, ifname) - } - } - params.Global.WanInterface = common.Deduplicate(ifs) - return nil + // preprocess "auto". + ifs := make([]string, 0, len(params.Global.WanInterface)+2) + for _, ifname := range params.Global.WanInterface { + if ifname == "auto" { + defaultIfs, err := common.GetDefaultIfnames() + if err != nil { + return fmt.Errorf("failed to convert 'auto': %w", err) + } + ifs = append(ifs, defaultIfs...) + } else { + ifs = append(ifs, ifname) + } + } + params.Global.WanInterface = common.Deduplicate(ifs) + return nil } func readConfig(cfgFile string) (conf *config.Config, includes []string, err error) { - merger := config.NewMerger(cfgFile) - sections, includes, err := merger.Merge() - if err != nil { - return nil, nil, err - } - if conf, err = config.New(sections); err != nil { - return nil, nil, err - } - return conf, includes, nil + merger := config.NewMerger(cfgFile) + sections, includes, err := merger.Merge() + if err != nil { + return nil, nil, err + } + if conf, err = config.New(sections); err != nil { + return nil, nil, err + } + return conf, includes, nil } func emptyConfig() (conf *config.Config, err error) { - sections, err := config_parser.Parse(`global{} routing{}`) - if err != nil { - return nil, err - } - if conf, err = config.New(sections); err != nil { - return nil, err - } - return conf, nil + sections, err := config_parser.Parse(`global{} routing{}`) + if err != nil { + return nil, err + } + if conf, err = config.New(sections); err != nil { + return nil, err + } + return conf, nil } func init() { - rootCmd.AddCommand(runCmd) + rootCmd.AddCommand(runCmd) } diff --git a/cmd/sysdump.go b/cmd/sysdump.go index 997f43c8cf..f7395a02a3 100644 --- a/cmd/sysdump.go +++ b/cmd/sysdump.go @@ -7,18 +7,18 @@ package cmd import ( "bytes" - "io/ioutil" "fmt" + "io/ioutil" "os" "os/exec" "path/filepath" "strings" - "time" + "time" - "github.com/vishvananda/netlink" - "github.com/spf13/cobra" "github.com/mholt/archiver/v3" "github.com/shirou/gopsutil/v4/net" + "github.com/spf13/cobra" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) @@ -46,7 +46,7 @@ func dumpNetworkInfo() { dumpNetfilter(tempDir) dumpIPTables(tempDir) - tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz",time.Now().Unix()) + tarFile := fmt.Sprintf("dae-sysdump.%d.tar.gz", time.Now().Unix()) if err := archiver.Archive([]string{tempDir}, tarFile); err != nil { fmt.Printf("Failed to create tar archive: %v\n", err) return @@ -55,7 +55,6 @@ func dumpNetworkInfo() { fmt.Printf("System network information collected and saved to %s\n", tarFile) } - // Translate scope enum into semantic words func scopeToString(scope netlink.Scope) string { switch scope { @@ -74,7 +73,6 @@ func scopeToString(scope netlink.Scope) string { } } - // Translate protocol enum into semantic words func protocolToString(proto int) string { switch proto { @@ -157,7 +155,6 @@ func typeToString(typ int) string { } } - func dumpRouting(outputDir string) { routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { @@ -232,7 +229,6 @@ func dumpNetInterfaces(outputDir string) { ioutil.WriteFile(filepath.Join(outputDir, "interfaces.txt"), buffer.Bytes(), 0644) } - func dumpSysctl(outputDir string) { sysctlPath := "/proc/sys/net" var buffer bytes.Buffer @@ -281,12 +277,12 @@ func dumpIPTables(outputDir string) { ioutil.WriteFile(filepath.Join(outputDir, "iptables.txt"), output, 0644) } - ip6tables := exec.Command("ip6tables-save","-c") + ip6tables := exec.Command("ip6tables-save", "-c") output, err = ip6tables.CombinedOutput() if err != nil { fmt.Printf("Failed to get ip6tables: %v\n", err) } else { - ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644) + ioutil.WriteFile(filepath.Join(outputDir, "ip6tables.txt"), output, 0644) } } diff --git a/control/dns_control.go b/control/dns_control.go index 5435814738..82245713f2 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -76,8 +76,8 @@ type DnsController struct { fixedDomainTtl map[string]int // mutex protects the dnsCache. - dnsCacheMu sync.Mutex - dnsCache map[string]*DnsCache + dnsCacheMu sync.Mutex + dnsCache map[string]*DnsCache dnsForwarderCacheMu sync.Mutex dnsForwarderCache map[string]DnsForwarder } @@ -113,9 +113,9 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont bestDialerChooser: option.BestDialerChooser, timeoutExceedCallback: option.TimeoutExceedCallback, - fixedDomainTtl: option.FixedDomainTtl, - dnsCacheMu: sync.Mutex{}, - dnsCache: make(map[string]*DnsCache), + fixedDomainTtl: option.FixedDomainTtl, + dnsCacheMu: sync.Mutex{}, + dnsCache: make(map[string]*DnsCache), dnsForwarderCacheMu: sync.Mutex{}, dnsForwarderCache: make(map[string]DnsForwarder), }, nil diff --git a/trace/trace.go b/trace/trace.go index b801ecbf37..ae6c816e64 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -11,9 +11,9 @@ import ( "encoding/binary" "errors" "fmt" - "slices" "net" "os" + "slices" "syscall" "unsafe" @@ -278,43 +278,42 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre logrus.Debugf("failed to parse ringbuf event: %+v", err) continue } - if skb2events[event.Skb]==nil { - skb2events[event.Skb] = []bpfEvent{} + if skb2events[event.Skb] == nil { + skb2events[event.Skb] = []bpfEvent{} } - skb2events[event.Skb] = append(skb2events[event.Skb],event) + skb2events[event.Skb] = append(skb2events[event.Skb], event) - - sym := NearestSymbol(event.Pc); - if skb2symNames[event.Skb]==nil { + sym := NearestSymbol(event.Pc) + if skb2symNames[event.Skb] == nil { skb2symNames[event.Skb] = []string{} } - skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name) + skb2symNames[event.Skb] = append(skb2symNames[event.Skb], sym.Name) switch sym.Name { - case "__kfree_skb","kfree_skbmem": - // most skb end in the call of kfree_skbmem - if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") { - // trace dropOnly with drop reason or all skb - for _,skb_ev := range skb2events[event.Skb] { - fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:]))) - if event.L3Proto == syscall.ETH_P_IP { - fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport)) - } else { - fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport)) - } - if event.L4Proto == syscall.IPPROTO_TCP { - fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags)) - } - fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen) - sym := NearestSymbol(skb_ev.Pc) - fmt.Fprintf(writer, "%s", sym.Name) - if sym.Name == "kfree_skb_reason" { - fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam]) - } - fmt.Fprintf(writer, "\n") + case "__kfree_skb", "kfree_skbmem": + // most skb end in the call of kfree_skbmem + if !dropOnly || slices.Contains(skb2symNames[event.Skb], "kfree_skb_reason") { + // trace dropOnly with drop reason or all skb + for _, skb_ev := range skb2events[event.Skb] { + fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:]))) + if event.L3Proto == syscall.ETH_P_IP { + fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport)) + } else { + fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport)) + } + if event.L4Proto == syscall.IPPROTO_TCP { + fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags)) } + fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen) + sym := NearestSymbol(skb_ev.Pc) + fmt.Fprintf(writer, "%s", sym.Name) + if sym.Name == "kfree_skb_reason" { + fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam]) + } + fmt.Fprintf(writer, "\n") + } delete(skb2events, event.Skb) - delete(skb2symNames, event.Skb) - } + delete(skb2symNames, event.Skb) + } } } } From 530dbb84411de7b9a89d34a2bcae3708420bb9aa Mon Sep 17 00:00:00 2001 From: woshikedayaa Date: Fri, 27 Dec 2024 14:18:12 +0800 Subject: [PATCH 3/6] feat(init): remove sudo fallback --- cmd/internal/su.go | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/cmd/internal/su.go b/cmd/internal/su.go index 27b8ea48ca..af4df5fb52 100644 --- a/cmd/internal/su.go +++ b/cmd/internal/su.go @@ -44,22 +44,8 @@ func AutoSu() { func trySudo() (path string, arg []string) { pathSudo, err := exec.LookPath("sudo") - if err != nil { - // fallback - var possibleSudoPath = []string{ - "/usr/bin/sudo", "/usr/sbin/sudo", - } - var found = false - for _, v := range possibleSudoPath { - if isExistAndExecutable(v) { - pathSudo = v - found = true - break - } - } - if !found { - return "", nil - } + if err != nil || !isExistAndExecutable(pathSudo) { + return "", nil } // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 return pathSudo, []string{ @@ -93,12 +79,14 @@ func tryDesktopSudo() (path string, arg []string) { } func isExistAndExecutable(path string) bool { + if path == "" { + return false + } + st, err := os.Stat(path) if err == nil { // https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go - if st.Mode()&0o111 == 0o111 { - return true - } + return st.Mode()&0o111 == 0o111 } return false } From ee566113d6b9036eb36c4404588b77d12e618993 Mon Sep 17 00:00:00 2001 From: woshikedayaa Date: Fri, 27 Dec 2024 14:22:07 +0800 Subject: [PATCH 4/6] feat(init): remove desktop environment check --- cmd/internal/su.go | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/cmd/internal/su.go b/cmd/internal/su.go index af4df5fb52..bbdea48708 100644 --- a/cmd/internal/su.go +++ b/cmd/internal/su.go @@ -18,7 +18,7 @@ func AutoSu() { if os.Getuid() == 0 { return } - path, arg := tryDesktopSudo() + path, arg := tryPolkit() if path == "" { path, arg = trySudo() } @@ -57,21 +57,17 @@ func trySudo() (path string, arg []string) { } } -func tryDesktopSudo() (path string, arg []string) { - // https://specifications.freedesktop.org/desktop-entry-spec/latest - desktop := os.Getenv("XDG_CURRENT_DESKTOP") - if desktop != "" { - var possible = []string{"pkexec"} - for _, v := range possible { - path, err := exec.LookPath(v) - if err != nil { - continue - } - if isExistAndExecutable(path) { - switch v { - case "pkexec": - return path, []string{path, "--keep-cwd", "--user", "root"} - } +func tryPolkit() (path string, arg []string) { + var possible = []string{"pkexec"} + for _, v := range possible { + path, err := exec.LookPath(v) + if err != nil { + continue + } + if isExistAndExecutable(path) { + switch v { + case "pkexec": + return path, []string{path, "--keep-cwd", "--user", "root"} } } } From fa2ad4b50db5deec9ece65c7feb6387102789bf4 Mon Sep 17 00:00:00 2001 From: woshikedayaa Date: Fri, 27 Dec 2024 14:56:31 +0800 Subject: [PATCH 5/6] fix(init): add compatibility for linux setuid. --- cmd/internal/su.go | 2 +- cmd/run.go | 888 ++++++++++++++++++++++----------------------- 2 files changed, 445 insertions(+), 445 deletions(-) diff --git a/cmd/internal/su.go b/cmd/internal/su.go index bbdea48708..34e477999d 100644 --- a/cmd/internal/su.go +++ b/cmd/internal/su.go @@ -15,7 +15,7 @@ import ( ) func AutoSu() { - if os.Getuid() == 0 { + if os.Geteuid() == 0 { return } path, arg := tryPolkit() diff --git a/cmd/run.go b/cmd/run.go index 2a42ef4743..0c5a033d0d 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -6,479 +6,479 @@ package cmd import ( - "context" - "errors" - "fmt" - "math/rand/v2" - "net" - "net/http" - "os" - "os/signal" - "path/filepath" - "runtime" - "strconv" - "strings" - "syscall" - "time" - - "github.com/daeuniverse/outbound/netproxy" - "github.com/daeuniverse/outbound/protocol/direct" - "gopkg.in/natefinch/lumberjack.v2" - - _ "net/http/pprof" - - "github.com/daeuniverse/dae/cmd/internal" - "github.com/daeuniverse/dae/common" - "github.com/daeuniverse/dae/common/consts" - "github.com/daeuniverse/dae/common/subscription" - "github.com/daeuniverse/dae/config" - "github.com/daeuniverse/dae/control" - "github.com/daeuniverse/dae/pkg/config_parser" - "github.com/daeuniverse/dae/pkg/logger" - "github.com/mohae/deepcopy" - "github.com/okzk/sdnotify" - "github.com/sirupsen/logrus" - "github.com/spf13/cobra" + "context" + "errors" + "fmt" + "math/rand/v2" + "net" + "net/http" + "os" + "os/signal" + "path/filepath" + "runtime" + "strconv" + "strings" + "syscall" + "time" + + "github.com/daeuniverse/outbound/netproxy" + "github.com/daeuniverse/outbound/protocol/direct" + "gopkg.in/natefinch/lumberjack.v2" + + _ "net/http/pprof" + + "github.com/daeuniverse/dae/cmd/internal" + "github.com/daeuniverse/dae/common" + "github.com/daeuniverse/dae/common/consts" + "github.com/daeuniverse/dae/common/subscription" + "github.com/daeuniverse/dae/config" + "github.com/daeuniverse/dae/control" + "github.com/daeuniverse/dae/pkg/config_parser" + "github.com/daeuniverse/dae/pkg/logger" + "github.com/mohae/deepcopy" + "github.com/okzk/sdnotify" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" ) const ( - PidFilePath = "/var/run/dae.pid" - SignalProgressFilePath = "/var/run/dae.progress" + PidFilePath = "/var/run/dae.pid" + SignalProgressFilePath = "/var/run/dae.progress" ) var ( - CheckNetworkLinks = []string{ - "http://edge.microsoft.com/captiveportal/generate_204", - "http://www.gstatic.com/generate_204", - "http://www.qualcomm.cn/generate_204", - } + CheckNetworkLinks = []string{ + "http://edge.microsoft.com/captiveportal/generate_204", + "http://www.gstatic.com/generate_204", + "http://www.qualcomm.cn/generate_204", + } ) func init() { - runCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "Config file of dae.(required)") - runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.") - runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.") - runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.") - runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.") - runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.") - runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions") - rand.Shuffle(len(CheckNetworkLinks), func(i, j int) { - CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i] - }) + runCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "Config file of dae.(required)") + runCmd.PersistentFlags().StringVar(&logFile, "logfile", "", "Log file to write. Empty means writing to stdout and stderr.") + runCmd.PersistentFlags().IntVar(&logFileMaxSize, "logfile-maxsize", 30, "Unit: MB. The maximum size in megabytes of the log file before it gets rotated.") + runCmd.PersistentFlags().IntVar(&logFileMaxBackups, "logfile-maxbackups", 3, "The maximum number of old log files to retain.") + runCmd.PersistentFlags().BoolVar(&disableTimestamp, "disable-timestamp", false, "Disable timestamp.") + runCmd.PersistentFlags().BoolVar(&disablePidFile, "disable-pidfile", false, "Not generate /var/run/dae.pid.") + runCmd.PersistentFlags().BoolVar(&disableAuthSudo, "disable-sudo", false, "Disable sudo prompt ,may cause startup failure due to insufficient permissions") + rand.Shuffle(len(CheckNetworkLinks), func(i, j int) { + CheckNetworkLinks[i], CheckNetworkLinks[j] = CheckNetworkLinks[j], CheckNetworkLinks[i] + }) } var ( - cfgFile string - logFile string - logFileMaxSize int - logFileMaxBackups int - disableTimestamp bool - disablePidFile bool - disableAuthSudo bool - - runCmd = &cobra.Command{ - Use: "run", - Short: "To run dae in the foreground.", - Run: func(cmd *cobra.Command, args []string) { - if cfgFile == "" { - logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.") - } - if disableAuthSudo && os.Getuid() != 0 { - logrus.Fatalln("Auto-sudo is disabled and current user is not root.") - } - // Require "sudo" if necessary. - if !disableAuthSudo { - internal.AutoSu() - } - - // Read config from --config cfgFile. - conf, includes, err := readConfig(cfgFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "err": err, - }).Fatalln("Failed to read config") - } - - var logOpts *lumberjack.Logger - if logFile != "" { - logOpts = &lumberjack.Logger{ - Filename: logFile, - MaxSize: logFileMaxSize, - MaxAge: 0, - MaxBackups: logFileMaxBackups, - LocalTime: true, - Compress: true, - } - } - log := logrus.New() - logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts) - logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts) - - log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) - if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil { - log.Fatalln(err) - } - }, - } + cfgFile string + logFile string + logFileMaxSize int + logFileMaxBackups int + disableTimestamp bool + disablePidFile bool + disableAuthSudo bool + + runCmd = &cobra.Command{ + Use: "run", + Short: "To run dae in the foreground.", + Run: func(cmd *cobra.Command, args []string) { + if cfgFile == "" { + logrus.Fatalln("Argument \"--config\" or \"-c\" is required but not provided.") + } + if disableAuthSudo && os.Geteuid() != 0 { + logrus.Fatalln("Auto-sudo is disabled and current user is not root.") + } + // Require "sudo" if necessary. + if !disableAuthSudo { + internal.AutoSu() + } + + // Read config from --config cfgFile. + conf, includes, err := readConfig(cfgFile) + if err != nil { + logrus.WithFields(logrus.Fields{ + "err": err, + }).Fatalln("Failed to read config") + } + + var logOpts *lumberjack.Logger + if logFile != "" { + logOpts = &lumberjack.Logger{ + Filename: logFile, + MaxSize: logFileMaxSize, + MaxAge: 0, + MaxBackups: logFileMaxBackups, + LocalTime: true, + Compress: true, + } + } + log := logrus.New() + logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts) + logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts) + + log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) + if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil { + log.Fatalln(err) + } + }, + } ) func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (err error) { - // Remove AbortFile at beginning. - _ = os.Remove(AbortFile) - - // New ControlPlane. - c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs) - if err != nil { - return err - } - - var pprofServer *http.Server - if conf.Global.PprofPort != 0 { - pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) - pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} - go pprofServer.ListenAndServe() - } - - // Serve tproxy TCP/UDP server util signals. - var listener *control.Listener - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGILL, syscall.SIGUSR1, syscall.SIGUSR2) - go func() { - readyChan := make(chan bool, 1) - go func() { - <-readyChan - sdnotify.Ready() - if !disablePidFile { - _ = os.WriteFile(PidFilePath, []byte(strconv.Itoa(os.Getpid())), 0644) - } - _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadDone}, 0644) - }() - control.GetDaeNetns().With(func() error { - if listener, err = c.ListenAndServe(readyChan, conf.Global.TproxyPort); err != nil { - log.Errorln("ListenAndServe:", err) - } - return err - }) - sigs <- nil - }() - reloading := false - reloadingErr := error(nil) - isSuspend := false - abortConnections := false + // Remove AbortFile at beginning. + _ = os.Remove(AbortFile) + + // New ControlPlane. + c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs) + if err != nil { + return err + } + + var pprofServer *http.Server + if conf.Global.PprofPort != 0 { + pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) + pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} + go pprofServer.ListenAndServe() + } + + // Serve tproxy TCP/UDP server util signals. + var listener *control.Listener + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGILL, syscall.SIGUSR1, syscall.SIGUSR2) + go func() { + readyChan := make(chan bool, 1) + go func() { + <-readyChan + sdnotify.Ready() + if !disablePidFile { + _ = os.WriteFile(PidFilePath, []byte(strconv.Itoa(os.Getpid())), 0644) + } + _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadDone}, 0644) + }() + control.GetDaeNetns().With(func() error { + if listener, err = c.ListenAndServe(readyChan, conf.Global.TproxyPort); err != nil { + log.Errorln("ListenAndServe:", err) + } + return err + }) + sigs <- nil + }() + reloading := false + reloadingErr := error(nil) + isSuspend := false + abortConnections := false loop: - for sig := range sigs { - switch sig { - case nil: - if reloading { - if listener == nil { - // Failed to listen. Exit. - break loop - } - // Serve. - reloading = false - log.Warnln("[Reload] Serve") - readyChan := make(chan bool, 1) - go func() { - if err := c.Serve(readyChan, listener); err != nil { - log.Errorln("ListenAndServe:", err) - } - sigs <- nil - }() - <-readyChan - sdnotify.Ready() - if reloadingErr == nil { - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadDone}, []byte("\nOK")...), 0644) - } else { - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+reloadingErr.Error())...), 0644) - } - log.Warnln("[Reload] Finished") - } else { - // Listening error. - break loop - } - case syscall.SIGUSR2: - isSuspend = true - fallthrough - case syscall.SIGUSR1: - // Reload signal. - if isSuspend { - log.Warnln("[Reload] Received suspend signal; prepare to suspend") - } else { - log.Warnln("[Reload] Received reload signal; prepare to reload") - } - sdnotify.Reloading() - _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadProcessing}, 0644) - reloadingErr = nil - - // Load new config. - abortConnections = os.Remove(AbortFile) == nil - log.Warnln("[Reload] Load new config") - var newConf *config.Config - if isSuspend { - isSuspend = false - newConf, err = emptyConfig() - if err != nil { - log.WithFields(logrus.Fields{ - "err": err, - }).Errorln("[Reload] Failed to reload") - sdnotify.Ready() - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) - continue - } - newConf.Global = deepcopy.Copy(conf.Global).(config.Global) - newConf.Global.WanInterface = nil - newConf.Global.LanInterface = nil - newConf.Global.LogLevel = "warning" - } else { - var includes []string - newConf, includes, err = readConfig(cfgFile) - if err != nil { - log.WithFields(logrus.Fields{ - "err": err, - }).Errorln("[Reload] Failed to reload") - sdnotify.Ready() - _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) - continue - } - log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) - } - // New logger. - oldLogOutput := log.Out - log = logrus.New() - logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil) - logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil) - log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK. - logrus.SetOutput(oldLogOutput) - - // New control plane. - obj := c.EjectBpf() - var dnsCache map[string]*control.DnsCache - if conf.Dns.IpVersionPrefer == newConf.Dns.IpVersionPrefer { - // Only keep dns cache when ip version preference not change. - dnsCache = c.CloneDnsCache() - } - log.Warnln("[Reload] Load new control plane") - newC, err := newControlPlane(log, obj, dnsCache, newConf, externGeoDataDirs) - if err != nil { - reloadingErr = err - log.WithFields(logrus.Fields{ - "err": err, - }).Errorln("[Reload] Failed to reload; try to roll back configuration") - // Load last config back. - newC, err = newControlPlane(log, obj, dnsCache, conf, externGeoDataDirs) - if err != nil { - sdnotify.Stopping() - obj.Close() - c.Close() - log.WithFields(logrus.Fields{ - "err": err, - }).Fatalln("[Reload] Failed to roll back configuration") - } - newConf = conf - log.Errorln("[Reload] Last reload failed; rolled back configuration") - } else { - log.Warnln("[Reload] Stopped old control plane") - } - - // Inject bpf objects into the new control plane life-cycle. - newC.InjectBpf(obj) - - // Prepare new context. - oldC := c - c = newC - conf = newConf - reloading = true - - // Ready to close. - if abortConnections { - oldC.AbortConnections() - } - oldC.Close() - - if pprofServer != nil { - pprofServer.Shutdown(context.Background()) - pprofServer = nil - } - if newConf.Global.PprofPort != 0 { - pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) - pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} - go pprofServer.ListenAndServe() - } - case syscall.SIGHUP: - // Ignore. - continue - default: - log.Infof("Received signal: %v", sig.String()) - break loop - } - } - defer os.Remove(PidFilePath) - defer control.GetDaeNetns().Close() - if e := c.Close(); e != nil { - return fmt.Errorf("close control plane: %w", e) - } - return nil + for sig := range sigs { + switch sig { + case nil: + if reloading { + if listener == nil { + // Failed to listen. Exit. + break loop + } + // Serve. + reloading = false + log.Warnln("[Reload] Serve") + readyChan := make(chan bool, 1) + go func() { + if err := c.Serve(readyChan, listener); err != nil { + log.Errorln("ListenAndServe:", err) + } + sigs <- nil + }() + <-readyChan + sdnotify.Ready() + if reloadingErr == nil { + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadDone}, []byte("\nOK")...), 0644) + } else { + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+reloadingErr.Error())...), 0644) + } + log.Warnln("[Reload] Finished") + } else { + // Listening error. + break loop + } + case syscall.SIGUSR2: + isSuspend = true + fallthrough + case syscall.SIGUSR1: + // Reload signal. + if isSuspend { + log.Warnln("[Reload] Received suspend signal; prepare to suspend") + } else { + log.Warnln("[Reload] Received reload signal; prepare to reload") + } + sdnotify.Reloading() + _ = os.WriteFile(SignalProgressFilePath, []byte{consts.ReloadProcessing}, 0644) + reloadingErr = nil + + // Load new config. + abortConnections = os.Remove(AbortFile) == nil + log.Warnln("[Reload] Load new config") + var newConf *config.Config + if isSuspend { + isSuspend = false + newConf, err = emptyConfig() + if err != nil { + log.WithFields(logrus.Fields{ + "err": err, + }).Errorln("[Reload] Failed to reload") + sdnotify.Ready() + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) + continue + } + newConf.Global = deepcopy.Copy(conf.Global).(config.Global) + newConf.Global.WanInterface = nil + newConf.Global.LanInterface = nil + newConf.Global.LogLevel = "warning" + } else { + var includes []string + newConf, includes, err = readConfig(cfgFile) + if err != nil { + log.WithFields(logrus.Fields{ + "err": err, + }).Errorln("[Reload] Failed to reload") + sdnotify.Ready() + _ = os.WriteFile(SignalProgressFilePath, append([]byte{consts.ReloadError}, []byte("\n"+err.Error())...), 0644) + continue + } + log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) + } + // New logger. + oldLogOutput := log.Out + log = logrus.New() + logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil) + logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil) + log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK. + logrus.SetOutput(oldLogOutput) + + // New control plane. + obj := c.EjectBpf() + var dnsCache map[string]*control.DnsCache + if conf.Dns.IpVersionPrefer == newConf.Dns.IpVersionPrefer { + // Only keep dns cache when ip version preference not change. + dnsCache = c.CloneDnsCache() + } + log.Warnln("[Reload] Load new control plane") + newC, err := newControlPlane(log, obj, dnsCache, newConf, externGeoDataDirs) + if err != nil { + reloadingErr = err + log.WithFields(logrus.Fields{ + "err": err, + }).Errorln("[Reload] Failed to reload; try to roll back configuration") + // Load last config back. + newC, err = newControlPlane(log, obj, dnsCache, conf, externGeoDataDirs) + if err != nil { + sdnotify.Stopping() + obj.Close() + c.Close() + log.WithFields(logrus.Fields{ + "err": err, + }).Fatalln("[Reload] Failed to roll back configuration") + } + newConf = conf + log.Errorln("[Reload] Last reload failed; rolled back configuration") + } else { + log.Warnln("[Reload] Stopped old control plane") + } + + // Inject bpf objects into the new control plane life-cycle. + newC.InjectBpf(obj) + + // Prepare new context. + oldC := c + c = newC + conf = newConf + reloading = true + + // Ready to close. + if abortConnections { + oldC.AbortConnections() + } + oldC.Close() + + if pprofServer != nil { + pprofServer.Shutdown(context.Background()) + pprofServer = nil + } + if newConf.Global.PprofPort != 0 { + pprofAddr := fmt.Sprintf("localhost:%d", conf.Global.PprofPort) + pprofServer = &http.Server{Addr: pprofAddr, Handler: nil} + go pprofServer.ListenAndServe() + } + case syscall.SIGHUP: + // Ignore. + continue + default: + log.Infof("Received signal: %v", sig.String()) + break loop + } + } + defer os.Remove(PidFilePath) + defer control.GetDaeNetns().Close() + if e := c.Close(); e != nil { + return fmt.Errorf("close control plane: %w", e) + } + return nil } func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*control.DnsCache, conf *config.Config, externGeoDataDirs []string) (c *control.ControlPlane, err error) { - // Deep copy to prevent modification. - conf = deepcopy.Copy(conf).(*config.Config) - - /// Get tag -> nodeList mapping. - tagToNodeList := map[string][]string{} - if len(conf.Node) > 0 { - for _, node := range conf.Node { - tagToNodeList[""] = append(tagToNodeList[""], string(node)) - } - } - // Resolve subscriptions to nodes. - resolvingfailed := false - if !conf.Global.DisableWaitingNetwork { - epo := 5 * time.Second - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { - conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) - if err != nil { - return nil, err - } - return &netproxy.FakeNetConn{ - Conn: conn, - LAddr: nil, - RAddr: nil, - }, nil - }, - }, - Timeout: epo, - } - log.Infoln("Waiting for network...") - for i := 0; ; i++ { - resp, err := client.Get(CheckNetworkLinks[i%len(CheckNetworkLinks)]) - if err != nil { - log.Debugln("CheckNetwork:", err) - var neterr net.Error - if errors.As(err, &neterr) && neterr.Timeout() { - // Do not sleep. - continue - } - time.Sleep(epo) - continue - } - resp.Body.Close() - if resp.StatusCode >= 200 && resp.StatusCode < 500 { - break - } - log.Infof("Bad status: %v (%v)", resp.Status, resp.StatusCode) - time.Sleep(epo) - } - log.Infoln("Network online.") - } - if len(conf.Subscription) > 0 { - log.Infoln("Fetching subscriptions...") - } - client := http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { - conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) - if err != nil { - return nil, err - } - return &netproxy.FakeNetConn{ - Conn: conn, - LAddr: nil, - RAddr: nil, - }, nil - }, - }, - Timeout: 30 * time.Second, - } - for _, sub := range conf.Subscription { - tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub)) - if err != nil { - log.Warnf(`failed to resolve subscription "%v": %v`, sub, err) - resolvingfailed = true - } - if len(nodes) > 0 { - tagToNodeList[tag] = append(tagToNodeList[tag], nodes...) - } - } - if len(tagToNodeList) == 0 { - if resolvingfailed { - log.Warnln("No node found because all subscription resolving failed.") - } else { - log.Warnln("No node found.") - } - } - - if len(conf.Global.LanInterface) == 0 && len(conf.Global.WanInterface) == 0 { - log.Warnln("No interface to bind.") - } - - if err = preprocessWanInterfaceAuto(conf); err != nil { - return nil, err - } - - c, err = control.NewControlPlane( - log, - bpf, - dnsCache, - tagToNodeList, - conf.Group, - &conf.Routing, - &conf.Global, - &conf.Dns, - externGeoDataDirs, - ) - if err != nil { - return nil, err - } - // Call GC to release memory. - runtime.GC() - - return c, nil + // Deep copy to prevent modification. + conf = deepcopy.Copy(conf).(*config.Config) + + /// Get tag -> nodeList mapping. + tagToNodeList := map[string][]string{} + if len(conf.Node) > 0 { + for _, node := range conf.Node { + tagToNodeList[""] = append(tagToNodeList[""], string(node)) + } + } + // Resolve subscriptions to nodes. + resolvingfailed := false + if !conf.Global.DisableWaitingNetwork { + epo := 5 * time.Second + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { + conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) + if err != nil { + return nil, err + } + return &netproxy.FakeNetConn{ + Conn: conn, + LAddr: nil, + RAddr: nil, + }, nil + }, + }, + Timeout: epo, + } + log.Infoln("Waiting for network...") + for i := 0; ; i++ { + resp, err := client.Get(CheckNetworkLinks[i%len(CheckNetworkLinks)]) + if err != nil { + log.Debugln("CheckNetwork:", err) + var neterr net.Error + if errors.As(err, &neterr) && neterr.Timeout() { + // Do not sleep. + continue + } + time.Sleep(epo) + continue + } + resp.Body.Close() + if resp.StatusCode >= 200 && resp.StatusCode < 500 { + break + } + log.Infof("Bad status: %v (%v)", resp.Status, resp.StatusCode) + time.Sleep(epo) + } + log.Infoln("Network online.") + } + if len(conf.Subscription) > 0 { + log.Infoln("Fetching subscriptions...") + } + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { + conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) + if err != nil { + return nil, err + } + return &netproxy.FakeNetConn{ + Conn: conn, + LAddr: nil, + RAddr: nil, + }, nil + }, + }, + Timeout: 30 * time.Second, + } + for _, sub := range conf.Subscription { + tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub)) + if err != nil { + log.Warnf(`failed to resolve subscription "%v": %v`, sub, err) + resolvingfailed = true + } + if len(nodes) > 0 { + tagToNodeList[tag] = append(tagToNodeList[tag], nodes...) + } + } + if len(tagToNodeList) == 0 { + if resolvingfailed { + log.Warnln("No node found because all subscription resolving failed.") + } else { + log.Warnln("No node found.") + } + } + + if len(conf.Global.LanInterface) == 0 && len(conf.Global.WanInterface) == 0 { + log.Warnln("No interface to bind.") + } + + if err = preprocessWanInterfaceAuto(conf); err != nil { + return nil, err + } + + c, err = control.NewControlPlane( + log, + bpf, + dnsCache, + tagToNodeList, + conf.Group, + &conf.Routing, + &conf.Global, + &conf.Dns, + externGeoDataDirs, + ) + if err != nil { + return nil, err + } + // Call GC to release memory. + runtime.GC() + + return c, nil } func preprocessWanInterfaceAuto(params *config.Config) error { - // preprocess "auto". - ifs := make([]string, 0, len(params.Global.WanInterface)+2) - for _, ifname := range params.Global.WanInterface { - if ifname == "auto" { - defaultIfs, err := common.GetDefaultIfnames() - if err != nil { - return fmt.Errorf("failed to convert 'auto': %w", err) - } - ifs = append(ifs, defaultIfs...) - } else { - ifs = append(ifs, ifname) - } - } - params.Global.WanInterface = common.Deduplicate(ifs) - return nil + // preprocess "auto". + ifs := make([]string, 0, len(params.Global.WanInterface)+2) + for _, ifname := range params.Global.WanInterface { + if ifname == "auto" { + defaultIfs, err := common.GetDefaultIfnames() + if err != nil { + return fmt.Errorf("failed to convert 'auto': %w", err) + } + ifs = append(ifs, defaultIfs...) + } else { + ifs = append(ifs, ifname) + } + } + params.Global.WanInterface = common.Deduplicate(ifs) + return nil } func readConfig(cfgFile string) (conf *config.Config, includes []string, err error) { - merger := config.NewMerger(cfgFile) - sections, includes, err := merger.Merge() - if err != nil { - return nil, nil, err - } - if conf, err = config.New(sections); err != nil { - return nil, nil, err - } - return conf, includes, nil + merger := config.NewMerger(cfgFile) + sections, includes, err := merger.Merge() + if err != nil { + return nil, nil, err + } + if conf, err = config.New(sections); err != nil { + return nil, nil, err + } + return conf, includes, nil } func emptyConfig() (conf *config.Config, err error) { - sections, err := config_parser.Parse(`global{} routing{}`) - if err != nil { - return nil, err - } - if conf, err = config.New(sections); err != nil { - return nil, err - } - return conf, nil + sections, err := config_parser.Parse(`global{} routing{}`) + if err != nil { + return nil, err + } + if conf, err = config.New(sections); err != nil { + return nil, err + } + return conf, nil } func init() { - rootCmd.AddCommand(runCmd) + rootCmd.AddCommand(runCmd) } From 73c22bfee2a8ed7520a3618a615edbf741dd26e6 Mon Sep 17 00:00:00 2001 From: woshikedayaa Date: Sat, 28 Dec 2024 11:10:48 +0800 Subject: [PATCH 6/6] feat(init): auto su more fallback --- cmd/internal/su.go | 149 +++++++++++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 66 deletions(-) diff --git a/cmd/internal/su.go b/cmd/internal/su.go index 34e477999d..c87f444677 100644 --- a/cmd/internal/su.go +++ b/cmd/internal/su.go @@ -6,83 +6,100 @@ package internal import ( - "fmt" - "os" - "os/exec" - "path/filepath" - - "github.com/sirupsen/logrus" + "fmt" + "github.com/sirupsen/logrus" + "os" + "os/exec" ) func AutoSu() { - if os.Geteuid() == 0 { - return - } - path, arg := tryPolkit() - if path == "" { - path, arg = trySudo() - } - if path == "" { - return - } - p, err := os.StartProcess(path, append(arg, os.Args...), &os.ProcAttr{ - Files: []*os.File{ - os.Stdin, - os.Stdout, - os.Stderr, - }, - }) - if err != nil { - logrus.Fatal(err) - } - stat, err := p.Wait() - if err != nil { - os.Exit(1) - } - os.Exit(stat.ExitCode()) + if os.Geteuid() == 0 { + return + } + path, arg := trySudo() + if path == "" { + path, arg = tryDoas() + } + if path == "" { + path, arg = tryPolkit() + } + + if path == "" { + return + } + logrus.Infof("use [ %s ] to elevate privileges to run [ %s ]", path, os.Args[0]) + p, err := os.StartProcess(path, append(arg, os.Args...), &os.ProcAttr{ + Files: []*os.File{ + os.Stdin, + os.Stdout, + os.Stderr, + }, + }) + if err != nil { + logrus.Fatal(err) + } + stat, err := p.Wait() + if err != nil { + os.Exit(1) + } + os.Exit(stat.ExitCode()) } func trySudo() (path string, arg []string) { - pathSudo, err := exec.LookPath("sudo") - if err != nil || !isExistAndExecutable(pathSudo) { - return "", nil - } - // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 - return pathSudo, []string{ - pathSudo, - "-E", - "-p", - fmt.Sprintf("%v must be run as root. Please enter the password for %%u to continue: ", filepath.Base(os.Args[0])), - "--", - } + pathSudo, err := exec.LookPath("sudo") + if err != nil || !isExistAndExecutable(pathSudo) { + return "", nil + } + // https://github.com/WireGuard/wireguard-tools/blob/71799a8f6d1450b63071a21cad6ed434b348d3d5/src/wg-quick/linux.bash#L85 + return pathSudo, []string{ + pathSudo, + "-E", + "-p", + fmt.Sprintf("Please enter the password for %%u to continue: "), + "--", + } +} + +func tryDoas() (path string, arg []string) { + // https://man.archlinux.org/man/doas.1 + var err error + path, err = exec.LookPath("doas") + if err != nil { + return "", nil + } + return path, []string{path, "-u", "root"} } func tryPolkit() (path string, arg []string) { - var possible = []string{"pkexec"} - for _, v := range possible { - path, err := exec.LookPath(v) - if err != nil { - continue - } - if isExistAndExecutable(path) { - switch v { - case "pkexec": - return path, []string{path, "--keep-cwd", "--user", "root"} - } - } - } - return "", nil + // https://github.com/systemd/systemd/releases/tag/v256 + // introduced run0 which is a polkit wrapper. + var possible = []string{"run0", "pkexec"} + for _, v := range possible { + path, err := exec.LookPath(v) + if err != nil { + continue + } + if isExistAndExecutable(path) { + switch v { + case "run0": + return path, []string{path} + case "pkexec": + return path, []string{path, "--keep-cwd", "--user", "root"} + } + } + } + return "", nil } func isExistAndExecutable(path string) bool { - if path == "" { - return false - } + if path == "" { + return false + } - st, err := os.Stat(path) - if err == nil { - // https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go - return st.Mode()&0o111 == 0o111 - } - return false + st, err := os.Stat(path) + if err == nil { + // https://stackoverflow.com/questions/60128401/how-to-check-if-a-file-is-executable-in-go + return st.Mode()&0o111 == 0o111 + } + return false }