Skip to content

Commit

Permalink
Refactor into separate packages & add tests. (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzemek authored Mar 24, 2024
1 parent 201d278 commit 02c6d29
Show file tree
Hide file tree
Showing 14 changed files with 902 additions and 324 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test-startup.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Go
name: Test systemd

on:
push:
Expand All @@ -18,7 +18,7 @@ jobs:
go-version: "1.21"

- name: Build
run: go build -v ./...
run: go build -v

- name: Install go-mmproxy
run: |
Expand Down
31 changes: 31 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Test

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: "1.21"

- name: Build
run: go build -v

- name: Prepare ip routes
run: |
sudo ip rule add from 127.0.0.1/8 iif lo table 123
sudo ip route add local 0.0.0.0/0 dev lo table 123
sudo ip -6 rule add from ::1/128 iif lo table 123
sudo ip -6 route add local ::/0 dev lo table 123
- name: Test
run: sudo go test -v -timeout 30s ./tests
24 changes: 0 additions & 24 deletions buffers.go

This file was deleted.

28 changes: 28 additions & 0 deletions buffers/buffers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2019 Path Network, Inc. All rights reserved.
// Copyright 2024 Konrad Zemek <[email protected]>
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package buffers

import (
"math"
"sync"
)

var buffers sync.Pool

func init() {
buffers.New = func() any {
slice := make([]byte, math.MaxUint16)
return &slice
}
}

func Get() []byte {
return *buffers.Get().(*[]byte)
}

func Put(buf []byte) {
buffers.Put(&buf)
}
141 changes: 72 additions & 69 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,62 +1,58 @@
// Copyright 2019 Path Network, Inc. All rights reserved.
// Copyright 2024 Konrad Zemek <[email protected]>
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
"bufio"
"context"
"flag"
"log/slog"
"net"
"net/netip"
"os"
"syscall"
"time"

"github.com/kzemek/go-mmproxy/tcp"
"github.com/kzemek/go-mmproxy/udp"
"github.com/kzemek/go-mmproxy/utils"
)

type options struct {
Protocol string
ListenAddrStr string
TargetAddr4Str string
TargetAddr6Str string
ListenAddr netip.AddrPort
TargetAddr4 netip.AddrPort
TargetAddr6 netip.AddrPort
Mark int
Verbose int
allowedSubnetsPath string
AllowedSubnets []*net.IPNet
Listeners int
Logger *slog.Logger
udpCloseAfter int
UDPCloseAfter time.Duration
}
var protocolStr string
var listenAddrStr string
var targetAddr4Str string
var targetAddr6Str string
var allowedSubnetsPath string
var udpCloseAfterInt int
var listeners int

var Opts options
var opts utils.Options

func init() {
flag.StringVar(&Opts.Protocol, "p", "tcp", "Protocol that will be proxied: tcp, udp")
flag.StringVar(&Opts.ListenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
flag.StringVar(&Opts.TargetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
flag.StringVar(&Opts.TargetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
flag.IntVar(&Opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
flag.IntVar(&Opts.Verbose, "v", 0, `0 - no logging of individual connections
flag.StringVar(&protocolStr, "p", "tcp", "Protocol that will be proxied: tcp, udp")
flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections
1 - log errors occurring in individual connections
2 - log all state changes of individual connections`)
flag.StringVar(&Opts.allowedSubnetsPath, "allowed-subnets", "",
flag.StringVar(&allowedSubnetsPath, "allowed-subnets", "",
"Path to a file that contains allowed subnets of the proxy servers")
flag.IntVar(&Opts.Listeners, "listeners", 1,
flag.IntVar(&listeners, "listeners", 1,
"Number of listener sockets that will be opened for the listen address (Linux 3.9+)")
flag.IntVar(&Opts.udpCloseAfter, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
}

func listen(listenerNum int, errors chan<- error) {
logger := Opts.Logger.With(slog.Int("listenerNum", listenerNum),
slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr.String()))
func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) {
logger := parentLogger.With(slog.Int("listenerNum", listenerNum),
slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String()))

listenConfig := net.ListenConfig{}
if Opts.Listeners > 1 {
if listeners > 1 {
listenConfig.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
soReusePort := 15
Expand All @@ -67,15 +63,15 @@ func listen(listenerNum int, errors chan<- error) {
}
}

if Opts.Protocol == "tcp" {
tcpListen(&listenConfig, logger, errors)
if opts.Protocol == utils.TCP {
tcp.Listen(ctx, &listenConfig, &opts, logger, listenErrors)
} else {
udpListen(&listenConfig, logger, errors)
udp.Listen(ctx, &listenConfig, &opts, logger, listenErrors)
}
}

func loadAllowedSubnets() error {
file, err := os.Open(Opts.allowedSubnetsPath)
func loadAllowedSubnets(logger *slog.Logger) error {
file, err := os.Open(allowedSubnetsPath)
if err != nil {
return err
}
Expand All @@ -84,12 +80,12 @@ func loadAllowedSubnets() error {

scanner := bufio.NewScanner(file)
for scanner.Scan() {
_, ipNet, err := net.ParseCIDR(scanner.Text())
ipNet, err := netip.ParsePrefix(scanner.Text())
if err != nil {
return err
}
Opts.AllowedSubnets = append(Opts.AllowedSubnets, ipNet)
Opts.Logger.Info("allowed subnet", slog.String("subnet", ipNet.String()))
opts.AllowedSubnets = append(opts.AllowedSubnets, ipNet)
logger.Info("allowed subnet", slog.String("subnet", ipNet.String()))
}

return nil
Expand All @@ -98,72 +94,79 @@ func loadAllowedSubnets() error {
func main() {
flag.Parse()
lvl := slog.LevelInfo
if Opts.Verbose > 0 {
if opts.Verbose > 0 {
lvl = slog.LevelDebug
}
Opts.Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))

if Opts.allowedSubnetsPath != "" {
if err := loadAllowedSubnets(); err != nil {
Opts.Logger.Error("failed to load allowed subnets file", "path", Opts.allowedSubnetsPath, "error", err)
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))

if allowedSubnetsPath != "" {
if err := loadAllowedSubnets(logger); err != nil {
logger.Error("failed to load allowed subnets file", "path", allowedSubnetsPath, "error", err)
}
}

if Opts.Protocol != "tcp" && Opts.Protocol != "udp" {
Opts.Logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", Opts.Protocol))
if protocolStr == "tcp" {
opts.Protocol = utils.TCP
} else if protocolStr == "udp" {
opts.Protocol = utils.UDP
} else {
logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", protocolStr))
os.Exit(1)
}

if Opts.Mark < 0 {
Opts.Logger.Error("--mark has to be >= 0", slog.Int("mark", Opts.Mark))
if opts.Mark < 0 {
logger.Error("--mark has to be >= 0", slog.Int("mark", opts.Mark))
os.Exit(1)
}

if Opts.Verbose < 0 {
Opts.Logger.Error("-v has to be >= 0", slog.Int("verbose", Opts.Verbose))
if opts.Verbose < 0 {
logger.Error("-v has to be >= 0", slog.Int("verbose", opts.Verbose))
os.Exit(1)
}

if Opts.Listeners < 1 {
Opts.Logger.Error("--listeners has to be >= 1")
if listeners < 1 {
logger.Error("--listeners has to be >= 1")
os.Exit(1)
}

var err error
if Opts.ListenAddr, err = parseHostPort(Opts.ListenAddrStr); err != nil {
Opts.Logger.Error("listen address is malformed", "error", err)
if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil {
logger.Error("listen address is malformed", "error", err)
os.Exit(1)
}

if Opts.TargetAddr4, err = netip.ParseAddrPort(Opts.TargetAddr4Str); err != nil {
Opts.Logger.Error("ipv4 target address is malformed", "error", err)
if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil {
logger.Error("ipv4 target address is malformed", "error", err)
os.Exit(1)
}
if !Opts.TargetAddr4.Addr().Is4() {
Opts.Logger.Error("ipv4 target address is not IPv4")
if !opts.TargetAddr4.Addr().Is4() {
logger.Error("ipv4 target address is not IPv4")
os.Exit(1)
}

if Opts.TargetAddr6, err = netip.ParseAddrPort(Opts.TargetAddr6Str); err != nil {
Opts.Logger.Error("ipv6 target address is malformed", "error", err)
if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil {
logger.Error("ipv6 target address is malformed", "error", err)
os.Exit(1)
}
if !Opts.TargetAddr6.Addr().Is6() {
Opts.Logger.Error("ipv6 target address is not IPv6")
if !opts.TargetAddr6.Addr().Is6() {
logger.Error("ipv6 target address is not IPv6")
os.Exit(1)
}

if Opts.udpCloseAfter < 0 {
Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter))
if udpCloseAfterInt < 0 {
logger.Error("--close-after has to be >= 0", slog.Int("close-after", udpCloseAfterInt))
os.Exit(1)
}
Opts.UDPCloseAfter = time.Duration(Opts.udpCloseAfter) * time.Second
opts.UDPCloseAfter = time.Duration(udpCloseAfterInt) * time.Second

listenErrors := make(chan error, Opts.Listeners)
for i := 0; i < Opts.Listeners; i++ {
go listen(i, listenErrors)
listenErrors := make(chan error, listeners)
ctxs := make([]context.Context, listeners)
for i := range ctxs {
ctxs[i] = context.Background()
go listen(ctxs[i], i, logger, listenErrors)
}
for i := 0; i < Opts.Listeners; i++ {
for range ctxs {
<-listenErrors
}
}
Loading

0 comments on commit 02c6d29

Please sign in to comment.