Skip to content

Commit

Permalink
fix(portforward): trigger after VPN restart
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Sep 28, 2023
1 parent a194906 commit d4df872
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 79 deletions.
48 changes: 27 additions & 21 deletions internal/portforward/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

type Loop struct {
// State
settings service.Settings
settings Settings
settingsMutex sync.RWMutex
service Service
// Fixed injected objets
Expand All @@ -28,16 +28,20 @@ type Loop struct {
runCtx context.Context //nolint:containedctx
runCancel context.CancelFunc
runDone <-chan struct{}
updateTrigger chan<- service.Settings
updateTrigger chan<- Settings
updatedResult <-chan error
}

func NewLoop(settings settings.PortForwarding, routing Routing,
client *http.Client, portAllower PortAllower,
logger Logger, uid, gid int) *Loop {
return &Loop{
settings: service.Settings{
UserSettings: settings,
settings: Settings{
VPNIsUp: ptrTo(false),
Service: service.Settings{
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
},
},
routing: routing,
client: client,
Expand All @@ -57,24 +61,22 @@ func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) {
runDone := make(chan struct{})
l.runDone = runDone

updateTrigger := make(chan service.Settings)
updateTrigger := make(chan Settings)
l.updateTrigger = updateTrigger
updateResult := make(chan error)
l.updatedResult = updateResult
runErrorCh := make(chan error)

go l.run(l.runCtx, runDone, runErrorCh,
l.settings, updateTrigger, updateResult)
go l.run(l.runCtx, runDone, runErrorCh, updateTrigger, updateResult)

return runErrorCh, nil
}

func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
runErrorCh chan<- error, initialSettings service.Settings,
updateTrigger <-chan service.Settings, updateResult chan<- error) {
runErrorCh chan<- error, updateTrigger <-chan Settings,
updateResult chan<- error) {
defer close(runDone)

settings := initialSettings
var serviceRunError <-chan error
for {
updateReceived := false
Expand All @@ -83,18 +85,20 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
// Stop call takes care of stopping the service
return
case partialUpdate := <-updateTrigger:
updatedSettings, err := settings.UpdateWith(partialUpdate)
updatedSettings, err := l.settings.updateWith(partialUpdate)
if err != nil {
updateResult <- err
continue
}
settings = updatedSettings
updateReceived = true
l.settingsMutex.Lock()
l.settings = updatedSettings
l.settingsMutex.Unlock()
case err := <-serviceRunError:
l.logger.Error(err.Error())
}

firstRun := l.service == nil
firstRun := serviceRunError == nil
if !firstRun {
err := l.service.Stop()
if err != nil {
Expand All @@ -103,7 +107,11 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
}
}

l.service = service.New(settings, l.routing, l.client,
serviceSettings := l.settings.Service.Copy()
// Only enable port forward if the VPN tunnel is up
*serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp

l.service = service.New(serviceSettings, l.routing, l.client,
l.portAllower, l.logger, l.uid, l.gid)

var err error
Expand All @@ -119,16 +127,10 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
}
return
}

// Service is created and started successfully, so update
// the settings for external calls such as GetSettings.
l.settingsMutex.Lock()
l.settings = settings
l.settingsMutex.Unlock()
}
}

func (l *Loop) UpdateWith(partialUpdate service.Settings) (err error) {
func (l *Loop) UpdateWith(partialUpdate Settings) (err error) {
select {
case l.updateTrigger <- partialUpdate:
select {
Expand Down Expand Up @@ -159,3 +161,7 @@ func (l *Loop) GetPortForwarded() (port uint16) {
}
return l.service.GetPortForwarded()
}

func ptrTo[T any](value T) *T {
return &value
}
2 changes: 1 addition & 1 deletion internal/portforward/service/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func (s *Service) writePortForwardedFile(port uint16) (err error) {
filepath := *s.settings.UserSettings.Filepath
filepath := s.settings.Filepath
s.logger.Info("writing port file " + filepath)
const perms = os.FileMode(0644)
err = os.WriteFile(filepath, []byte(fmt.Sprint(port)), perms)
Expand Down
56 changes: 19 additions & 37 deletions internal/portforward/service/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,69 +4,51 @@ import (
"errors"
"fmt"

"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gosettings"
)

type Settings struct {
UserSettings settings.PortForwarding
Enabled *bool
PortForwarder PortForwarder
Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA
VPNProvider string // used to validate new settings
}

// UpdateWith deep copies the receiving settings, overrides the copy with
// fields set in the partialUpdate argument, validates the new settings
// and returns them if they are valid, or returns an error otherwise.
// In all cases, the receiving settings are unmodified.
func (s Settings) UpdateWith(partialUpdate Settings) (updatedSettings Settings, err error) {
updatedSettings = s.copy()
updatedSettings.overrideWith(partialUpdate)
err = updatedSettings.validate()
if err != nil {
return updatedSettings, fmt.Errorf("validating new settings: %w", err)
}
return updatedSettings, nil
}

func (s Settings) copy() (copied Settings) {
copied.UserSettings = s.UserSettings.Copy()
func (s Settings) Copy() (copied Settings) {
copied.Enabled = gosettings.CopyPointer(s.Enabled)
copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.VPNProvider = s.VPNProvider
return copied
}

func (s *Settings) overrideWith(update Settings) {
s.UserSettings.OverrideWith(update.UserSettings)
func (s *Settings) OverrideWith(update Settings) {
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
s.PortForwarder = gosettings.OverrideWithInterface(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
s.VPNProvider = gosettings.OverrideWithString(s.VPNProvider, update.VPNProvider)
}

var (
ErrVPNProviderNotSet = errors.New("VPN provider not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrPortForwarderNotSet = errors.New("port forwarder not set")
ErrGatewayNotSet = errors.New("gateway not set")
ErrInterfaceNotSet = errors.New("interface not set")
ErrServerNameNotSet = errors.New("server name not set")
ErrFilepathNotSet = errors.New("file path not set")
ErrInterfaceNotSet = errors.New("interface not set")
)

func (s *Settings) validate() (err error) {
func (s *Settings) Validate() (err error) {
switch {
case s.VPNProvider == "":
return fmt.Errorf("%w", ErrVPNProviderNotSet)
case s.VPNProvider == providers.PrivateInternetAccess && s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
case s.PortForwarder == nil:
return fmt.Errorf("%w", ErrPortForwarderNotSet)
// Port forwarder can be nil when the loop updates
// to stop the service.
case s.Filepath == "":
return fmt.Errorf("%w", ErrFilepathNotSet)
case s.Interface == "":
return fmt.Errorf("%w", ErrInterfaceNotSet)
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerName == "":
return fmt.Errorf("%w", ErrServerNameNotSet)
}

return s.UserSettings.Validate(s.VPNProvider)
return nil
}
4 changes: 3 additions & 1 deletion internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()

if !*s.settings.UserSettings.Enabled {
if !*s.settings.Enabled {
return nil, nil //nolint:nilnil
}

Expand Down Expand Up @@ -64,6 +64,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
if !crashed { // stopped by Stop call
return
}
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
_ = s.cleanup()
runError <- err
}(keepPortCtx, s.settings.PortForwarder, obj, runErrorCh, keepPortDoneCh)
Expand Down
3 changes: 2 additions & 1 deletion internal/portforward/service/stop.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func (s *Service) Stop() (err error) {
serviceNotRunning := s.port == 0
s.portMutex.RUnlock()
if serviceNotRunning {
// TODO replace with goservices.ErrAlreadyStopped
return nil
}

Expand All @@ -36,7 +37,7 @@ func (s *Service) cleanup() (err error) {

s.port = 0

filepath := *s.settings.UserSettings.Filepath
filepath := s.settings.Filepath
s.logger.Info("removing port file " + filepath)
err = os.Remove(filepath)
if err != nil {
Expand Down
43 changes: 43 additions & 0 deletions internal/portforward/settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package portforward

import (
"github.com/qdm12/gluetun/internal/portforward/service"
"github.com/qdm12/gosettings"
)

type Settings struct {
// VPNIsUp can be optionally set to signal the loop
// the VPN is up (true) or down (false). If left to nil,
// it is assumed the VPN is in the same previous state.
VPNIsUp *bool
Service service.Settings
}

// updateWith deep copies the receiving settings, overrides the copy with
// fields set in the partialUpdate argument, validates the new settings
// and returns them if they are valid, or returns an error otherwise.
// In all cases, the receiving settings are unmodified.
func (s Settings) updateWith(partialUpdate Settings) (updated Settings, err error) {
updated = s.copy()
updated.overrideWith(partialUpdate)
err = updated.validate()
if err != nil {
return updated, err
}
return updated, nil
}

func (s Settings) copy() (copied Settings) {
copied.VPNIsUp = gosettings.CopyPointer(s.VPNIsUp)
copied.Service = s.Service.Copy()
return copied
}

func (s *Settings) overrideWith(update Settings) {
s.VPNIsUp = gosettings.OverrideWithPointer(s.VPNIsUp, update.VPNIsUp)
s.Service.OverrideWith(update.Service)
}

func (s Settings) validate() (err error) {
return s.Service.Validate()
}
4 changes: 2 additions & 2 deletions internal/vpn/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
)

func (l *Loop) cleanup(vpnProvider string) {
func (l *Loop) cleanup() {
for _, vpnPort := range l.vpnInputPorts {
err := l.fw.RemoveAllowedPort(context.Background(), vpnPort)
if err != nil {
Expand All @@ -18,7 +18,7 @@ func (l *Loop) cleanup(vpnProvider string) {
l.logger.Error("clearing public IP data: " + err.Error())
}

err = l.stopPortForwarding(vpnProvider)
err = l.stopPortForwarding()
if err != nil {
portForwardingAlreadyStopped := errors.Is(err, context.Canceled)
if !portForwardingAlreadyStopped {
Expand Down
2 changes: 1 addition & 1 deletion internal/vpn/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/netlink"
portforward "github.com/qdm12/gluetun/internal/portforward/service"
portforward "github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils"
)
Expand Down
23 changes: 11 additions & 12 deletions internal/vpn/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
"fmt"

"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/portforward/service"
pfutils "github.com/qdm12/gluetun/internal/provider/utils"
)
Expand All @@ -23,21 +23,20 @@ func getPortForwarder(provider Provider, providers Providers, //nolint:ireturn
}

func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
partialUpdate := service.Settings{
PortForwarder: data.portForwarder,
Interface: data.vpnIntf,
ServerName: data.serverName,
VPNProvider: data.portForwarder.Name(),
partialUpdate := portforward.Settings{
VPNIsUp: ptrTo(true),
Service: service.Settings{
PortForwarder: data.portForwarder,
Interface: data.vpnIntf,
ServerName: data.serverName,
},
}
return l.portForward.UpdateWith(partialUpdate)
}

func (l *Loop) stopPortForwarding(vpnProvider string) (err error) {
partialUpdate := service.Settings{
VPNProvider: vpnProvider,
UserSettings: settings.PortForwarding{
Enabled: ptrTo(false),
},
func (l *Loop) stopPortForwarding() (err error) {
partialUpdate := portforward.Settings{
VPNIsUp: ptrTo(false),
}
return l.portForward.UpdateWith(partialUpdate)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/vpn/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case <-tunnelReady:
go l.onTunnelUp(openvpnCtx, tunnelUpData)
case <-ctx.Done():
l.cleanup(portForwarder.Name())
l.cleanup()
openvpnCancel()
<-waitError
close(waitError)
return
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
l.cleanup(portForwarder.Name())
l.cleanup()
openvpnCancel()
<-waitError
// do not close waitError or the waitError
Expand All @@ -92,7 +92,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
case err := <-waitError: // unexpected error
l.statusManager.Lock() // prevent SetStatus from running in parallel

l.cleanup(portForwarder.Name())
l.cleanup()
openvpnCancel()
l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err)
Expand Down

0 comments on commit d4df872

Please sign in to comment.