diff --git a/cmd/minivpn/iface.go b/cmd/minivpn/iface.go new file mode 100644 index 00000000..4102d938 --- /dev/null +++ b/cmd/minivpn/iface.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "net" +) + +func getInterfaceByIP(ipAddr string) (*net.Interface, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.String() == ipAddr { + return &iface, nil + } + } + } + } + + return nil, fmt.Errorf("interface with IP %s not found", ipAddr) +} diff --git a/cmd/minivpn/log.go b/cmd/minivpn/log.go new file mode 100644 index 00000000..bf4a5c22 --- /dev/null +++ b/cmd/minivpn/log.go @@ -0,0 +1,79 @@ +package main + +import ( + "fmt" + "io" + "os" + "sync" + "time" + + "github.com/apex/log" +) + +// Default handler outputting to stderr. +var Default = NewHandler(os.Stderr) + +// start time. +var start = time.Now() + +// colors. +const ( + none = 0 + red = 31 + green = 32 + yellow = 33 + blue = 34 + gray = 37 +) + +// Colors mapping. +var Colors = [...]int{ + log.DebugLevel: gray, + log.InfoLevel: blue, + log.WarnLevel: yellow, + log.ErrorLevel: red, + log.FatalLevel: red, +} + +// Strings mapping. +var Strings = [...]string{ + log.DebugLevel: "DEBUG", + log.InfoLevel: "INFO", + log.WarnLevel: "WARN", + log.ErrorLevel: "ERROR", + log.FatalLevel: "FATAL", +} + +// Handler implementation. +type Handler struct { + mu sync.Mutex + Writer io.Writer +} + +// New handler. +func NewHandler(w io.Writer) *Handler { + return &Handler{ + Writer: w, + } +} + +// HandleLog implements log.Handler. +func (h *Handler) HandleLog(e *log.Entry) error { + color := Colors[e.Level] + level := Strings[e.Level] + names := e.Fields.Names() + + h.mu.Lock() + defer h.mu.Unlock() + + ts := time.Since(start) + fmt.Fprintf(h.Writer, "\033[%dm%6s\033[0m[%10v] %-25s", color, level, ts, e.Message) + + for _, name := range names { + fmt.Fprintf(h.Writer, " \033[%dm%s\033[0m=%v", color, name, e.Fields.Get(name)) + } + + fmt.Fprintln(h.Writer) + + return nil +} diff --git a/cmd/minivpn/main.go b/cmd/minivpn/main.go index 1ecff11a..a61d56ea 100644 --- a/cmd/minivpn/main.go +++ b/cmd/minivpn/main.go @@ -2,140 +2,205 @@ package main import ( "context" + "encoding/json" + "flag" "fmt" - "io" + "net" "os" + "os/exec" "time" + "github.com/Doridian/water" "github.com/apex/log" - "github.com/pborman/getopt/v2" + "github.com/jackpal/gateway" "github.com/ooni/minivpn/extras/ping" - "github.com/ooni/minivpn/vpn" + "github.com/ooni/minivpn/internal/model" + "github.com/ooni/minivpn/internal/networkio" + "github.com/ooni/minivpn/internal/runtimex" + "github.com/ooni/minivpn/internal/tracex" + "github.com/ooni/minivpn/internal/tun" ) -var ( - startTime = time.Now() - extraTimeoutSeconds = 10 * time.Second -) +func runCmd(binaryPath string, args ...string) { + cmd := exec.Command(binaryPath, args...) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + err := cmd.Run() + if nil != err { + log.WithError(err).Warn("error running /sbin/ip") + } +} -func printUsage() { - fmt.Println("valid commands: ping, proxy") - getopt.Usage() - os.Exit(0) +func runIP(args ...string) { + runCmd("/sbin/ip", args...) } -func timeoutSecondsFromCount(count int) time.Duration { - waitOnLastOne := 2 * time.Second - return time.Duration(count)*time.Second + waitOnLastOne +func runRoute(args ...string) { + runCmd("/sbin/route", args...) +} +type config struct { + configPath string + doPing bool + doTrace bool + skipRoute bool + timeout int } -// RunPinger takes an Option object, starts a Client, and runs a Pinger against -// the passed target, for a number count of packets. -func RunPinger(opt *vpn.Options, target string, count uint32) error { - c := int(count) - ctx, cancel := context.WithTimeout(context.Background(), timeoutSecondsFromCount(c)+extraTimeoutSeconds) - defer cancel() +func main() { + log.SetLevel(log.DebugLevel) + + cfg := &config{} + flag.StringVar(&cfg.configPath, "config", "", "config file to load") + flag.BoolVar(&cfg.doPing, "ping", false, "if true, do ping and exit (for testing)") + flag.BoolVar(&cfg.doTrace, "trace", false, "if true, do a trace of the handshake and exit (for testing)") + flag.BoolVar(&cfg.skipRoute, "skip-route", false, "if true, exit without setting routes (for testing)") + flag.IntVar(&cfg.timeout, "timeout", 60, "timeout in seconds (default=60)") + flag.Parse() + + if cfg.configPath == "" { + fmt.Println("[error] need config path") + os.Exit(1) + } - tunnel := vpn.NewClientFromOptions(opt) - err := tunnel.Start(ctx) - if err != nil { - return err + log.SetHandler(NewHandler(os.Stderr)) + log.SetLevel(log.DebugLevel) + + opts := []model.Option{ + model.WithConfigFile(cfg.configPath), + model.WithLogger(log.Log), } - pinger := ping.New(target, tunnel) - pinger.Count = c - pinger.Timeout = timeoutSecondsFromCount(c) - err = pinger.Run(ctx) - if err != nil { - return err + start := time.Now() + + if cfg.doTrace { + opts = append(opts, model.WithHandshakeTracer(tracex.NewTracer(start))) } - pinger.PrintStats() - return nil -} + config := model.NewConfig(opts...) -func main() { - optConfig := getopt.StringLong("config", 'c', "", "Configuration file") - optServer := getopt.StringLong("server", 's', "", "VPN Server to connect to") - optTarget := getopt.StringLong("target", 't', "8.8.8.8", "Target for ICMP Ping") - optCount := getopt.Uint32Long("count", 'n', uint32(3), "Stop after sending these many ECHO_REQUEST packets") - optVerbosity := getopt.Uint16Long("verbosity", 'v', uint16(4), "Verbosity level (1 to 5, 1 is lowest)") + // connect to the server + dialer := networkio.NewDialer(log.Log, &net.Dialer{}) + ctx := context.Background() - helpFlag := getopt.Bool('h', "Display help") + proto := config.Remote().Protocol + addr := config.Remote().AddrPort - getopt.Parse() - args := getopt.Args() + conn, err := dialer.DialContext(ctx, proto, addr) + if err != nil { + log.WithError(err).Fatal("dialer.DialContext") + } - if len(args) != 1 { - printUsage() + // The TLS will expire in 60 seconds by default, but we can pass + // a shorter timeout. + ctx, cancel := context.WithTimeout(ctx, time.Duration(cfg.timeout)*time.Second) + defer cancel() + // create a vpn tun Device + tunnel, err := tun.StartTUN(ctx, conn, config) + if err != nil { + log.WithError(err).Fatal("init error") + return + } + log.Infof("Local IP: %s\n", tunnel.LocalAddr()) + log.Infof("Gateway: %s\n", tunnel.RemoteAddr()) + + fmt.Println("initialization-sequence-completed") + fmt.Printf("elapsed: %v\n", time.Since(start)) + + if cfg.doTrace { + trace := config.Tracer().Trace() + jsonData, err := json.MarshalIndent(trace, "", " ") + runtimex.PanicOnError(err, "cannot serialize trace") + fileName := "handshake-trace.json" + os.WriteFile(fileName, jsonData, 0644) + fmt.Println("trace written to", fileName) + os.Exit(0) } - if *helpFlag || (*optServer == "" && *optConfig == "") { - printUsage() + if cfg.doPing { + pinger := ping.New("8.8.8.8", tunnel) + count := 5 + pinger.Count = count + + err = pinger.Run(context.Background()) + if err != nil { + pinger.PrintStats() + log.WithError(err).Fatal("ping error") + } + pinger.PrintStats() + os.Exit(0) } - var opts *vpn.Options - - verbosityLevel := log.InfoLevel - switch *optVerbosity { - case uint16(1): - verbosityLevel = log.FatalLevel - case uint16(2): - verbosityLevel = log.ErrorLevel - case uint16(3): - verbosityLevel = log.WarnLevel - case uint16(4): - verbosityLevel = log.InfoLevel - case uint16(5): - verbosityLevel = log.DebugLevel - default: - verbosityLevel = log.DebugLevel + if cfg.skipRoute { + os.Exit(0) } - logger := &log.Logger{Level: verbosityLevel, Handler: &logHandler{Writer: os.Stderr}} - logger.Debugf("config file: %s", *optConfig) + // create a tun interface on the OS + iface, err := water.New(water.Config{DeviceType: water.TUN}) + runtimex.PanicOnError(err, "unable to open tun interface") + + // TODO: investigate what's the maximum working MTU, additionally get it from flag. + MTU := 1420 + iface.SetMTU(MTU) + + localAddr := tunnel.LocalAddr().String() + remoteAddr := tunnel.RemoteAddr().String() + netMask := tunnel.NetMask() - opts, err := vpn.NewOptionsFromFilePath(*optConfig) + // discover local gateway IP, we need it to add a route to our remote via our network gw + defaultGatewayIP, err := gateway.DiscoverGateway() if err != nil { - fmt.Println("fatal: " + err.Error()) - os.Exit(1) + log.Warn("could not discover default gateway IP, routes might be broken") } - opts.Log = logger - - switch args[0] { - case "ping": - err = RunPinger(opts, *optTarget, *optCount) - if err != nil { - logger.Error(err.Error()) - } - case "proxy": - // not actively tested at the moment - ListenAndServeSocks(opts) - default: - printUsage() + defaultInterfaceIP, err := gateway.DiscoverInterface() + if err != nil { + log.Warn("could not discover default route interface IP, routes might be broken") + } + defaultInterface, err := getInterfaceByIP(defaultInterfaceIP.String()) + if err != nil { + log.Warn("could not get default route interface, routes might be broken") } -} - -type logHandler struct { - io.Writer -} -func (h *logHandler) HandleLog(e *log.Entry) (err error) { - var s string - if e.Level == log.DebugLevel { - s = fmt.Sprintf("%s", e.Message) - } else if e.Level == log.ErrorLevel { - s = fmt.Sprintf("[%14.6f] %s", time.Since(startTime).Seconds(), e.Message) - } else { - s = fmt.Sprintf("[%14.6f] <%s> %s", time.Since(startTime).Seconds(), e.Level, e.Message) + if defaultGatewayIP != nil && defaultInterface != nil { + log.Infof("route add %s gw %v dev %s", config.Remote().IPAddr, defaultGatewayIP, defaultInterface.Name) + runRoute("add", config.Remote().IPAddr, "gw", defaultGatewayIP.String(), defaultInterface.Name) } - if len(e.Fields) > 0 { - s += fmt.Sprintf(": %+v", e.Fields) + + // we want the network CIDR for setting up the routes + network := &net.IPNet{ + IP: net.ParseIP(localAddr).Mask(netMask), + Mask: netMask, } - s += "\n" - _, err = h.Writer.Write([]byte(s)) - return + + // configure the interface and bring it up + runIP("addr", "add", localAddr, "dev", iface.Name()) + runIP("link", "set", "dev", iface.Name(), "up") + runRoute("add", remoteAddr, "gw", localAddr) + runRoute("add", "-net", network.String(), "dev", iface.Name()) + runIP("route", "add", "default", "via", remoteAddr, "dev", iface.Name()) + + go func() { + for { + packet := make([]byte, 2000) + n, err := iface.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + tunnel.Write(packet[:n]) + } + }() + go func() { + for { + packet := make([]byte, 2000) + n, err := tunnel.Read(packet) + if err != nil { + log.WithError(err).Fatal("error reading from tun") + } + iface.Write(packet[:n]) + } + }() + select {} } diff --git a/cmd/minivpn/proxy.go b/cmd/minivpn/proxy.go deleted file mode 100644 index cc207a89..00000000 --- a/cmd/minivpn/proxy.go +++ /dev/null @@ -1,80 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "net" - "os" - "runtime" - "syscall" - - socks5 "github.com/armon/go-socks5" - "github.com/ooni/minivpn/vpn" -) - -const ( - socksPort = "8080" - socksIP = "127.0.0.1" -) - -// ListenAndServeSocks configures a vpn dialer, and configures and runs a -// socks5 server to use dialer.DialContext. The vpn dialer will initialize the tunnel -// upon receiving the first proxied request, and will reuse the same session -// for all further requests. -func ListenAndServeSocks(opts *vpn.Options) { - port := os.Getenv("LPORT") - if port == "" { - port = socksPort - } - ip := os.Getenv("LHOST") - if ip == "" { - ip = socksIP - } - dialer, err := vpn.StartNewTunDialerFromOptions(opts, &net.Dialer{}) - if err != nil { - panic(err) - } - conf := &socks5.Config{ - Dial: dialer.DialContext, - } - server, err := socks5.New(conf) - if err != nil { - panic(err) - } - - addr := net.JoinHostPort(ip, port) - fmt.Printf("[+] Starting socks5 proxy at %s\n", addr) - if err := server.ListenAndServe("tcp", addr); err != nil { - if isErrorAddressAlreadyInUse(err) { - fmt.Printf("[!] Address %s already in use\n", addr) - for i := 1; i < 1e4; i++ { - addr := net.JoinHostPort(ip, fmt.Sprintf("%d", i+1024)) - fmt.Println("[+] Trying to listen on", addr) - if err := server.ListenAndServe("tcp", addr); err != nil { - continue - } - } - } else { - panic(err) - } - } -} - -func isErrorAddressAlreadyInUse(err error) bool { - var eOsSyscall *os.SyscallError - if !errors.As(err, &eOsSyscall) { - return false - } - var errErrno syscall.Errno // doesn't need a "*" (ptr) because it's already a ptr (uintptr) - if !errors.As(eOsSyscall, &errErrno) { - return false - } - if errErrno == syscall.EADDRINUSE { - return true - } - const WSAEADDRINUSE = 10048 - if runtime.GOOS == "windows" && errErrno == WSAEADDRINUSE { - return true - } - return false -}