diff --git a/internal/cmdtmpl/command.go b/internal/cmdtmpl/command.go index c3d59fa..6a7a953 100644 --- a/internal/cmdtmpl/command.go +++ b/internal/cmdtmpl/command.go @@ -212,21 +212,16 @@ func getCommandListTrafPol(name string) *CommandList { }, defaultTemplate: TrafPolDefaultTemplate, } - case "TrafPolAddAllowedDevice": - // Add Allowed Device + case "TrafPolSetAllowedDevices": + // Set Allowed Devices cl = &CommandList{ Name: name, Commands: []*Command{ - {Line: "{{.Executables.Nft}} -f - add element inet oc-daemon-filter allowdevs { {{.Device}} }"}, - }, - defaultTemplate: TrafPolDefaultTemplate, - } - case "TrafPolRemoveAllowedDevice": - // Remove Allowed Device - cl = &CommandList{ - Name: name, - Commands: []*Command{ - {Line: "{{.Executables.Nft}} -f - delete element inet oc-daemon-filter allowdevs { {{.Device}} }"}, + {Line: "{{.Executables.Nft}} -f -", + Stdin: `flush set inet oc-daemon-filter allowdevs +{{range .Devices -}} +add element inet oc-daemon-filter allowdevs { {{.}} } +{{end}}`}, }, defaultTemplate: TrafPolDefaultTemplate, } diff --git a/internal/cmdtmpl/command_test.go b/internal/cmdtmpl/command_test.go index 85571a7..d72b30a 100644 --- a/internal/cmdtmpl/command_test.go +++ b/internal/cmdtmpl/command_test.go @@ -37,8 +37,7 @@ func TestGetCommandList(t *testing.T) { // Traffic Policing "TrafPolSetFilterRules", "TrafPolUnsetFilterRules", - "TrafPolAddAllowedDevice", - "TrafPolRemoveAllowedDevice", + "TrafPolSetAllowedDevices", "TrafPolFlushAllowedHosts", "TrafPolAddAllowedHost", "TrafPolAddPortalPorts", @@ -89,8 +88,7 @@ func TestGetCmds(t *testing.T) { // Traffic Policing "TrafPolSetFilterRules", "TrafPolUnsetFilterRules", - // TrafPolAddAllowedDevice", // skip, requires device - // "TrafPolRemoveAllowedDevice", // skip, requires device + // TrafPolSetAllowedDevices", // skip, requires devices "TrafPolFlushAllowedHosts", // "TrafPolAddAllowedHost", // skip, requires host "TrafPolAddPortalPorts", @@ -117,8 +115,7 @@ func TestGetCmds(t *testing.T) { // existing, with insufficient input data for _, name := range []string{ // Traffic Policing - "TrafPolAddAllowedDevice", - "TrafPolRemoveAllowedDevice", + "TrafPolSetAllowedDevices", "TrafPolAddAllowedHost", // VPN Setup diff --git a/internal/trafpol/filter.go b/internal/trafpol/filter.go index b15a847..60eee3c 100644 --- a/internal/trafpol/filter.go +++ b/internal/trafpol/filter.go @@ -50,58 +50,30 @@ func unsetFilterRules(ctx context.Context, config *daemoncfg.Config) { } } -// addAllowedDevice adds device to the allowed devices. -func addAllowedDevice(ctx context.Context, conf *daemoncfg.Config, device string) { +// setAllowedDevices sets devices as allowed devices. +func setAllowedDevices(ctx context.Context, conf *daemoncfg.Config, devices []string) { data := &struct { daemoncfg.Config - Device string + Devices []string }{ - Config: *conf, - Device: device, + Config: *conf, + Devices: devices, } - cmds, err := cmdtmpl.GetCmds("TrafPolAddAllowedDevice", data) + cmds, err := cmdtmpl.GetCmds("TrafPolSetAllowedDevices", data) if err != nil { - log.WithError(err).Error("TrafPol could not get add allowed device commands") + log.WithError(err).Error("TrafPol could not get set allowed devices commands") } for _, c := range cmds { if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { log.WithFields(log.Fields{ - "device": device, + "devices": devices, "command": c.Cmd, "args": c.Args, "stdin": c.Stdin, "stdout": string(stdout), "stderr": string(stderr), "error": err, - }).Error("TrafPol could not run add allowed device command") - } - } -} - -// removeAllowedDevice removes device from the allowed devices. -func removeAllowedDevice(ctx context.Context, conf *daemoncfg.Config, device string) { - data := &struct { - daemoncfg.Config - Device string - }{ - Config: *conf, - Device: device, - } - cmds, err := cmdtmpl.GetCmds("TrafPolRemoveAllowedDevice", data) - if err != nil { - log.WithError(err).Error("TrafPol could not get remove allowed device commands") - } - for _, c := range cmds { - if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { - log.WithFields(log.Fields{ - "device": device, - "command": c.Cmd, - "args": c.Args, - "stdin": c.Stdin, - "stdout": string(stdout), - "stderr": string(stderr), - "error": err, - }).Error("TrafPol could not run remove allowed device command") + }).Error("TrafPol could not run set allowed devices command") } } } diff --git a/internal/trafpol/filter_test.go b/internal/trafpol/filter_test.go index 9d4ca74..6816408 100644 --- a/internal/trafpol/filter_test.go +++ b/internal/trafpol/filter_test.go @@ -27,8 +27,10 @@ func TestFilterFunctionsErrors(_ *testing.T) { unsetFilterRules(ctx, conf) // allowed devices - addAllowedDevice(ctx, conf, "eth0") - removeAllowedDevice(ctx, conf, "eth0") + setAllowedDevices(ctx, conf, []string{"eth0"}) + setAllowedDevices(ctx, conf, []string{"eth0", "eth1"}) + setAllowedDevices(ctx, conf, []string{"eth0"}) + setAllowedDevices(ctx, conf, []string{}) // allowed IPs setAllowedIPs(ctx, conf, []netip.Prefix{ diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go index 0f70216..9a9bbd1 100644 --- a/internal/trafpol/trafpol.go +++ b/internal/trafpol/trafpol.go @@ -71,12 +71,12 @@ func (t *TrafPol) handleDeviceUpdate(ctx context.Context, u *devmon.Update) { // skip when removing devices. if u.Add && u.Type != "device" { if t.allowDevs.Add(u.Device) { - addAllowedDevice(ctx, t.config, u.Device) + setAllowedDevices(ctx, t.config, t.allowDevs.List()) } return } if t.allowDevs.Remove(u.Device) { - removeAllowedDevice(ctx, t.config, u.Device) + setAllowedDevices(ctx, t.config, t.allowDevs.List()) } }