diff --git a/cmd/minivpn/iface.go b/cmd/minivpn/iface.go new file mode 100644 index 00000000..4102d938 --- /dev/null +++ b/cmd/minivpn/iface.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "net" +) + +func getInterfaceByIP(ipAddr string) (*net.Interface, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.String() == ipAddr { + return &iface, nil + } + } + } + } + + return nil, fmt.Errorf("interface with IP %s not found", ipAddr) +} diff --git a/cmd/minivpn/log.go b/cmd/minivpn/log.go new file mode 100644 index 00000000..bf4a5c22 --- /dev/null +++ b/cmd/minivpn/log.go @@ -0,0 +1,79 @@ +package main + +import ( + "fmt" + "io" + "os" + "sync" + "time" + + "github.com/apex/log" +) + +// Default handler outputting to stderr. +var Default = NewHandler(os.Stderr) + +// start time. +var start = time.Now() + +// colors. +const ( + none = 0 + red = 31 + green = 32 + yellow = 33 + blue = 34 + gray = 37 +) + +// Colors mapping. +var Colors = [...]int{ + log.DebugLevel: gray, + log.InfoLevel: blue, + log.WarnLevel: yellow, + log.ErrorLevel: red, + log.FatalLevel: red, +} + +// Strings mapping. +var Strings = [...]string{ + log.DebugLevel: "DEBUG", + log.InfoLevel: "INFO", + log.WarnLevel: "WARN", + log.ErrorLevel: "ERROR", + log.FatalLevel: "FATAL", +} + +// Handler implementation. +type Handler struct { + mu sync.Mutex + Writer io.Writer +} + +// New handler. +func NewHandler(w io.Writer) *Handler { + return &Handler{ + Writer: w, + } +} + +// HandleLog implements log.Handler. +func (h *Handler) HandleLog(e *log.Entry) error { + color := Colors[e.Level] + level := Strings[e.Level] + names := e.Fields.Names() + + h.mu.Lock() + defer h.mu.Unlock() + + ts := time.Since(start) + fmt.Fprintf(h.Writer, "\033[%dm%6s\033[0m[%10v] %-25s", color, level, ts, e.Message) + + for _, name := range names { + fmt.Fprintf(h.Writer, " \033[%dm%s\033[0m=%v", color, name, e.Fields.Get(name)) + } + + fmt.Fprintln(h.Writer) + + return nil +} diff --git a/cmd/minivpn/main.go b/cmd/minivpn/main.go index 1ecff11a..059e713f 100644 --- a/cmd/minivpn/main.go +++ b/cmd/minivpn/main.go @@ -2,140 +2,210 @@ package main import ( "context" + "encoding/json" + "flag" "fmt" - "io" + "net" "os" + "os/exec" "time" + "github.com/Doridian/water" "github.com/apex/log" - "github.com/pborman/getopt/v2" + "github.com/jackpal/gateway" "github.com/ooni/minivpn/extras/ping" - "github.com/ooni/minivpn/vpn" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/networkio" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/tun" + "github.com/ooni/minivpn/pkg/tracex" ) -var ( - startTime = time.Now() - extraTimeoutSeconds = 10 * time.Second -) - -func printUsage() { - fmt.Println("valid commands: ping, proxy") - getopt.Usage() - os.Exit(0) +func runCmd(binaryPath string, args ...string) { + cmd := exec.Command(binaryPath, args...) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + err := cmd.Run() + if nil != err { + log.WithError(err).Warn("error running /sbin/ip") + } } -func timeoutSecondsFromCount(count int) time.Duration { - waitOnLastOne := 2 * time.Second - return time.Duration(count)*time.Second + waitOnLastOne +func runIP(args ...string) { + runCmd("/sbin/ip", args...) +} +func runRoute(args ...string) { + runCmd("/sbin/route", args...) } -// RunPinger takes an Option object, starts a Client, and runs a Pinger against -// the passed target, for a number count of packets. -func RunPinger(opt *vpn.Options, target string, count uint32) error { - c := int(count) - ctx, cancel := context.WithTimeout(context.Background(), timeoutSecondsFromCount(c)+extraTimeoutSeconds) - defer cancel() +type config struct { + configPath string + doPing bool + doTrace bool + skipRoute bool + timeout int +} - tunnel := vpn.NewClientFromOptions(opt) - err := tunnel.Start(ctx) - if err != nil { - return err +func main() { + log.SetLevel(log.DebugLevel) + + cfg := &config{} + flag.StringVar(&cfg.configPath, "config", "", "config file to load") + flag.BoolVar(&cfg.doPing, "ping", false, "if true, do ping and exit (for testing)") + flag.BoolVar(&cfg.doTrace, "trace", false, "if true, do a trace of the handshake and exit (for testing)") + flag.BoolVar(&cfg.skipRoute, "skip-route", false, "if true, exit without setting routes (for testing)") + flag.IntVar(&cfg.timeout, "timeout", 60, "timeout in seconds (default=60)") + flag.Parse() + + if cfg.configPath == "" { + fmt.Println("[error] need config path") + os.Exit(1) } - pinger := ping.New(target, tunnel) - pinger.Count = c - pinger.Timeout = timeoutSecondsFromCount(c) - err = pinger.Run(ctx) - if err != nil { - return err - } - pinger.PrintStats() + log.SetHandler(NewHandler(os.Stderr)) + log.SetLevel(log.DebugLevel) - return nil -} + opts := []model.Option{ + model.WithConfigFile(cfg.configPath), + model.WithLogger(log.Log), + } -func main() { - optConfig := getopt.StringLong("config", 'c', "", "Configuration file") - optServer := getopt.StringLong("server", 's', "", "VPN Server to connect to") - optTarget := getopt.StringLong("target", 't', "8.8.8.8", "Target for ICMP Ping") - optCount := getopt.Uint32Long("count", 'n', uint32(3), "Stop after sending these many ECHO_REQUEST packets") - optVerbosity := getopt.Uint16Long("verbosity", 'v', uint16(4), "Verbosity level (1 to 5, 1 is lowest)") + start := time.Now() + + var tracer *tracex.Tracer + if cfg.doTrace { + tracer = tracex.NewTracer(start) + opts = append(opts, model.WithHandshakeTracer(tracer)) + defer func() { + trace := tracer.Trace() + jsonData, err := json.MarshalIndent(trace, "", " ") + runtimex.PanicOnError(err, "cannot serialize trace") + fileName := fmt.Sprintf("handshake-trace-%s.json", time.Now().Format("2006-01-02-15:05:00")) + os.WriteFile(fileName, jsonData, 0644) + fmt.Println("trace written to", fileName) + }() + } - helpFlag := getopt.Bool('h', "Display help") + config := model.NewConfig(opts...) - getopt.Parse() - args := getopt.Args() + // connect to the server + dialer := networkio.NewDialer(log.Log, &net.Dialer{}) + ctx := context.Background() - if len(args) != 1 { - printUsage() + proto := config.Remote().Protocol + addr := config.Remote().AddrPort + conn, err := dialer.DialContext(ctx, proto, addr) + if err != nil { + log.WithError(err).Error("dialer.DialContext") + return } - if *helpFlag || (*optServer == "" && *optConfig == "") { - printUsage() - } + // The TLS will expire in 60 seconds by default, but we can pass + // a shorter timeout. + ctx, cancel := context.WithTimeout(ctx, time.Duration(cfg.timeout)*time.Second) + defer cancel() - var opts *vpn.Options - - verbosityLevel := log.InfoLevel - switch *optVerbosity { - case uint16(1): - verbosityLevel = log.FatalLevel - case uint16(2): - verbosityLevel = log.ErrorLevel - case uint16(3): - verbosityLevel = log.WarnLevel - case uint16(4): - verbosityLevel = log.InfoLevel - case uint16(5): - verbosityLevel = log.DebugLevel - default: - verbosityLevel = log.DebugLevel + // create a vpn tun Device + tunnel, err := tun.StartTUN(ctx, conn, config) + if err != nil { + log.WithError(err).Error("init error") + return } + log.Infof("Local IP: %s\n", tunnel.LocalAddr()) + log.Infof("Gateway: %s\n", tunnel.RemoteAddr()) - logger := &log.Logger{Level: verbosityLevel, Handler: &logHandler{Writer: os.Stderr}} - logger.Debugf("config file: %s", *optConfig) + fmt.Println("initialization-sequence-completed") + fmt.Printf("elapsed: %v\n", time.Since(start)) - opts, err := vpn.NewOptionsFromFilePath(*optConfig) - if err != nil { - fmt.Println("fatal: " + err.Error()) - os.Exit(1) + if cfg.doTrace { + return } - opts.Log = logger - switch args[0] { - case "ping": - err = RunPinger(opts, *optTarget, *optCount) + if cfg.doPing { + pinger := ping.New("8.8.8.8", tunnel) + count := 5 + pinger.Count = count + + err = pinger.Run(context.Background()) if err != nil { - logger.Error(err.Error()) + pinger.PrintStats() + log.WithError(err).Fatal("ping error") } - case "proxy": - // not actively tested at the moment - ListenAndServeSocks(opts) - default: - printUsage() + pinger.PrintStats() + os.Exit(0) } -} -type logHandler struct { - io.Writer -} + if cfg.skipRoute { + os.Exit(0) + } -func (h *logHandler) HandleLog(e *log.Entry) (err error) { - var s string - if e.Level == log.DebugLevel { - s = fmt.Sprintf("%s", e.Message) - } else if e.Level == log.ErrorLevel { - s = fmt.Sprintf("[%14.6f] %s", time.Since(startTime).Seconds(), e.Message) - } else { - s = fmt.Sprintf("[%14.6f] <%s> %s", time.Since(startTime).Seconds(), e.Level, e.Message) + // create a tun interface on the OS + iface, err := water.New(water.Config{DeviceType: water.TUN}) + runtimex.PanicOnError(err, "unable to open tun interface") + + // TODO: investigate what's the maximum working MTU, additionally get it from flag. + MTU := 1420 + iface.SetMTU(MTU) + + localAddr := tunnel.LocalAddr().String() + remoteAddr := tunnel.RemoteAddr().String() + netMask := tunnel.NetMask() + + // discover local gateway IP, we need it to add a route to our remote via our network gw + defaultGatewayIP, err := gateway.DiscoverGateway() + if err != nil { + log.Warn("could not discover default gateway IP, routes might be broken") } - if len(e.Fields) > 0 { - s += fmt.Sprintf(": %+v", e.Fields) + defaultInterfaceIP, err := gateway.DiscoverInterface() + if err != nil { + log.Warn("could not discover default route interface IP, routes might be broken") } - s += "\n" - _, err = h.Writer.Write([]byte(s)) - return + defaultInterface, err := getInterfaceByIP(defaultInterfaceIP.String()) + if err != nil { + log.Warn("could not get default route interface, routes might be broken") + } + + if defaultGatewayIP != nil && defaultInterface != nil { + log.Infof("route add %s gw %v dev %s", config.Remote().IPAddr, defaultGatewayIP, defaultInterface.Name) + runRoute("add", config.Remote().IPAddr, "gw", defaultGatewayIP.String(), defaultInterface.Name) + } + + // we want the network CIDR for setting up the routes + network := &net.IPNet{ + IP: net.ParseIP(localAddr).Mask(netMask), + Mask: netMask, + } + + // configure the interface and bring it up + runIP("addr", "add", localAddr, "dev", iface.Name()) + runIP("link", "set", "dev", iface.Name(), "up") + runRoute("add", remoteAddr, "gw", localAddr) + runRoute("add", "-net", network.String(), "dev", iface.Name()) + runIP("route", "add", "default", "via", remoteAddr, "dev", iface.Name()) + + go func() { + for { + packet := make([]byte, 2000) + n, err := iface.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + tunnel.Write(packet[:n]) + } + }() + go func() { + for { + packet := make([]byte, 2000) + n, err := tunnel.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + iface.Write(packet[:n]) + } + }() + select {} } diff --git a/cmd/minivpn/proxy.go b/cmd/minivpn/proxy.go deleted file mode 100644 index cc207a89..00000000 --- a/cmd/minivpn/proxy.go +++ /dev/null @@ -1,80 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "net" - "os" - "runtime" - "syscall" - - socks5 "github.com/armon/go-socks5" - "github.com/ooni/minivpn/vpn" -) - -const ( - socksPort = "8080" - socksIP = "127.0.0.1" -) - -// ListenAndServeSocks configures a vpn dialer, and configures and runs a -// socks5 server to use dialer.DialContext. The vpn dialer will initialize the tunnel -// upon receiving the first proxied request, and will reuse the same session -// for all further requests. -func ListenAndServeSocks(opts *vpn.Options) { - port := os.Getenv("LPORT") - if port == "" { - port = socksPort - } - ip := os.Getenv("LHOST") - if ip == "" { - ip = socksIP - } - dialer, err := vpn.StartNewTunDialerFromOptions(opts, &net.Dialer{}) - if err != nil { - panic(err) - } - conf := &socks5.Config{ - Dial: dialer.DialContext, - } - server, err := socks5.New(conf) - if err != nil { - panic(err) - } - - addr := net.JoinHostPort(ip, port) - fmt.Printf("[+] Starting socks5 proxy at %s\n", addr) - if err := server.ListenAndServe("tcp", addr); err != nil { - if isErrorAddressAlreadyInUse(err) { - fmt.Printf("[!] Address %s already in use\n", addr) - for i := 1; i < 1e4; i++ { - addr := net.JoinHostPort(ip, fmt.Sprintf("%d", i+1024)) - fmt.Println("[+] Trying to listen on", addr) - if err := server.ListenAndServe("tcp", addr); err != nil { - continue - } - } - } else { - panic(err) - } - } -} - -func isErrorAddressAlreadyInUse(err error) bool { - var eOsSyscall *os.SyscallError - if !errors.As(err, &eOsSyscall) { - return false - } - var errErrno syscall.Errno // doesn't need a "*" (ptr) because it's already a ptr (uintptr) - if !errors.As(eOsSyscall, &errErrno) { - return false - } - if errErrno == syscall.EADDRINUSE { - return true - } - const WSAEADDRINUSE = 10048 - if runtime.GOOS == "windows" && errErrno == WSAEADDRINUSE { - return true - } - return false -} diff --git a/extras/ping/ping.go b/extras/ping/ping.go index 652693fe..455e8dc1 100644 --- a/extras/ping/ping.go +++ b/extras/ping/ping.go @@ -396,6 +396,7 @@ func (p *Pinger) runLoop(recvCh <-chan *packet) error { } if p.Count > 0 && p.PacketsRecv >= p.Count { + p.done <- true return nil } } @@ -483,13 +484,16 @@ func (p *Pinger) recvICMP(recv chan<- *packet) error { case <-p.done: return nil default: + if p.PacketsRecv >= p.Count { + return nil + } buf := make([]byte, 512) if err := p.conn.SetReadDeadline(time.Now().Add(delay)); err != nil { return fmt.Errorf("%w: %s", errCannotSetReadDeadline, err) } n, err := p.conn.Read(buf) if err != nil { - var netErr *net.OpError + var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { // Read timeout delay = expBackoff.Get()