diff --git a/internal/cmdtmpl/command.go b/internal/cmdtmpl/command.go index 6a7a953..e9c3427 100644 --- a/internal/cmdtmpl/command.go +++ b/internal/cmdtmpl/command.go @@ -252,21 +252,16 @@ add element inet oc-daemon-filter allowdevs { {{.}} } }, defaultTemplate: TrafPolDefaultTemplate, } - case "TrafPolAddPortalPorts": + case "TrafPolSetAllowedPorts": // Remove Portal Ports cl = &CommandList{ Name: name, Commands: []*Command{ - {Line: "{{.Executables.Nft}} -f - add element inet oc-daemon-filter allowports { {{range $i, $port := .TrafficPolicing.PortalPorts}}{{if $i}}, {{end}}{{$port}}{{end}} }"}, - }, - defaultTemplate: TrafPolDefaultTemplate, - } - case "TrafPolRemovePortalPorts": - // Remove Portal Ports - cl = &CommandList{ - Name: name, - Commands: []*Command{ - {Line: "{{.Executables.Nft}} -f - delete element inet oc-daemon-filter allowports { {{range $i, $port := .TrafficPolicing.PortalPorts}}{{if $i}}, {{end}}{{$port}}{{end}} }"}, + {Line: "{{.Executables.Nft}} -f -", + Stdin: `flush set inet oc-daemon-filter allowports +{{range .Ports -}} +add element inet oc-daemon-filter allowports { {{.}} } +{{end}}`}, }, defaultTemplate: TrafPolDefaultTemplate, } diff --git a/internal/cmdtmpl/command_test.go b/internal/cmdtmpl/command_test.go index d72b30a..0ea2bb6 100644 --- a/internal/cmdtmpl/command_test.go +++ b/internal/cmdtmpl/command_test.go @@ -40,8 +40,7 @@ func TestGetCommandList(t *testing.T) { "TrafPolSetAllowedDevices", "TrafPolFlushAllowedHosts", "TrafPolAddAllowedHost", - "TrafPolAddPortalPorts", - "TrafPolRemovePortalPorts", + "TrafPolSetAllowedPorts", "TrafPolCleanup", // VPN Setup @@ -91,8 +90,7 @@ func TestGetCmds(t *testing.T) { // TrafPolSetAllowedDevices", // skip, requires devices "TrafPolFlushAllowedHosts", // "TrafPolAddAllowedHost", // skip, requires host - "TrafPolAddPortalPorts", - "TrafPolRemovePortalPorts", + //"TrafPolSetAllowedPorts", // skip, requires ports "TrafPolCleanup", // VPN Setup @@ -117,6 +115,7 @@ func TestGetCmds(t *testing.T) { // Traffic Policing "TrafPolSetAllowedDevices", "TrafPolAddAllowedHost", + "TrafPolSetAllowedPorts", // VPN Setup "VPNSetupSetExcludes", diff --git a/internal/trafpol/filter.go b/internal/trafpol/filter.go index 60eee3c..3b85616 100644 --- a/internal/trafpol/filter.go +++ b/internal/trafpol/filter.go @@ -132,44 +132,30 @@ func setAllowedIPs(ctx context.Context, conf *daemoncfg.Config, ips []netip.Pref } } -// addPortalPorts adds ports for a captive portal to the allowed ports. -func addPortalPorts(ctx context.Context, conf *daemoncfg.Config) { - cmds, err := cmdtmpl.GetCmds("TrafPolAddPortalPorts", conf) - if err != nil { - log.WithError(err).Error("TrafPol could not get add portal ports commands") - } - for _, c := range cmds { - if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { - log.WithFields(log.Fields{ - "ports": conf.TrafficPolicing.PortalPorts, - "command": c.Cmd, - "args": c.Args, - "stdin": c.Stdin, - "stdout": string(stdout), - "stderr": string(stderr), - "error": err, - }).Error("TrafPol could not run add portal ports command") - } +// setAllowedPorts sets ports (for a captive portal) as the allowed ports. +func setAllowedPorts(ctx context.Context, conf *daemoncfg.Config, ports []uint16) { + data := &struct { + daemoncfg.Config + Ports []uint16 + }{ + Config: *conf, + Ports: ports, } -} - -// removePortalPorts removes ports for a captive portal from the allowed ports. -func removePortalPorts(ctx context.Context, conf *daemoncfg.Config) { - cmds, err := cmdtmpl.GetCmds("TrafPolRemovePortalPorts", conf) + cmds, err := cmdtmpl.GetCmds("TrafPolSetAllowedPorts", data) if err != nil { - log.WithError(err).Error("TrafPol could not get remove portal ports commands") + log.WithError(err).Error("TrafPol could not get set allowed ports commands") } for _, c := range cmds { if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { log.WithFields(log.Fields{ - "ports": conf.TrafficPolicing.PortalPorts, + "ports": ports, "command": c.Cmd, "args": c.Args, "stdin": c.Stdin, "stdout": string(stdout), "stderr": string(stderr), "error": err, - }).Error("TrafPol could not run remove portal ports command") + }).Error("TrafPol could not run set allowed ports command") } } } diff --git a/internal/trafpol/filter_test.go b/internal/trafpol/filter_test.go index 6816408..58d7569 100644 --- a/internal/trafpol/filter_test.go +++ b/internal/trafpol/filter_test.go @@ -39,7 +39,6 @@ func TestFilterFunctionsErrors(_ *testing.T) { }) // portal ports - conf.TrafficPolicing.PortalPorts = []uint16{80, 443} - addPortalPorts(ctx, conf) - removePortalPorts(ctx, conf) + setAllowedPorts(ctx, conf, []uint16{80, 443}) + setAllowedPorts(ctx, conf, []uint16{}) } diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go index 9a9bbd1..7f1c8f1 100644 --- a/internal/trafpol/trafpol.go +++ b/internal/trafpol/trafpol.go @@ -100,7 +100,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { t.resolver.Resolve() // remove ports from allowed ports - removePortalPorts(ctx, t.config) + setAllowedPorts(ctx, t.config, []uint16{}) t.capPortal = false log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status") } @@ -109,7 +109,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { // add ports to allowed ports if !t.capPortal { - addPortalPorts(ctx, t.config) + setAllowedPorts(ctx, t.config, t.config.TrafficPolicing.PortalPorts) t.capPortal = true log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status") } diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index a69e4d6..dd013e0 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -57,10 +57,10 @@ func TestTrafPolHandleCPDReport(t *testing.T) { var nftMutex sync.Mutex nftCmds := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, cmd string, _ string, args ...string) ([]byte, []byte, error) { + execs.RunCmd = func(_ context.Context, cmd string, stdin string, args ...string) ([]byte, []byte, error) { nftMutex.Lock() defer nftMutex.Unlock() - nftCmds = append(nftCmds, cmd+" "+strings.Join(args, " ")) + nftCmds = append(nftCmds, cmd+" "+strings.Join(args, " ")+" "+stdin) return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -86,7 +86,9 @@ func TestTrafPolHandleCPDReport(t *testing.T) { tp.handleCPDReport(ctx, report) want = []string{ - "nft -f - add element inet oc-daemon-filter allowports { 80, 443 }", + "nft -f - flush set inet oc-daemon-filter allowports\n" + + "add element inet oc-daemon-filter allowports { 80 }\n" + + "add element inet oc-daemon-filter allowports { 443 }\n", } got = getNftCmds() if !reflect.DeepEqual(got, want) { @@ -98,8 +100,10 @@ func TestTrafPolHandleCPDReport(t *testing.T) { tp.handleCPDReport(ctx, report) want = []string{ - "nft -f - add element inet oc-daemon-filter allowports { 80, 443 }", - "nft -f - delete element inet oc-daemon-filter allowports { 80, 443 }", + "nft -f - flush set inet oc-daemon-filter allowports\n" + + "add element inet oc-daemon-filter allowports { 80 }\n" + + "add element inet oc-daemon-filter allowports { 443 }\n", + "nft -f - flush set inet oc-daemon-filter allowports\n", } got = getNftCmds() if !reflect.DeepEqual(got, want) {