Skip to content

Commit

Permalink
Merge pull request #6 from muzzammilshahid/ensure-setup
Browse files Browse the repository at this point in the history
Ensure wireguard installation
  • Loading branch information
muzzammilshahid authored Oct 1, 2024
2 parents 9bf47ad + d0aecda commit 3cb9dda
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 1 deletion.
4 changes: 4 additions & 0 deletions cmd/http-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func main() {
log.Fatalln("You need to run this service as root")
}

if err := wireguard_admin_service.EnsureWireguardInstallation(); err != nil {
log.Fatalln(err)
}

r := gin.Default()

const qrDir = "./qr-codes"
Expand Down
180 changes: 179 additions & 1 deletion wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"regexp"
Expand All @@ -17,9 +18,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

const wireguardParamsFile = "/etc/wireguard/params"

// AddUser adds new wireguard user.
func AddUser(clientName string) error {
params, err := godotenv.Read("/etc/wireguard/params")
params, err := godotenv.Read(wireguardParamsFile)
if err != nil {
return fmt.Errorf("error loading .env file: %w", err)
}
Expand Down Expand Up @@ -244,3 +247,178 @@ func syncWireGuardConfig(serverWgNic string) error {

return nil
}

func EnsureWireguardInstallation() error {
if _, err := os.Stat(wireguardParamsFile); err == nil {
return nil
}

if err := os.MkdirAll("/etc/wireguard", 0600); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}

serverPrivKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return fmt.Errorf("failed to generate client private key: %w", err)
}
serverPubKey := serverPrivKey.PublicKey()

serverPublicIP, err := getServerPubIP()
if err != nil {
return fmt.Errorf("failed to get server public IP: %w", err)
}

pubNIC, err := getDefaultNetworkInterface()
if err != nil {
return fmt.Errorf("failed to get default network interface: %w", err)
}

serverPort, err := getRandomPort()
if err != nil {
return fmt.Errorf("failed to get server port: %w", err)
}

err = runCommand("apt-get", []string{"install", "-y", "wireguard", "iptables", "resolvconf", "qrencode"}, "")
if err != nil {
return fmt.Errorf("failed to install wireguard: %w", err)
}

if err := os.WriteFile(wireguardParamsFile, []byte(fmt.Sprintf(`SERVER_PUB_IP=%s
SERVER_PUB_NIC=%s
SERVER_WG_NIC=wg0
SERVER_WG_IPV4=10.66.66.1
SERVER_WG_IPV6=fd42:42:42::1
SERVER_PORT=%v
SERVER_PRIV_KEY=%s
SERVER_PUB_KEY=%s
CLIENT_DNS_1=1.1.1.1
CLIENT_DNS_2=1.0.0.1
ALLOWED_IPS=0.0.0.0/0,::/0
`, serverPublicIP, pubNIC, serverPort, serverPrivKey, serverPubKey)), 0600); err != nil {
return fmt.Errorf("failed to write file to %s: %w", wireguardParamsFile, err)
}

err = os.WriteFile("/etc/wireguard/wg0.conf", []byte(fmt.Sprintf(`[Interface]
Address = 10.66.66.1/24,fd42:42:42::1/64
ListenPort = %v
PrivateKey = %s
PostUp = iptables -I INPUT -p udp --dport %v -j ACCEPT
PostUp = iptables -I FORWARD -i %s -o wg0 -j ACCEPT
PostUp = iptables -I FORWARD -i wg0 -j ACCEPT
PostUp = iptables -t nat -A POSTROUTING -o %s -j MASQUERADE
PostUp = ip6tables -I FORWARD -i wg0 -j ACCEPT
PostUp = ip6tables -t nat -A POSTROUTING -o %s -j MASQUERADE
PostDown = iptables -D INPUT -p udp --dport %v -j ACCEPT
PostDown = iptables -D FORWARD -i %s -o wg0 -j ACCEPT
PostDown = iptables -D FORWARD -i wg0 -j ACCEPT
PostDown = iptables -t nat -D POSTROUTING -o %s -j MASQUERADE
PostDown = ip6tables -D FORWARD -i wg0 -j ACCEPT
PostDown = ip6tables -t nat -D POSTROUTING -o %s -j MASQUERADE
`, serverPort, serverPrivKey, serverPort, pubNIC, pubNIC, pubNIC, serverPort, pubNIC, pubNIC, pubNIC)), 0600)
if err != nil {
return fmt.Errorf("failed to write file to %s: %w", wireguardParamsFile, err)
}

err = os.WriteFile("/etc/sysctl.d/wg.conf", []byte(`net.ipv4.ip_forward = 1
net.ipv6.conf.all.forwarding = 1`), 0600)
if err != nil {
return fmt.Errorf("failed to write file to /etc/sysctl.d/wg.conf: %w", err)
}

if err = runCommand("sysctl", []string{"--system"}, ""); err != nil {
return fmt.Errorf("failed to run command: %w", err)
}

if err = runCommand("systemctl", []string{"start", "wg-quick@wg0"}, ""); err != nil {
return fmt.Errorf("failed to start wg-quick: %w", err)
}

if err = runCommand("systemctl", []string{"enable", "wg-quick@wg0"}, ""); err != nil {
return fmt.Errorf("failed to enable wg-quick: %w", err)
}

return nil
}

func getServerPubIP() (string, error) {
interfaces, err := net.Interfaces()
if err != nil {
return "", fmt.Errorf("error fetching network interfaces %w", err)
}

for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
return "", fmt.Errorf("error fetching addresses for interface: %w", err)
}

for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
fmt.Println("Error parsing CIDR:", err)
continue
}

if ip.IsGlobalUnicast() {
if ip.To4() != nil {
return ip.String(), nil
} else if ip.To4() == nil {
return ip.String(), nil
}
}
}
}

return "", fmt.Errorf("no IP address found")
}

// getDefaultNetworkInterface returns the name of the default network interface.
func getDefaultNetworkInterface() (string, error) {
interfaces, err := net.Interfaces()
if err != nil {
return "", err
}

// Iterate through interfaces to find the one with an IPv4 address and set as default
for _, iface := range interfaces {
// Skip down or loopback interfaces
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}

addrs, err := iface.Addrs()
if err != nil {
return "", err
}

for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}

// Ensure it is a non-loopback IPv4 address
if ip != nil && ip.To4() != nil {
return iface.Name, nil
}
}
}

return "", fmt.Errorf("no default network interface found")
}

func getRandomPort() (int, error) {
const minPort = 49152
const maxPort = 65535

rangeSize := maxPort - minPort + 1
n, err := rand.Int(rand.Reader, big.NewInt(int64(rangeSize)))
if err != nil {
return 0, err
}

return int(n.Int64()) + minPort, nil
}

0 comments on commit 3cb9dda

Please sign in to comment.