Skip to content

Commit

Permalink
Merge pull request #139 from telekom-mms/feature/add-set-allowed-ports
Browse files Browse the repository at this point in the history
Add commands for setting allowed ports to TrafPol
  • Loading branch information
hwipl authored Jan 24, 2025
2 parents 01b5c42 + 123734c commit 35a813c
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 51 deletions.
17 changes: 6 additions & 11 deletions internal/cmdtmpl/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,16 @@ add element inet oc-daemon-filter allowdevs { {{.}} }
},
defaultTemplate: TrafPolDefaultTemplate,
}
case "TrafPolAddPortalPorts":
case "TrafPolSetAllowedPorts":
// Remove Portal Ports
cl = &CommandList{
Name: name,
Commands: []*Command{
{Line: "{{.Executables.Nft}} -f - add element inet oc-daemon-filter allowports { {{range $i, $port := .TrafficPolicing.PortalPorts}}{{if $i}}, {{end}}{{$port}}{{end}} }"},
},
defaultTemplate: TrafPolDefaultTemplate,
}
case "TrafPolRemovePortalPorts":
// Remove Portal Ports
cl = &CommandList{
Name: name,
Commands: []*Command{
{Line: "{{.Executables.Nft}} -f - delete element inet oc-daemon-filter allowports { {{range $i, $port := .TrafficPolicing.PortalPorts}}{{if $i}}, {{end}}{{$port}}{{end}} }"},
{Line: "{{.Executables.Nft}} -f -",
Stdin: `flush set inet oc-daemon-filter allowports
{{range .Ports -}}
add element inet oc-daemon-filter allowports { {{.}} }
{{end}}`},
},
defaultTemplate: TrafPolDefaultTemplate,
}
Expand Down
7 changes: 3 additions & 4 deletions internal/cmdtmpl/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ func TestGetCommandList(t *testing.T) {
"TrafPolSetAllowedDevices",
"TrafPolFlushAllowedHosts",
"TrafPolAddAllowedHost",
"TrafPolAddPortalPorts",
"TrafPolRemovePortalPorts",
"TrafPolSetAllowedPorts",
"TrafPolCleanup",

// VPN Setup
Expand Down Expand Up @@ -91,8 +90,7 @@ func TestGetCmds(t *testing.T) {
// TrafPolSetAllowedDevices", // skip, requires devices
"TrafPolFlushAllowedHosts",
// "TrafPolAddAllowedHost", // skip, requires host
"TrafPolAddPortalPorts",
"TrafPolRemovePortalPorts",
//"TrafPolSetAllowedPorts", // skip, requires ports
"TrafPolCleanup",

// VPN Setup
Expand All @@ -117,6 +115,7 @@ func TestGetCmds(t *testing.T) {
// Traffic Policing
"TrafPolSetAllowedDevices",
"TrafPolAddAllowedHost",
"TrafPolSetAllowedPorts",

// VPN Setup
"VPNSetupSetExcludes",
Expand Down
38 changes: 12 additions & 26 deletions internal/trafpol/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,44 +132,30 @@ func setAllowedIPs(ctx context.Context, conf *daemoncfg.Config, ips []netip.Pref
}
}

// addPortalPorts adds ports for a captive portal to the allowed ports.
func addPortalPorts(ctx context.Context, conf *daemoncfg.Config) {
cmds, err := cmdtmpl.GetCmds("TrafPolAddPortalPorts", conf)
if err != nil {
log.WithError(err).Error("TrafPol could not get add portal ports commands")
}
for _, c := range cmds {
if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.WithFields(log.Fields{
"ports": conf.TrafficPolicing.PortalPorts,
"command": c.Cmd,
"args": c.Args,
"stdin": c.Stdin,
"stdout": string(stdout),
"stderr": string(stderr),
"error": err,
}).Error("TrafPol could not run add portal ports command")
}
// setAllowedPorts sets ports (for a captive portal) as the allowed ports.
func setAllowedPorts(ctx context.Context, conf *daemoncfg.Config, ports []uint16) {
data := &struct {
daemoncfg.Config
Ports []uint16
}{
Config: *conf,
Ports: ports,
}
}

// removePortalPorts removes ports for a captive portal from the allowed ports.
func removePortalPorts(ctx context.Context, conf *daemoncfg.Config) {
cmds, err := cmdtmpl.GetCmds("TrafPolRemovePortalPorts", conf)
cmds, err := cmdtmpl.GetCmds("TrafPolSetAllowedPorts", data)
if err != nil {
log.WithError(err).Error("TrafPol could not get remove portal ports commands")
log.WithError(err).Error("TrafPol could not get set allowed ports commands")
}
for _, c := range cmds {
if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.WithFields(log.Fields{
"ports": conf.TrafficPolicing.PortalPorts,
"ports": ports,
"command": c.Cmd,
"args": c.Args,
"stdin": c.Stdin,
"stdout": string(stdout),
"stderr": string(stderr),
"error": err,
}).Error("TrafPol could not run remove portal ports command")
}).Error("TrafPol could not run set allowed ports command")
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions internal/trafpol/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ func TestFilterFunctionsErrors(_ *testing.T) {
})

// portal ports
conf.TrafficPolicing.PortalPorts = []uint16{80, 443}
addPortalPorts(ctx, conf)
removePortalPorts(ctx, conf)
setAllowedPorts(ctx, conf, []uint16{80, 443})
setAllowedPorts(ctx, conf, []uint16{})
}
4 changes: 2 additions & 2 deletions internal/trafpol/trafpol.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) {
t.resolver.Resolve()

// remove ports from allowed ports
removePortalPorts(ctx, t.config)
setAllowedPorts(ctx, t.config, []uint16{})
t.capPortal = false
log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status")
}
Expand All @@ -109,7 +109,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) {

// add ports to allowed ports
if !t.capPortal {
addPortalPorts(ctx, t.config)
setAllowedPorts(ctx, t.config, t.config.TrafficPolicing.PortalPorts)
t.capPortal = true
log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status")
}
Expand Down
14 changes: 9 additions & 5 deletions internal/trafpol/trafpol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func TestTrafPolHandleCPDReport(t *testing.T) {
var nftMutex sync.Mutex
nftCmds := []string{}
oldRunCmd := execs.RunCmd
execs.RunCmd = func(_ context.Context, cmd string, _ string, args ...string) ([]byte, []byte, error) {
execs.RunCmd = func(_ context.Context, cmd string, stdin string, args ...string) ([]byte, []byte, error) {
nftMutex.Lock()
defer nftMutex.Unlock()
nftCmds = append(nftCmds, cmd+" "+strings.Join(args, " "))
nftCmds = append(nftCmds, cmd+" "+strings.Join(args, " ")+" "+stdin)
return nil, nil, nil
}
defer func() { execs.RunCmd = oldRunCmd }()
Expand All @@ -86,7 +86,9 @@ func TestTrafPolHandleCPDReport(t *testing.T) {
tp.handleCPDReport(ctx, report)

want = []string{
"nft -f - add element inet oc-daemon-filter allowports { 80, 443 }",
"nft -f - flush set inet oc-daemon-filter allowports\n" +
"add element inet oc-daemon-filter allowports { 80 }\n" +
"add element inet oc-daemon-filter allowports { 443 }\n",
}
got = getNftCmds()
if !reflect.DeepEqual(got, want) {
Expand All @@ -98,8 +100,10 @@ func TestTrafPolHandleCPDReport(t *testing.T) {
tp.handleCPDReport(ctx, report)

want = []string{
"nft -f - add element inet oc-daemon-filter allowports { 80, 443 }",
"nft -f - delete element inet oc-daemon-filter allowports { 80, 443 }",
"nft -f - flush set inet oc-daemon-filter allowports\n" +
"add element inet oc-daemon-filter allowports { 80 }\n" +
"add element inet oc-daemon-filter allowports { 443 }\n",
"nft -f - flush set inet oc-daemon-filter allowports\n",
}
got = getNftCmds()
if !reflect.DeepEqual(got, want) {
Expand Down

0 comments on commit 35a813c

Please sign in to comment.