diff --git a/outbound/wireguard.go b/outbound/wireguard.go index 7805e165a1..d34023653a 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -180,7 +180,6 @@ func (w *WireGuard) Close() error { if w.pauseCallback != nil { w.pauseManager.UnregisterCallback(w.pauseCallback) } - w.tunDevice.Close() return nil } diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index 7f57b7c73a..d5770419e2 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -230,17 +230,13 @@ func (w *StackDevice) Events() <-chan wgTun.Event { } func (w *StackDevice) Close() error { - select { - case <-w.done: - return os.ErrClosed - default: - } + close(w.done) + close(w.events) w.stack.Close() for _, endpoint := range w.stack.CleanupEndpoints() { endpoint.Abort() } w.stack.Wait() - close(w.done) return nil } diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index 49acc5b90e..2c16c53dfc 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "os" + "sync" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -21,14 +22,16 @@ import ( var _ Device = (*SystemDevice)(nil) type SystemDevice struct { - dialer N.Dialer - device tun.Tun - batchDevice tun.LinuxTUN - name string - mtu int - events chan wgTun.Event - addr4 netip.Addr - addr6 netip.Addr + dialer N.Dialer + device tun.Tun + batchDevice tun.LinuxTUN + name string + mtu uint32 + inet4Addresses []netip.Prefix + inet6Addresses []netip.Prefix + gso bool + events chan wgTun.Event + closeOnce sync.Once } func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) { @@ -44,43 +47,17 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes if interfaceName == "" { interfaceName = tun.CalculateInterfaceName("wg") } - tunInterface, err := tun.New(tun.Options{ - Name: interfaceName, - Inet4Address: inet4Addresses, - Inet6Address: inet6Addresses, - MTU: mtu, - GSO: gso, - }) - if err != nil { - return nil, err - } - var inet4Address netip.Addr - var inet6Address netip.Addr - if len(inet4Addresses) > 0 { - inet4Address = inet4Addresses[0].Addr() - } - if len(inet6Addresses) > 0 { - inet6Address = inet6Addresses[0].Addr() - } - var batchDevice tun.LinuxTUN - if gso { - batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN) - if !isBatchTUN { - return nil, E.New("GSO is not supported on current platform") - } - batchDevice = batchTUN - } + return &SystemDevice{ dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{ BindInterface: interfaceName, })), - device: tunInterface, - batchDevice: batchDevice, - name: interfaceName, - mtu: int(mtu), - events: make(chan wgTun.Event), - addr4: inet4Address, - addr6: inet6Address, + name: interfaceName, + mtu: mtu, + inet4Addresses: inet4Addresses, + inet6Addresses: inet6Addresses, + gso: gso, + events: make(chan wgTun.Event), }, nil } @@ -93,14 +70,39 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr } func (w *SystemDevice) Inet4Address() netip.Addr { - return w.addr4 + if len(w.inet4Addresses) == 0 { + return netip.Addr{} + } + return w.inet4Addresses[0].Addr() } func (w *SystemDevice) Inet6Address() netip.Addr { - return w.addr6 + if len(w.inet6Addresses) == 0 { + return netip.Addr{} + } + return w.inet6Addresses[0].Addr() } func (w *SystemDevice) Start() error { + tunInterface, err := tun.New(tun.Options{ + Name: w.name, + Inet4Address: w.inet4Addresses, + Inet6Address: w.inet6Addresses, + MTU: w.mtu, + GSO: w.gso, + }) + if err != nil { + return err + } + w.device = tunInterface + if w.gso { + batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN) + if !isBatchTUN { + tunInterface.Close() + return E.New("GSO is not supported on current platform") + } + w.batchDevice = batchTUN + } w.events <- wgTun.EventUp return nil } @@ -143,7 +145,7 @@ func (w *SystemDevice) Flush() error { } func (w *SystemDevice) MTU() (int, error) { - return w.mtu, nil + return int(w.mtu), nil } func (w *SystemDevice) Name() (string, error) { @@ -155,6 +157,7 @@ func (w *SystemDevice) Events() <-chan wgTun.Event { } func (w *SystemDevice) Close() error { + close(w.events) return w.device.Close() }