diff --git a/pkg/wireguard/wireguard.go b/pkg/wireguard/wireguard.go index c6407a01..f2977bd7 100644 --- a/pkg/wireguard/wireguard.go +++ b/pkg/wireguard/wireguard.go @@ -1,6 +1,7 @@ package wireguard import ( + "errors" "fmt" "net" "os/exec" @@ -88,7 +89,7 @@ func createLinkUsingUserspaceImpl(iface string, wgUserspaceImplementationFallbac } -func createLinkUsingKernalModule(iface string) error { +func createLinkUsingKernelModule(iface string) error { // link not created wgLink := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{ @@ -104,49 +105,25 @@ func createLinkUsingKernalModule(iface string) error { return nil } -func SyncLink(_ agent.State, iface string, wgUserspaceImplementationFallback string, wgUseUserspaceImpl bool) error { - _, err := netlink.LinkByName(iface) +func SyncLink(state agent.State, iface string, wgUserspaceImplementationFallback string, wgUseUserspaceImpl bool) error { + link, err := netlink.LinkByName(iface) if err != nil { - if _, ok := err.(netlink.LinkNotFoundError); !ok { - return err - } - } + if errors.As(err, &netlink.LinkNotFoundError{}) { + if wgUseUserspaceImpl { + if err := createLinkUsingUserspaceImpl(iface, wgUserspaceImplementationFallback); err != nil { + return fmt.Errorf("create link using user space impl: %w", err) + } - if _, ok := err.(netlink.LinkNotFoundError); ok { - if wgUseUserspaceImpl { - err = createLinkUsingUserspaceImpl(iface, wgUserspaceImplementationFallback) + } else if err := createLinkUsingKernelModule(iface); err != nil { + fmt.Printf("could not create link using kernel module, will attempt to fallback to user space implementation: %v\n", err) - if err != nil { - return err + // Fallback to user space implementation. + wgUseUserspaceImpl = true } - } else { - err = createLinkUsingKernalModule(iface) - - if err != nil { - err = createLinkUsingUserspaceImpl(iface, wgUserspaceImplementationFallback) - - if err != nil { - return err - } - } - } - - // TODO: Can this be removed? - link, err := netlink.LinkByName(iface) - if err != nil { - return err - } - if err := netlink.LinkSetUp(link); err != nil { - return err - } - } - - link, err := netlink.LinkByName(iface) - if err != nil { - if _, ok := err.(netlink.LinkNotFoundError); !ok { - return err + return SyncLink(state, iface, wgUserspaceImplementationFallback, wgUseUserspaceImpl) } + return fmt.Errorf("link by name: %w", err) } addresses, err := netlink.AddrList(link, syscall.AF_INET) @@ -165,7 +142,7 @@ func SyncLink(_ agent.State, iface string, wgUserspaceImplementationFallback str } if err := netlink.LinkSetUp(link); err != nil { - return err + return fmt.Errorf("link set up: %w", err) } return nil }