From 3cf38fa3b1e20e9195782ac7fe2622cca69f99af Mon Sep 17 00:00:00 2001 From: hwipl <33433250+hwipl@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:17:38 +0100 Subject: [PATCH] Add command templates Add a command templates package for running external programs from OC-Daemon. It allows for defining named lists of commands that are called from specific points in OC-Daemon. For example, the command list "SplitRoutingSetupRouting" is called when Split Routing configures the routing. Each command in such a list allows for generating the command line and standard input from Go templates. The templates use the daemon configuration (and for some command lists additional information like addresses) as data input. Signed-off-by: hwipl <33433250+hwipl@users.noreply.github.com> --- internal/cmdtmpl/command.go | 654 +++++++++++++++++++++++++++++++ internal/cmdtmpl/command_test.go | 146 +++++++ internal/daemon/daemon.go | 2 +- internal/execs/execs.go | 56 --- internal/execs/execs_test.go | 184 --------- internal/splitrt/filter.go | 244 ++---------- internal/splitrt/route.go | 156 -------- internal/splitrt/splitrt.go | 102 ++--- internal/splitrt/splitrt_test.go | 23 +- internal/trafpol/filter.go | 421 ++++++++------------ internal/trafpol/filter_test.go | 18 +- internal/trafpol/trafpol.go | 24 +- internal/trafpol/trafpol_test.go | 19 +- internal/vpnsetup/vpnsetup.go | 299 +++++++------- 14 files changed, 1238 insertions(+), 1110 deletions(-) create mode 100644 internal/cmdtmpl/command.go create mode 100644 internal/cmdtmpl/command_test.go delete mode 100644 internal/splitrt/route.go diff --git a/internal/cmdtmpl/command.go b/internal/cmdtmpl/command.go new file mode 100644 index 00000000..2b49db78 --- /dev/null +++ b/internal/cmdtmpl/command.go @@ -0,0 +1,654 @@ +// Package cmdtmpl contains command lists for external commands with templates. +package cmdtmpl + +import ( + "bytes" + "context" + "fmt" + "strings" + "text/template" + + "github.com/telekom-mms/oc-daemon/internal/execs" +) + +// Command consists of a command line to be executed and an optional Stdin to +// be passed to the command on execution. +type Command struct { + Line string + Stdin string +} + +// CommandList is a list of Commands. +type CommandList struct { + Name string + Commands []*Command + + defaultTemplate string + template *template.Template +} + +// executeTemplate executes the template on data and returns the resulting +// output as string. +func (cl *CommandList) executeTemplate(tmpl string, data any) (string, error) { + t, err := cl.template.Clone() + if err != nil { + return "", err + } + t, err = t.Parse(tmpl) + if err != nil { + return "", err + } + buf := &bytes.Buffer{} + if err := t.Execute(buf, data); err != nil { + return "", err + } + + s := buf.String() + return s, nil +} + +// SplitRoutingDefaultTemplate is the default template for Split Routing. +const SplitRoutingDefaultTemplate = ` +{{- define "SplitRoutingRules"}} +table inet oc-daemon-routing { + # set for ipv4 excludes + set excludes4 { + type ipv4_addr + flags interval + } + + # set for ipv6 excludes + set excludes6 { + type ipv6_addr + flags interval + } + + chain preraw { + type filter hook prerouting priority raw; policy accept; + + # add drop rules for non-local traffic from other devices to + # tunnel network addresses here + {{if .VPNConfig.IPv4.IsValid}} + iifname != {{.VPNConfig.Device.Name}} ip daddr {{.VPNConfig.IPv4}} fib saddr type != local counter drop + {{end}} + {{if .VPNConfig.IPv6.IsValid}} + iifname != {{.VPNConfig.Device.Name}} ip6 daddr {{.VPNConfig.IPv6}} fib saddr type != local counter drop + {{end}} + } + + chain splitrouting { + # restore mark from conntracking + ct mark != 0 meta mark set ct mark counter + meta mark != 0 counter accept + + # mark packets in exclude sets + ip daddr @excludes4 counter meta mark set {{.SplitRouting.FirewallMark}} + ip6 daddr @excludes6 counter meta mark set {{.SplitRouting.FirewallMark}} + + # save mark in conntraction + ct mark set meta mark counter + } + + chain premangle { + type filter hook prerouting priority mangle; policy accept; + + # handle split routing + counter jump splitrouting + } + + chain output { + type route hook output priority mangle; policy accept; + + # handle split routing + counter jump splitrouting + } + + chain postmangle { + type filter hook postrouting priority mangle; policy accept; + + # save mark in conntracking + meta mark {{.SplitRouting.FirewallMark}} ct mark set meta mark counter + } + + chain postrouting { + type nat hook postrouting priority srcnat; policy accept; + + # masquerare tunnel/exclude traffic to make sure the source IP + # matches the outgoing interface + ct mark {{.SplitRouting.FirewallMark}} counter masquerade + } + + chain rejectipversion { + # used to reject unsupported ip version on the tunnel device + + # make sure exclude traffic is not filtered + ct mark {{.SplitRouting.FirewallMark}} counter accept + + # use tcp reset and icmp admin prohibited + meta l4proto tcp counter reject with tcp reset + counter reject with icmpx admin-prohibited + } + + chain rejectforward { + type filter hook forward priority filter; policy accept; + + # reject unsupported ip versions when forwarding packets, + # add matching jump rule to rejectipversion if necessary + {{if .VPNConfig.IPv4.IsValid}} + meta oifname {{.VPNConfig.Device.Name}} meta nfproto ipv6 counter jump rejectipversion + {{end}} + {{if .VPNConfig.IPv6.IsValid}} + meta oifname {{.VPNConfig.Device.Name}} meta nfproto ipv4 counter jump rejectipversion + {{end}} + } + + chain rejectoutput { + type filter hook output priority filter; policy accept; + + # reject unsupported ip versions when sending packets, + # add matching jump rule to rejectipversion if necessary + {{if .VPNConfig.IPv4.IsValid}} + meta oifname {{.VPNConfig.Device.Name}} meta nfproto ipv6 counter jump rejectipversion + {{end}} + {{if .VPNConfig.IPv6.IsValid}} + meta oifname {{.VPNConfig.Device.Name}} meta nfproto ipv4 counter jump rejectipversion + {{end}} + } +} +{{end -}} +` + +// getCommandListSplitRouting returns the command list identified by name for SplitRouting. +func getCommandListSplitRouting(name string) *CommandList { + var cl *CommandList + switch name { + case "SplitRoutingSetupRouting": + // Setup Routing + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f -", Stdin: `{{template "SplitRoutingRules" .}}`}, + {Line: "{{.Executables.IP}} -4 route add 0.0.0.0/0 dev {{.VPNConfig.Device.Name}} table {{.SplitRouting.RoutingTable}}"}, + {Line: "{{.Executables.IP}} -4 rule add iif {{.VPNConfig.Device.Name}} table main pref {{.SplitRouting.RulePriority1}}"}, + {Line: "{{.Executables.IP}} -4 rule add not fwmark {{.SplitRouting.FirewallMark}} table {{.SplitRouting.RoutingTable}} pref {{.SplitRouting.RulePriority2}}"}, + {Line: "{{.Executables.Sysctl}} -q net.ipv4.conf.all.src_valid_mark=1"}, + {Line: "{{.Executables.IP}} -6 route add ::/0 dev {{.VPNConfig.Device.Name}} table {{.SplitRouting.RoutingTable}}"}, + {Line: "{{.Executables.IP}} -6 rule add iif {{.VPNConfig.Device.Name}} table main pref {{.SplitRouting.RulePriority1}}"}, + {Line: "{{.Executables.IP}} -6 rule add not fwmark {{.SplitRouting.FirewallMark}} table {{.SplitRouting.RoutingTable}} pref {{.SplitRouting.RulePriority2}}"}, + }, + defaultTemplate: SplitRoutingDefaultTemplate, + } + case "SplitRoutingTeardownRouting": + // Teardown Routing + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.IP}} -4 rule delete table {{.SplitRouting.RoutingTable}}"}, + {Line: "{{.Executables.IP}} -4 rule delete iif {{.VPNConfig.Device.Name}} table main"}, + {Line: "{{.Executables.IP}} -6 rule delete table {{.SplitRouting.RoutingTable}}"}, + {Line: "{{.Executables.IP}} -6 rule delete iif {{.VPNConfig.Device.Name}} table main"}, + {Line: "{{.Executables.Nft}} -f - delete table inet oc-daemon-routing"}, + }, + defaultTemplate: SplitRoutingDefaultTemplate, + } + case "SplitRoutingSetExcludes": + // Set Excludes + cl = &CommandList{ + Name: name, + Commands: []*Command{ + // flush existing entries + // add entries + {Line: "{{.Executables.Nft}} -f -", + Stdin: `flush set inet oc-daemon-routing excludes4 +flush set inet oc-daemon-routing excludes6 +{{range .Addresses -}} +{{if .Addr.Is6 -}} +add element inet oc-daemon-routing excludes6 { {{.}} } +{{else -}} +add element inet oc-daemon-routing excludes4 { {{.}} } +{{end -}} +{{end}}`}, + }, + defaultTemplate: SplitRoutingDefaultTemplate, + } + case "SplitRoutingCleanup": + // Cleanup + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.IP}} -4 rule delete pref {{.SplitRouting.RulePriority1}}"}, + {Line: "{{.Executables.IP}} -4 rule delete pref {{.SplitRouting.RulePriority2}}"}, + {Line: "{{.Executables.IP}} -6 rule delete pref {{.SplitRouting.RulePriority1}}"}, + {Line: "{{.Executables.IP}} -6 rule delete pref {{.SplitRouting.RulePriority2}}"}, + {Line: "{{.Executables.IP}} -4 route flush table {{.SplitRouting.RoutingTable}}"}, + {Line: "{{.Executables.IP}} -6 route flush table {{.SplitRouting.RoutingTable}}"}, + {Line: "{{.Executables.Nft}} -f - delete table inet oc-daemon-routing"}, + }, + defaultTemplate: SplitRoutingDefaultTemplate, + } + default: + return nil + + } + + cl.template = template.Must(template.New("Template").Parse(cl.defaultTemplate)) + return cl +} + +// TrafPolDefaultTemplate is the default template for Traffic Policing. +const TrafPolDefaultTemplate = ` +{{- define "TrafPolRules"}} +table inet oc-daemon-filter { + # set for allowed devices + set allowdevs { + type ifname + elements = { lo } + } + + # set for allowed ipv4 hosts + set allowhosts4 { + type ipv4_addr + flags interval + } + + # set for allowed ipv6 hosts + set allowhosts6 { + type ipv6_addr + flags interval + } + + # set for allowed ports + set allowports { + type inet_service + elements = { 53 } + } + + chain input { + type filter hook input priority 0; policy drop; + + # accept related traffic + ct state established,related counter accept + + # accept traffic on allowed devices, e.g., lo + iifname @allowdevs counter accept + + # accept ICMP traffic + icmp type { + echo-reply, + destination-unreachable, + source-quench, + redirect, + time-exceeded, + parameter-problem, + timestamp-reply, + info-reply, + address-mask-reply, + router-advertisement, + } counter accept + + # accept ICMPv6 traffic otherwise IPv6 connectivity breaks + icmpv6 type { + destination-unreachable, + packet-too-big, + time-exceeded, + echo-reply, + mld-listener-query, + mld-listener-report, + mld2-listener-report, + mld-listener-done, + nd-router-advert, + nd-neighbor-solicit, + nd-neighbor-advert, + ind-neighbor-solicit, + ind-neighbor-advert, + nd-redirect, + parameter-problem, + router-renumbering + } counter accept + + # accept DHCPv4 traffic + udp dport 68 udp sport 67 counter accept + + # accept DHCPv6 traffic + udp dport 546 udp sport 547 counter accept + } + + chain output { + type filter hook output priority 0; policy drop; + + # accept related traffic + ct state established,related counter accept + + # accept traffic on allowed devices, e.g., lo + oifname @allowdevs counter accept + + # accept traffic to allowed hosts + ip daddr @allowhosts4 counter accept + ip6 daddr @allowhosts6 counter accept + + # accept ICMP traffic + icmp type { + source-quench, + echo-request, + timestamp-request, + info-request, + address-mask-request, + router-solicitation + } counter accept + + # accept ICMPv6 traffic otherwise IPv6 connectivity breaks + icmpv6 type { + echo-request, + mld-listener-report, + mld2-listener-report, + mld-listener-done, + nd-router-solicit, + nd-neighbor-solicit, + nd-neighbor-advert, + ind-neighbor-solicit, + ind-neighbor-advert, + } counter accept + + # accept traffic to allowed ports, e.g., DNS + udp dport @allowports counter accept + tcp dport @allowports counter accept + + # accept DHCPv4 traffic + udp dport 67 udp sport 68 counter accept + + # accept DHCPv6 traffic + udp dport 547 udp sport 546 counter accept + + # reject everything else + counter reject + } + + chain forward { + type filter hook forward priority 0; policy drop; + + # accept related traffic + ct state established,related counter accept + + # accept split exclude traffic + iifname @allowdevs ct mark {{.SplitRouting.FirewallMark}} counter accept + + # accept traffic on allowed devices + iifname @allowdevs oifname @allowdevs counter accept + } +} +{{end}}` + +// getCommandListTrafPol returns the command list identified by name for Traffic Policing. +func getCommandListTrafPol(name string) *CommandList { + var cl *CommandList + switch name { + case "TrafPolSetFilterRules": + // Set Filter Rules + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f -", Stdin: `{{template "TrafPolRules" .}}`}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolUnsetFilterRules": + // Unset Filter Rules + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - delete table inet oc-daemon-filter"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolAddAllowedDevice": + // Add Allowed Device + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - add element inet oc-daemon-filter allowdevs { {{.Device}} }"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolRemoveAllowedDevice": + // Remove Allowed Device + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - delete element inet oc-daemon-filter allowdevs { {{.Device}} }"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolFlushAllowedHosts": + // Flush Allowed Hosts + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - flush set inet oc-daemon-filter allowhosts4"}, + {Line: "{{.Executables.Nft}} -f - flush set inet oc-daemon-filter allowhosts6"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolAddAllowedHost": + // Add Allowed Host + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f -", + Stdin: ` + {{if .AllowedIP.Addr.Is4}} + add element inet oc-daemon-filter allowhosts4 { {{.AllowedIP}} } + {{else}} + add element inet oc-daemon-filter allowhosts6 { {{.AllowedIP}} } + {{end}} + `, + }, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolAddPortalPorts": + // Remove Portal Ports + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - add element inet oc-daemon-filter allowports { {{range $i, $port := .TrafficPolicing.PortalPorts}}{{if $i}}, {{end}}{{$port}}{{end}} }"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolRemovePortalPorts": + // Remove Portal Ports + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - delete element inet oc-daemon-filter allowports { {{range $i, $port := .TrafficPolicing.PortalPorts}}{{if $i}}, {{end}}{{$port}}{{end}} }"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + case "TrafPolCleanup": + // Cleanup + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Nft}} -f - delete table inet oc-daemon-filter"}, + }, + defaultTemplate: TrafPolDefaultTemplate, + } + default: + return nil + } + + cl.template = template.Must(template.New("Template").Parse(cl.defaultTemplate)) + return cl +} + +// getCommandListVPNSetup returns the command list identified by name for VPNSetup. +func getCommandListVPNSetup(name string) *CommandList { + var cl *CommandList + switch name { + case "VPNSetupSetupVPNDevice": + // Setup VPN Device + cl = &CommandList{ + Name: name, + Commands: []*Command{ + // set mtu on device + {Line: "{{.Executables.IP}} link set {{.VPNConfig.Device.Name}} mtu {{.VPNConfig.Device.MTU}}"}, + // set device up + {Line: "{{.Executables.IP}} link set {{.VPNConfig.Device.Name}} up"}, + // set ipv4 and ipv6 addresses on device + {Line: "{{if .VPNConfig.IPv4.IsValid}}{{.Executables.IP}} address add {{.VPNConfig.IPv4}} dev {{.VPNConfig.Device.Name}}{{end}}"}, + {Line: "{{if .VPNConfig.IPv6.IsValid}}{{.Executables.IP}} address add {{.VPNConfig.IPv6}} dev {{.VPNConfig.Device.Name}}{{end}}"}, + }, + defaultTemplate: "", + } + case "VPNSetupTeardownVPNDevice": + // Teardown VPN Device + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.IP}} link set {{.VPNConfig.Device.Name}} down"}, + }, + defaultTemplate: "", + } + case "VPNSetupSetupDNSServer": + // Setup DNS server + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} dns {{.VPNConfig.Device.Name}} {{.DNSProxy.Address}}"}, + }, + defaultTemplate: "", + } + case "VPNSetupSetupDNSDomains": + // Setup DNS domains + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} domain {{.VPNConfig.Device.Name}} {{.VPNConfig.DNS.DefaultDomain}} ~."}, + }, + defaultTemplate: "", + } + case "VPNSetupSetupDNSDefaultRoute": + // Setup DNS Default Route + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} default-route {{.VPNConfig.Device}} yes"}, + }, + defaultTemplate: "", + } + case "VPNSetupSetupDNS": + // Setup DNS + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} dns {{.VPNConfig.Device.Name}} {{.DNSProxy.Address}}"}, + {Line: "{{.Executables.Resolvectl}} domain {{.VPNConfig.Device.Name}} {{.VPNConfig.DNS.DefaultDomain}} ~."}, + {Line: "{{.Executables.Resolvectl}} default-route {{.VPNConfig.Device.Name}} yes"}, + {Line: "{{.Executables.Resolvectl}} flush-caches"}, + {Line: "{{.Executables.Resolvectl}} reset-server-features"}, + }, + defaultTemplate: "", + } + case "VPNSetupTeardownDNS": + // Teardown DNS + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} revert {{.VPNConfig.Device.Name}}"}, + {Line: "{{.Executables.Resolvectl}} flush-caches"}, + {Line: "{{.Executables.Resolvectl}} reset-server-features"}, + }, + defaultTemplate: "", + } + case "VPNSetupEnsureDNS": + // Ensure DNS + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} status {{.VPNConfig.Device.Name}} --no-pager"}, + }, + defaultTemplate: "", + } + case "VPNSetupCleanup": + // Cleanup + cl = &CommandList{ + Name: name, + Commands: []*Command{ + {Line: "{{.Executables.Resolvectl}} revert {{.OpenConnect.VPNDevice}}"}, + {Line: "{{.Executables.IP}} link delete {{.OpenConnect.VPNDevice}}"}, + }, + defaultTemplate: "", + } + default: + return nil + } + + cl.template = template.Must(template.New("Template").Parse(cl.defaultTemplate)) + return cl +} + +// getCommandList returns the command list identified by name. +func getCommandList(name string) *CommandList { + if strings.HasPrefix(name, "SplitRouting") { + return getCommandListSplitRouting(name) + } + if strings.HasPrefix(name, "TrafPol") { + return getCommandListTrafPol(name) + } + if strings.HasPrefix(name, "VPNSetup") { + return getCommandListVPNSetup(name) + } + return nil +} + +// Cmd is a command ready to run. +type Cmd struct { + Cmd string + Args []string + Stdin string +} + +// Run runs the command. +func (c *Cmd) Run(ctx context.Context) (stdout, stderr []byte, err error) { + return execs.RunCmd(ctx, c.Cmd, c.Stdin, c.Args...) +} + +// GetCmds returns a list of Cmds ready to run. +func GetCmds(name string, data any) ([]*Cmd, error) { + cl := getCommandList(name) + if cl == nil { + return nil, fmt.Errorf("could not find command list %s", name) + } + var commands []*Cmd + for _, c := range cl.Commands { + // execute template for command line + line, err := cl.executeTemplate(c.Line, data) + if err != nil { + return nil, fmt.Errorf("could not execute template for command line: %w", err) + } + + // execute template for stdin + stdin, err := cl.executeTemplate(c.Stdin, data) + if err != nil { + return nil, fmt.Errorf("could not execute template for stdin: %w", err) + } + + // extract command from command line + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + command := fields[0] + + // extract arguments from command line + args := []string{} + if len(fields) > 1 { + args = fields[1:] + } + commands = append(commands, &Cmd{ + Cmd: command, + Args: args, + Stdin: stdin, + }) + } + return commands, nil +} diff --git a/internal/cmdtmpl/command_test.go b/internal/cmdtmpl/command_test.go new file mode 100644 index 00000000..de318383 --- /dev/null +++ b/internal/cmdtmpl/command_test.go @@ -0,0 +1,146 @@ +package cmdtmpl + +import ( + "context" + "testing" + "text/template" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" +) + +// TestExecuteTemplateErrors tests executeTemplate of CommandList, parse error. +func TestExecuteTemplateParseError(t *testing.T) { + cl := &CommandList{ + template: template.Must(template.New("test").Parse("test")), + } + if _, err := cl.executeTemplate("{{ invalid }}", nil); err == nil { + t.Error("invalid template should not parse correctly") + } +} + +// TestGetCommandList tests getCommandList. +func TestGetCommandList(t *testing.T) { + // not existing + for _, name := range []string{ + "SplitRoutingDoesNotExist", + "TrafPolDoesNotExist", + "VPNSetupDoesNotExist", + "DoesNotExist", + } { + cl := getCommandList(name) + if cl != nil { + t.Errorf("command list %s should not exists, got %s", name, cl.Name) + } + } + + // existing + for _, name := range []string{ + // Split Routing + "SplitRoutingSetupRouting", + "SplitRoutingTeardownRouting", + "SplitRoutingSetExcludes", + "SplitRoutingCleanup", + + // Traffic Policing + "TrafPolSetFilterRules", + "TrafPolUnsetFilterRules", + "TrafPolAddAllowedDevice", + "TrafPolRemoveAllowedDevice", + "TrafPolFlushAllowedHosts", + "TrafPolAddAllowedHost", + "TrafPolAddPortalPorts", + "TrafPolRemovePortalPorts", + "TrafPolCleanup", + + // VPN Setup + "VPNSetupSetupVPNDevice", + "VPNSetupTeardownVPNDevice", + "VPNSetupSetupDNSServer", + "VPNSetupSetupDNSDomains", + "VPNSetupSetupDNSDefaultRoute", + "VPNSetupSetupDNS", + "VPNSetupTeardownDNS", + "VPNSetupEnsureDNS", + "VPNSetupCleanup", + } { + cl := getCommandList(name) + if cl.Name != name { + t.Errorf("command list should be %s, got %s", name, cl.Name) + } + } +} + +// TestCmdRun tests Run of Cmd. +func TestCmdRun(t *testing.T) { + cmd := &Cmd{ + Cmd: "echo", + Args: []string{"this", "is", "a", "test"}, + } + stdout, _, err := cmd.Run(context.Background()) + if err != nil { + t.Errorf("unexpected error %s", err) + } + if string(stdout) != "this is a test\n" { + t.Errorf("unexpected stdout: %s", stdout) + } +} + +// TestGetCmds tets GetCmds. +func TestGetCmds(t *testing.T) { + // not existing + if _, err := GetCmds("DoesNotExist", nil); err == nil { + t.Error("not existing command list should return error") + } + + // existing, that only need daemon config as input data + for _, name := range []string{ + // Split Routing + "SplitRoutingSetupRouting", + "SplitRoutingTeardownRouting", + // "SplitRoutingSetExcludes", // skip, requires excludes + "SplitRoutingCleanup", + + // Traffic Policing + "TrafPolSetFilterRules", + "TrafPolUnsetFilterRules", + // TrafPolAddAllowedDevice", // skip, requires device + // "TrafPolRemoveAllowedDevice", // skip, requires device + "TrafPolFlushAllowedHosts", + // "TrafPolAddAllowedHost", // skip, requires host + "TrafPolAddPortalPorts", + "TrafPolRemovePortalPorts", + "TrafPolCleanup", + + // VPN Setup + "VPNSetupSetupVPNDevice", + "VPNSetupTeardownVPNDevice", + "VPNSetupSetupDNSServer", + "VPNSetupSetupDNSDomains", + "VPNSetupSetupDNSDefaultRoute", + "VPNSetupSetupDNS", + "VPNSetupTeardownDNS", + "VPNSetupEnsureDNS", + "VPNSetupCleanup", + } { + if cmds, err := GetCmds(name, daemoncfg.NewConfig()); err != nil || + len(cmds) == 0 { + + t.Errorf("got invalid command list for name %s", name) + } + } + + // existing, with insufficient input data + for _, name := range []string{ + // Split Routing + "SplitRoutingSetExcludes", + + // Traffic Policing + "TrafPolAddAllowedDevice", + "TrafPolRemoveAllowedDevice", + "TrafPolAddAllowedHost", + } { + if _, err := GetCmds(name, daemoncfg.NewConfig()); err == nil { + t.Errorf("insufficient data should return error for list %s", name) + } + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index a6337461..2b4acd6c 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -652,7 +652,7 @@ func (d *Daemon) handleProfileUpdate() error { func (d *Daemon) cleanup(ctx context.Context) { ocrunner.CleanupConnect(d.config.OpenConnect) vpnsetup.Cleanup(ctx, d.config) - trafpol.Cleanup(ctx) + trafpol.Cleanup(ctx, d.config) } // initToken creates the daemon token for client authentication. diff --git a/internal/execs/execs.go b/internal/execs/execs.go index 1c035f8c..b13f2d82 100644 --- a/internal/execs/execs.go +++ b/internal/execs/execs.go @@ -32,62 +32,6 @@ var RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) (std return } -// RunIP runs the "ip" command with args. -func RunIP(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - return RunCmd(ctx, ip, "", arg...) -} - -// RunIPLink runs the "ip link" command with args. -func RunIPLink(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - a := append([]string{"link"}, arg...) - return RunIP(ctx, a...) -} - -// RunIPAddress runs the "ip address" command with args. -func RunIPAddress(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - a := append([]string{"address"}, arg...) - return RunIP(ctx, a...) -} - -// RunIP4Route runs the "ip -4 route" command with args. -func RunIP4Route(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - a := append([]string{"-4", "route"}, arg...) - return RunIP(ctx, a...) -} - -// RunIP6Route runs the "ip -6 route" command with args. -func RunIP6Route(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - a := append([]string{"-6", "route"}, arg...) - return RunIP(ctx, a...) -} - -// RunIP4Rule runs the "ip -4 rule" command with args. -func RunIP4Rule(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - a := append([]string{"-4", "rule"}, arg...) - return RunIP(ctx, a...) -} - -// RunIP6Rule runs the "ip -6 rule" command with args. -func RunIP6Rule(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - a := append([]string{"-6", "rule"}, arg...) - return RunIP(ctx, a...) -} - -// RunSysctl runs the "sysctl" command with args. -func RunSysctl(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - return RunCmd(ctx, sysctl, "", arg...) -} - -// RunNft runs the "nft -f -" command and sets stdin to s. -func RunNft(ctx context.Context, s string) (stdout, stderr []byte, err error) { - return RunCmd(ctx, nft, s, "-f", "-") -} - -// RunResolvectl runs the "resolvectl" command with args. -func RunResolvectl(ctx context.Context, arg ...string) (stdout, stderr []byte, err error) { - return RunCmd(ctx, resolvectl, "", arg...) -} - // SetExecutables configures all executables from config. func SetExecutables(config *daemoncfg.Executables) { ip = config.IP diff --git a/internal/execs/execs_test.go b/internal/execs/execs_test.go index 99a2353f..07d729e0 100644 --- a/internal/execs/execs_test.go +++ b/internal/execs/execs_test.go @@ -3,8 +3,6 @@ package execs import ( "context" "path/filepath" - "reflect" - "strings" "testing" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" @@ -43,188 +41,6 @@ func TestRunCmd(t *testing.T) { } } -// TestRunIP tests RunIP. -func TestRunIP(t *testing.T) { - want := []string{"ip address show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIP(context.Background(), "address", "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunIPLink tests RunIPLink. -func TestRunIPLink(t *testing.T) { - want := []string{"ip link show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIPLink(context.Background(), "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunIPAddress tests RunIPAddress. -func TestRunIPAddress(t *testing.T) { - want := []string{"ip address show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIPAddress(context.Background(), "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunIP4Route tests RunIP4Route. -func TestRunIP4Route(t *testing.T) { - want := []string{"ip -4 route show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIP4Route(context.Background(), "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunIP6Route tests RunIP6Route. -func TestRunIP6Route(t *testing.T) { - want := []string{"ip -6 route show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIP6Route(context.Background(), "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunIP4Rule tests RunIP4Rule. -func TestRunIP4Rule(t *testing.T) { - want := []string{"ip -4 rule show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIP4Rule(context.Background(), "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunIP6Rule tests RunIP6Rule. -func TestRunIP6Rule(t *testing.T) { - want := []string{"ip -6 rule show"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunIP6Rule(context.Background(), "show") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunSysctl tests RunSysctl. -func TestRunSysctl(t *testing.T) { - want := []string{"sysctl -q net.ipv4.conf.all.src_valid_mark=1"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunSysctl(context.Background(), "-q", "net.ipv4.conf.all.src_valid_mark=1") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunNft tests RunNft. -func TestRunNft(t *testing.T) { - want := []string{"nft -f - list tables"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")+" "+s) - return nil, nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - _, _, _ = RunNft(context.Background(), "list tables") - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - -// TestRunResolvectl tests RunResolvectl. -func TestRunResolvectl(t *testing.T) { - want := []string{"resolvectl dns"} - got := []string{} - - oldRunCmd := RunCmd - RunCmd = func(_ context.Context, cmd string, _ string, arg ...string) ([]byte, []byte, error) { - got = append(got, cmd+" "+strings.Join(arg, " ")) - return []byte("OK"), nil, nil - } - defer func() { RunCmd = oldRunCmd }() - - if b, _, err := RunResolvectl(context.Background(), "dns"); err != nil || string(b) != "OK" { - t.Errorf("invalid return values %s, %v", b, err) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} - // TestSetExecutables tests SetExecutables. func TestSetExecutables(t *testing.T) { old := daemoncfg.NewExecutables() diff --git a/internal/splitrt/filter.go b/internal/splitrt/filter.go index 482077e2..663af249 100644 --- a/internal/splitrt/filter.go +++ b/internal/splitrt/filter.go @@ -2,231 +2,37 @@ package splitrt import ( "context" - "fmt" "net/netip" - "strings" log "github.com/sirupsen/logrus" - "github.com/telekom-mms/oc-daemon/internal/execs" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) -// setRoutingRules sets the basic nftables rules for routing. -func setRoutingRules(ctx context.Context, fwMark string) { - const routeRules = ` -table inet oc-daemon-routing { - # set for ipv4 excludes - set excludes4 { - type ipv4_addr - flags interval - } - - # set for ipv6 excludes - set excludes6 { - type ipv6_addr - flags interval - } - - chain preraw { - type filter hook prerouting priority raw; policy accept; - - # add drop rules for non-local traffic from other devices to - # tunnel network addresses here - } - - chain splitrouting { - # restore mark from conntracking - ct mark != 0 meta mark set ct mark counter - meta mark != 0 counter accept - - # mark packets in exclude sets - ip daddr @excludes4 counter meta mark set $FWMARK - ip6 daddr @excludes6 counter meta mark set $FWMARK - - # save mark in conntraction - ct mark set meta mark counter - } - - chain premangle { - type filter hook prerouting priority mangle; policy accept; - - # handle split routing - counter jump splitrouting - } - - chain output { - type route hook output priority mangle; policy accept; - - # handle split routing - counter jump splitrouting - } - - chain postmangle { - type filter hook postrouting priority mangle; policy accept; - - # save mark in conntracking - meta mark $FWMARK ct mark set meta mark counter - } - - chain postrouting { - type nat hook postrouting priority srcnat; policy accept; - - # masquerare tunnel/exclude traffic to make sure the source IP - # matches the outgoing interface - ct mark $FWMARK counter masquerade - } - - chain rejectipversion { - # used to reject unsupported ip version on the tunnel device - - # make sure exclude traffic is not filtered - ct mark $FWMARK counter accept - - # use tcp reset and icmp admin prohibited - meta l4proto tcp counter reject with tcp reset - counter reject with icmpx admin-prohibited - } - - chain rejectforward { - type filter hook forward priority filter; policy accept; - - # reject unsupported ip versions when forwarding packets, - # add matching jump rule to rejectipversion if necessary - } - - chain rejectoutput { - type filter hook output priority filter; policy accept; - - # reject unsupported ip versions when sending packets, - # add matching jump rule to rejectipversion if necessary - } -} -` - r := strings.NewReplacer("$FWMARK", fwMark) - rules := r.Replace(routeRules) - if stdout, stderr, err := execs.RunNft(ctx, rules); err != nil { - log.WithError(err).WithFields(log.Fields{ - "fwMark": fwMark, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting routing rules") - } -} - -// unsetRoutingRules removes the nftables rules for routing. -func unsetRoutingRules(ctx context.Context) { - if stdout, stderr, err := execs.RunNft(ctx, "delete table inet oc-daemon-routing"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error unsetting routing rules") - } -} - -// addLocalAddresses adds rules for device and its family (ip, ip6) addresses, -// that drop non-local traffic from other devices to device's network -// addresses; used to filter non-local traffic to vpn addresses. -func addLocalAddresses(ctx context.Context, device, family string, addresses []netip.Prefix) { - nftconf := "" - for _, addr := range addresses { - if !addr.IsValid() { - continue - } - nftconf += "add rule inet oc-daemon-routing preraw iifname != " - nftconf += fmt.Sprintf("%s %s daddr %s ", device, family, addr) - nftconf += "fib saddr type != local counter drop\n" - } - - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "family": family, - "addresses": addresses, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error adding local addresses") - } -} - -// addLocalAddressesIPv4 adds rules for device and its addresses, that drop -// non-local traffic from other devices to device's network addresses; used to -// filter non-local traffic to vpn addresses. -func addLocalAddressesIPv4(ctx context.Context, device string, addresses []netip.Prefix) { - addLocalAddresses(ctx, device, "ip", addresses) -} - -// addLocalAddressesIPv6 adds rules for device and its addresses, that drop -// non-local traffic from other devices to device's network addresses; used to -// filter non-local traffic to vpn addresses. -func addLocalAddressesIPv6(ctx context.Context, device string, addresses []netip.Prefix) { - addLocalAddresses(ctx, device, "ip6", addresses) -} - -// rejectIPVersion adds rules for the tunnel device to reject an unsupported ip -// version ("ipv6" or "ipv4"). -func rejectIPVersion(ctx context.Context, device, version string) { - nftconf := "" - for _, chain := range []string{"rejectforward", "rejectoutput"} { - nftconf += fmt.Sprintf("add rule inet oc-daemon-routing %s ", - chain) - nftconf += fmt.Sprintf("meta oifname %s meta nfproto %s ", - device, version) - nftconf += "counter jump rejectipversion\n" - } - - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "version": version, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ip version reject rules") - } -} - -// rejectIPv6 adds rules for the tunnel device that reject IPv6 traffic on it; -// used to avoid sending IPv6 packets over a tunnel that only supports IPv4. -func rejectIPv6(ctx context.Context, device string) { - rejectIPVersion(ctx, device, "ipv6") -} - -// rejectIPv4 adds rules for the tunnel device that reject IPv4 traffic on it; -// used to avoid sending IPv4 packets over a tunnel that only supports IPv6. -func rejectIPv4(ctx context.Context, device string) { - rejectIPVersion(ctx, device, "ipv4") -} - // setExcludes resets the excludes to addresses in netfilter. -func setExcludes(ctx context.Context, addresses []netip.Prefix) { - // flush existing entries - nftconf := "" - nftconf += "flush set inet oc-daemon-routing excludes4\n" - nftconf += "flush set inet oc-daemon-routing excludes6\n" - - // add entries - for _, a := range addresses { - set := "excludes4" - if a.Addr().Is6() { - set = "excludes6" +func setExcludes(ctx context.Context, conf *daemoncfg.Config, addresses []netip.Prefix) { + data := &struct { + daemoncfg.Config + Addresses []netip.Prefix + }{ + Config: *conf, + Addresses: addresses, + } + cmds, err := cmdtmpl.GetCmds("SplitRoutingSetExcludes", data) + if err != nil { + log.WithError(err).Error("SplitRouting could not get set excludes commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil { + log.WithFields(log.Fields{ + "addresses": addresses, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("SplitRouting could not run set excludes command") } - nftconf += fmt.Sprintf( - "add element inet oc-daemon-routing %s { %s }\n", - set, a) - } - - // run command - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil { - log.WithError(err).WithFields(log.Fields{ - "addresses": addresses, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting excludes") - } -} - -// cleanupRoutingRules cleans up the nftables rules for routing after a -// failed shutdown. -func cleanupRoutingRules(ctx context.Context) { - if _, _, err := execs.RunNft(ctx, "delete table inet oc-daemon-routing"); err == nil { - log.Debug("SplitRouting cleaned up nft") } } diff --git a/internal/splitrt/route.go b/internal/splitrt/route.go deleted file mode 100644 index 464764ac..00000000 --- a/internal/splitrt/route.go +++ /dev/null @@ -1,156 +0,0 @@ -package splitrt - -import ( - "context" - - log "github.com/sirupsen/logrus" - "github.com/telekom-mms/oc-daemon/internal/execs" -) - -// addDefaultRouteIPv4 adds default routing for IPv4. -func addDefaultRouteIPv4(ctx context.Context, device, rtTable, rulePrio1, fwMark, rulePrio2 string) { - // set default route - if stdout, stderr, err := execs.RunIP4Route(ctx, "add", "0.0.0.0/0", "dev", device, - "table", rtTable); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "rtTable": rtTable, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv4 default route") - } - - // set routing rules - if stdout, stderr, err := execs.RunIP4Rule(ctx, "add", "iif", device, "table", "main", - "pref", rulePrio1); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "rulePrio1": rulePrio1, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv4 routing rule 1") - } - if stdout, stderr, err := execs.RunIP4Rule(ctx, "add", "not", "fwmark", fwMark, - "table", rtTable, "pref", rulePrio2); err != nil { - log.WithError(err).WithFields(log.Fields{ - "fwMark": fwMark, - "rtTable": rtTable, - "rulePrio2": rulePrio2, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv4 routing rule 2") - } - - // set src_valid_mark with sysctl - if stdout, stderr, err := execs.RunSysctl(ctx, "-q", - "net.ipv4.conf.all.src_valid_mark=1"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv4 sysctl") - } -} - -// addDefaultRouteIPv6 adds default routing for IPv6. -func addDefaultRouteIPv6(ctx context.Context, device, rtTable, rulePrio1, fwMark, rulePrio2 string) { - // set default route - if stdout, stderr, err := execs.RunIP6Route(ctx, "add", "::/0", "dev", device, "table", - rtTable); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "rtTable": rtTable, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv6 default route") - } - - // set routing rules - if stdout, stderr, err := execs.RunIP6Rule(ctx, "add", "iif", device, "table", "main", - "pref", rulePrio1); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "rulePrio1": rulePrio1, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv6 routing rule 1") - } - if stdout, stderr, err := execs.RunIP6Rule(ctx, "add", "not", "fwmark", fwMark, - "table", rtTable, "pref", rulePrio2); err != nil { - log.WithError(err).WithFields(log.Fields{ - "fwMark": fwMark, - "rtTable": rtTable, - "rulePrio2": rulePrio2, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error setting ipv6 routing rule 2") - } -} - -// deleteDefaultRouteIPv4 removes default routing for IPv4. -func deleteDefaultRouteIPv4(ctx context.Context, device, rtTable string) { - // delete routing rules - if stdout, stderr, err := execs.RunIP4Rule(ctx, "delete", "table", rtTable); err != nil { - log.WithError(err).WithFields(log.Fields{ - "rtTable": rtTable, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error deleting ipv4 routing rule 2") - } - if stdout, stderr, err := execs.RunIP4Rule(ctx, "delete", "iif", device, "table", - "main"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error deleting ipv4 routing rule 1") - } -} - -// deleteDefaultRouteIPv6 removes default routing for IPv6. -func deleteDefaultRouteIPv6(ctx context.Context, device, rtTable string) { - // delete routing rules - if stdout, stderr, err := execs.RunIP6Rule(ctx, "delete", "table", rtTable); err != nil { - log.WithError(err).WithFields(log.Fields{ - "rtTable": rtTable, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error deleting ipv6 routing rule 2") - } - if stdout, stderr, err := execs.RunIP6Rule(ctx, "delete", "iif", device, "table", - "main"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("SplitRouting error deleting ipv6 routing rule 1") - } -} - -// cleanupRouting cleans up the routing configuration after a failed shutdown. -func cleanupRouting(ctx context.Context, rtTable, rulePrio1, rulePrio2 string) { - // delete ipv4 routing rules - if _, _, err := execs.RunIP4Rule(ctx, "delete", "pref", rulePrio1); err == nil { - log.Debug("SplitRouting cleaned up ipv4 routing rule 1") - } - if _, _, err := execs.RunIP4Rule(ctx, "delete", "pref", rulePrio2); err == nil { - log.Debug("SplitRouting cleaned up ipv4 routing rule 2") - } - - // delete ipv6 routing rules - if _, _, err := execs.RunIP6Rule(ctx, "delete", "pref", rulePrio1); err == nil { - log.Debug("SplitRouting cleaned up ipv6 routing rule 1") - } - if _, _, err := execs.RunIP6Rule(ctx, "delete", "pref", rulePrio2); err == nil { - log.Debug("SplitRouting cleaned up ipv6 routing rule 2") - } - - // flush ipv4 routing table - if _, _, err := execs.RunIP4Route(ctx, "flush", "table", rtTable); err == nil { - log.Debug("SplitRouting cleaned up ipv4 routing table") - } - - // flush ipv6 routing table - if _, _, err := execs.RunIP6Route(ctx, "flush", "table", rtTable); err == nil { - log.Debug("SplitRouting cleaned up ipv6 routing table") - } -} diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index f9bb70d6..e0ed16ca 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/addrmon" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" @@ -62,23 +63,23 @@ type SplitRouting struct { // setupRouting sets up routing using config. func (s *SplitRouting) setupRouting(ctx context.Context) { - // prepare netfilter and excludes - setRoutingRules(ctx, s.config.SplitRouting.FirewallMark) - - // filter non-local traffic to vpn addresses - addLocalAddressesIPv4(ctx, - s.config.VPNConfig.Device.Name, - []netip.Prefix{s.config.VPNConfig.IPv4}) - addLocalAddressesIPv6(ctx, - s.config.VPNConfig.Device.Name, - []netip.Prefix{s.config.VPNConfig.IPv6}) - - // reject unsupported ip versions on vpn - if !s.config.VPNConfig.IPv6.IsValid() { - rejectIPv6(ctx, s.config.VPNConfig.Device.Name) + // set up routing + data := s.config + cmds, err := cmdtmpl.GetCmds("SplitRoutingSetupRouting", data) + if err != nil { + log.WithError(err).Error("SplitRouting could not get setup routing commands") } - if !s.config.VPNConfig.IPv4.IsValid() { - rejectIPv4(ctx, s.config.VPNConfig.Device.Name) + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("SplitRouting could not run setup routing command") + } } // add gateway to static excludes @@ -86,7 +87,7 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { gateway := netip.PrefixFrom(s.config.VPNConfig.Gateway, s.config.VPNConfig.Gateway.BitLen()) if s.excludes.AddStatic(gateway) { - setExcludes(ctx, s.excludes.GetPrefixes()) + setExcludes(ctx, s.config, s.excludes.GetPrefixes()) } } @@ -96,7 +97,7 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { continue } if s.excludes.AddStatic(e) { - setExcludes(ctx, s.excludes.GetPrefixes()) + setExcludes(ctx, s.config, s.excludes.GetPrefixes()) } } @@ -107,34 +108,31 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { continue } if s.excludes.AddStatic(e) { - setExcludes(ctx, s.excludes.GetPrefixes()) + setExcludes(ctx, s.config, s.excludes.GetPrefixes()) } } - - // setup routing - addDefaultRouteIPv4(ctx, - s.config.VPNConfig.Device.Name, - s.config.SplitRouting.RoutingTable, - s.config.SplitRouting.RulePriority1, - s.config.SplitRouting.FirewallMark, - s.config.SplitRouting.RulePriority2) - addDefaultRouteIPv6(ctx, - s.config.VPNConfig.Device.Name, - s.config.SplitRouting.RoutingTable, - s.config.SplitRouting.RulePriority1, - s.config.SplitRouting.FirewallMark, - s.config.SplitRouting.RulePriority2) } // teardownRouting tears down the routing configuration. func (s *SplitRouting) teardownRouting(ctx context.Context) { - deleteDefaultRouteIPv4(ctx, - s.config.VPNConfig.Device.Name, - s.config.SplitRouting.RoutingTable) - deleteDefaultRouteIPv6(ctx, - s.config.VPNConfig.Device.Name, - s.config.SplitRouting.RoutingTable) - unsetRoutingRules(ctx) + // tear down routing + data := s.config + cmds, err := cmdtmpl.GetCmds("SplitRoutingTeardownRouting", data) + if err != nil { + log.WithError(err).Error("SplitRouting could not get teardown routing commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("SplitRouting could not run teardown routing command") + } + } } // excludeSettings returns whether local (virtual) networks should be excluded. @@ -186,7 +184,7 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { for _, e := range excludes { if !isIn(e, s.locals.get()) { if s.excludes.AddStatic(e) { - setExcludes(ctx, s.excludes.GetPrefixes()) + setExcludes(ctx, s.config, s.excludes.GetPrefixes()) } } } @@ -195,7 +193,7 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { for _, l := range s.locals.get() { if !isIn(l, excludes) { if s.excludes.RemoveStatic(l) { - setExcludes(ctx, s.excludes.GetPrefixes()) + setExcludes(ctx, s.config, s.excludes.GetPrefixes()) } } } @@ -245,7 +243,7 @@ func (s *SplitRouting) handleDNSReport(ctx context.Context, r *dnsproxy.Report) exclude := netip.PrefixFrom(r.IP, r.IP.BitLen()) if s.excludes.AddDynamic(exclude, r.TTL) { - setExcludes(ctx, s.excludes.GetPrefixes()) + setExcludes(ctx, s.config, s.excludes.GetPrefixes()) } } @@ -350,9 +348,17 @@ func NewSplitRouting(config *daemoncfg.Config) *SplitRouting { // Cleanup cleans up old configuration after a failed shutdown. func Cleanup(ctx context.Context, config *daemoncfg.Config) { - cleanupRouting(ctx, - config.SplitRouting.RoutingTable, - config.SplitRouting.RulePriority1, - config.SplitRouting.RulePriority2) - cleanupRoutingRules(ctx) + cmds, err := cmdtmpl.GetCmds("SplitRoutingCleanup", config) + if err != nil { + log.WithError(err).Error("SplitRouting could not get cleanup commands") + } + for _, c := range cmds { + if _, _, err := c.Run(ctx); err == nil { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + }).Debug("SplitRouting cleaned up configuration") + } + } } diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index a34869e9..1a110abb 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -168,22 +168,33 @@ func TestSplitRoutingHandleDNSReport(t *testing.T) { go s.handleDNSReport(ctx, report) <-report.Done() + want := []string{ + "flush set inet oc-daemon-routing excludes4\n" + + "flush set inet oc-daemon-routing excludes6\n" + + "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + // test ipv6 + got = []string{} report = dnsproxy.NewReport("example.com", netip.MustParseAddr("2001::1"), 300) go s.handleDNSReport(ctx, report) <-report.Done() - want := []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", + want = []string{ "flush set inet oc-daemon-routing excludes4\n" + "flush set inet oc-daemon-routing excludes6\n" + "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n" + "add element inet oc-daemon-routing excludes6 { 2001::1/128 }\n", + "flush set inet oc-daemon-routing excludes4\n" + + "flush set inet oc-daemon-routing excludes6\n" + + "add element inet oc-daemon-routing excludes6 { 2001::1/128 }\n" + + "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if !reflect.DeepEqual(got[0], want[0]) && !reflect.DeepEqual(got[0], want[1]) { + t.Errorf("got %v, want %v or %v", got[0], want[0], want[1]) } } diff --git a/internal/trafpol/filter.go b/internal/trafpol/filter.go index dadcd314..b15a8474 100644 --- a/internal/trafpol/filter.go +++ b/internal/trafpol/filter.go @@ -3,303 +3,218 @@ package trafpol import ( "context" "errors" - "fmt" "net/netip" - "strconv" - "strings" log "github.com/sirupsen/logrus" - "github.com/telekom-mms/oc-daemon/internal/execs" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // setFilterRules sets the filter rules. -func setFilterRules(ctx context.Context, fwMark string) { - const filterRules = ` -table inet oc-daemon-filter { - # set for allowed devices - set allowdevs { - type ifname - elements = { lo } - } - - # set for allowed ipv4 hosts - set allowhosts4 { - type ipv4_addr - flags interval - } - - # set for allowed ipv6 hosts - set allowhosts6 { - type ipv6_addr - flags interval - } - - # set for allowed ports - set allowports { - type inet_service - elements = { 53 } - } - - chain input { - type filter hook input priority 0; policy drop; - - # accept related traffic - ct state established,related counter accept - - # accept traffic on allowed devices, e.g., lo - iifname @allowdevs counter accept - - # accept ICMP traffic - icmp type { - echo-reply, - destination-unreachable, - source-quench, - redirect, - time-exceeded, - parameter-problem, - timestamp-reply, - info-reply, - address-mask-reply, - router-advertisement, - } counter accept - - # accept ICMPv6 traffic otherwise IPv6 connectivity breaks - icmpv6 type { - destination-unreachable, - packet-too-big, - time-exceeded, - echo-reply, - mld-listener-query, - mld-listener-report, - mld2-listener-report, - mld-listener-done, - nd-router-advert, - nd-neighbor-solicit, - nd-neighbor-advert, - ind-neighbor-solicit, - ind-neighbor-advert, - nd-redirect, - parameter-problem, - router-renumbering - } counter accept - - # accept DHCPv4 traffic - udp dport 68 udp sport 67 counter accept - - # accept DHCPv6 traffic - udp dport 546 udp sport 547 counter accept - } - - chain output { - type filter hook output priority 0; policy drop; - - # accept related traffic - ct state established,related counter accept - - # accept traffic on allowed devices, e.g., lo - oifname @allowdevs counter accept - - # accept traffic to allowed hosts - ip daddr @allowhosts4 counter accept - ip6 daddr @allowhosts6 counter accept - - # accept ICMP traffic - icmp type { - source-quench, - echo-request, - timestamp-request, - info-request, - address-mask-request, - router-solicitation - } counter accept - - # accept ICMPv6 traffic otherwise IPv6 connectivity breaks - icmpv6 type { - echo-request, - mld-listener-report, - mld2-listener-report, - mld-listener-done, - nd-router-solicit, - nd-neighbor-solicit, - nd-neighbor-advert, - ind-neighbor-solicit, - ind-neighbor-advert, - } counter accept - - # accept traffic to allowed ports, e.g., DNS - udp dport @allowports counter accept - tcp dport @allowports counter accept - - # accept DHCPv4 traffic - udp dport 67 udp sport 68 counter accept - - # accept DHCPv6 traffic - udp dport 547 udp sport 546 counter accept - - # reject everything else - counter reject - } - - chain forward { - type filter hook forward priority 0; policy drop; - - # accept related traffic - ct state established,related counter accept - - # accept split exclude traffic - iifname @allowdevs ct mark $FWMARK counter accept - - # accept traffic on allowed devices - iifname @allowdevs oifname @allowdevs counter accept - } -} -` - r := strings.NewReplacer("$FWMARK", fwMark) - rules := r.Replace(filterRules) - if stdout, stderr, err := execs.RunNft(ctx, rules); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error setting routing rules") +func setFilterRules(ctx context.Context, config *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("TrafPolSetFilterRules", config) + if err != nil { + log.WithError(err).Error("TrafPol could not get set filter rules commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run set filter rules command") + } } } // unsetFilterRules unsets the filter rules. -func unsetFilterRules(ctx context.Context) { - if stdout, stderr, err := execs.RunNft(ctx, - "delete table inet oc-daemon-filter"); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error unsetting routing rules") +func unsetFilterRules(ctx context.Context, config *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("TrafPolUnsetFilterRules", config) + if err != nil { + log.WithError(err).Error("TrafPol could not get unset filter rules commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run unset filter rules command") + } } } // addAllowedDevice adds device to the allowed devices. -func addAllowedDevice(ctx context.Context, device string) { - nftconf := fmt.Sprintf("add element inet oc-daemon-filter allowdevs { %s }", device) - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error adding allowed device") +func addAllowedDevice(ctx context.Context, conf *daemoncfg.Config, device string) { + data := &struct { + daemoncfg.Config + Device string + }{ + Config: *conf, + Device: device, + } + cmds, err := cmdtmpl.GetCmds("TrafPolAddAllowedDevice", data) + if err != nil { + log.WithError(err).Error("TrafPol could not get add allowed device commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "device": device, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run add allowed device command") + } } } // removeAllowedDevice removes device from the allowed devices. -func removeAllowedDevice(ctx context.Context, device string) { - nftconf := fmt.Sprintf("delete element inet oc-daemon-filter allowdevs { %s }", device) - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error removing allowed device") +func removeAllowedDevice(ctx context.Context, conf *daemoncfg.Config, device string) { + data := &struct { + daemoncfg.Config + Device string + }{ + Config: *conf, + Device: device, + } + cmds, err := cmdtmpl.GetCmds("TrafPolRemoveAllowedDevice", data) + if err != nil { + log.WithError(err).Error("TrafPol could not get remove allowed device commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "device": device, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run remove allowed device command") + } } } // setAllowedIPs set the allowed hosts. -func setAllowedIPs(ctx context.Context, ips []netip.Prefix) { +func setAllowedIPs(ctx context.Context, conf *daemoncfg.Config, ips []netip.Prefix) { // we perform all nft commands separately here and not as one atomic // operation to avoid issues where the whole update fails because nft // runs into "file exists" errors even though we remove duplicates from // ips before calling this function and we flush the existing entries - if stdout, stderr, err := execs.RunNft(ctx, - "flush set inet oc-daemon-filter allowhosts4"); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error flushing allowed ipv4s") + // flush allowed hosts + cmds, err := cmdtmpl.GetCmds("TrafPolFlushAllowedHosts", conf) + if err != nil { + log.WithError(err).Error("TrafPol could not get flush allowed hosts commands") } - if stdout, stderr, err := execs.RunNft(ctx, - "flush set inet oc-daemon-filter allowhosts6"); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error flushing allowed ipv6s") + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run flush allowed hosts command") + } } - fmt4 := "add element inet oc-daemon-filter allowhosts4 { %s }" - fmt6 := "add element inet oc-daemon-filter allowhosts6 { %s }" + // add allowed hosts for _, ip := range ips { - if ip.Addr().Is4() { - // ipv4 address - nftconf := fmt.Sprintf(fmt4, ip) - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error adding allowed ipv4") - } - } else { - // ipv6 address - nftconf := fmt.Sprintf(fmt6, ip) - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error adding allowed ipv6") + data := &struct { + daemoncfg.Config + AllowedIP netip.Prefix + }{ + Config: *conf, + AllowedIP: ip, + } + cmds, err := cmdtmpl.GetCmds("TrafPolAddAllowedHost", data) + if err != nil { + log.WithError(err).Error("TrafPol could not get add allowed host commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "host": ip, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run add allowed host command") } } } } -// portsToString returns ports as string. -func portsToString(ports []uint16) string { - s := []string{} - for _, port := range ports { - s = append(s, strconv.FormatUint(uint64(port), 10)) - } - return strings.Join(s, ", ") -} - // addPortalPorts adds ports for a captive portal to the allowed ports. -func addPortalPorts(ctx context.Context, ports []uint16) { - p := portsToString(ports) - nftconf := fmt.Sprintf("add element inet oc-daemon-filter allowports { %s }", p) - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error adding portal ports") +func addPortalPorts(ctx context.Context, conf *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("TrafPolAddPortalPorts", conf) + if err != nil { + log.WithError(err).Error("TrafPol could not get add portal ports commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "ports": conf.TrafficPolicing.PortalPorts, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run add portal ports command") + } } } // removePortalPorts removes ports for a captive portal from the allowed ports. -func removePortalPorts(ctx context.Context, ports []uint16) { - p := portsToString(ports) - nftconf := fmt.Sprintf("delete element inet oc-daemon-filter allowports { %s }", p) - if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && - !errors.Is(err, context.Canceled) { - - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("TrafPol error removing portal ports") +func removePortalPorts(ctx context.Context, conf *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("TrafPolRemovePortalPorts", conf) + if err != nil { + log.WithError(err).Error("TrafPol could not get remove portal ports commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "ports": conf.TrafficPolicing.PortalPorts, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("TrafPol could not run remove portal ports command") + } } } // cleanupFilterRules cleans up the filter rules after a failed shutdown. -func cleanupFilterRules(ctx context.Context) { - if _, _, err := execs.RunNft(ctx, "delete table inet oc-daemon-filter"); err == nil { - log.Debug("TrafPol cleaned up nft") +func cleanupFilterRules(ctx context.Context, conf *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("TrafPolCleanup", conf) + if err != nil { + log.WithError(err).Error("TrafPol could not get cleanup commands") + } + for _, c := range cmds { + if _, _, err := c.Run(ctx); err == nil { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + }).Warn("TrafPol cleaned up configuration") + } } } diff --git a/internal/trafpol/filter_test.go b/internal/trafpol/filter_test.go index ad6ceb00..9d4ca746 100644 --- a/internal/trafpol/filter_test.go +++ b/internal/trafpol/filter_test.go @@ -6,6 +6,7 @@ import ( "net/netip" "testing" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/execs" ) @@ -20,20 +21,23 @@ func TestFilterFunctionsErrors(_ *testing.T) { ctx := context.Background() // filter rules - setFilterRules(ctx, "123") - unsetFilterRules(ctx) + conf := daemoncfg.NewConfig() + conf.SplitRouting.FirewallMark = "123" + setFilterRules(ctx, conf) + unsetFilterRules(ctx, conf) // allowed devices - addAllowedDevice(ctx, "eth0") - removeAllowedDevice(ctx, "eth0") + addAllowedDevice(ctx, conf, "eth0") + removeAllowedDevice(ctx, conf, "eth0") // allowed IPs - setAllowedIPs(ctx, []netip.Prefix{ + setAllowedIPs(ctx, conf, []netip.Prefix{ netip.MustParsePrefix("192.168.1.1/32"), netip.MustParsePrefix("2000::1/128"), }) // portal ports - addPortalPorts(ctx, []uint16{80, 443}) - removePortalPorts(ctx, []uint16{80, 443}) + conf.TrafficPolicing.PortalPorts = []uint16{80, 443} + addPortalPorts(ctx, conf) + removePortalPorts(ctx, conf) } diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go index 95181dd1..0f702166 100644 --- a/internal/trafpol/trafpol.go +++ b/internal/trafpol/trafpol.go @@ -71,12 +71,12 @@ func (t *TrafPol) handleDeviceUpdate(ctx context.Context, u *devmon.Update) { // skip when removing devices. if u.Add && u.Type != "device" { if t.allowDevs.Add(u.Device) { - addAllowedDevice(ctx, u.Device) + addAllowedDevice(ctx, t.config, u.Device) } return } if t.allowDevs.Remove(u.Device) { - removeAllowedDevice(ctx, u.Device) + removeAllowedDevice(ctx, t.config, u.Device) } } @@ -100,7 +100,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { t.resolver.Resolve() // remove ports from allowed ports - removePortalPorts(ctx, t.config.TrafficPolicing.PortalPorts) + removePortalPorts(ctx, t.config) t.capPortal = false log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status") } @@ -109,7 +109,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { // add ports to allowed ports if !t.capPortal { - addPortalPorts(ctx, t.config.TrafficPolicing.PortalPorts) + addPortalPorts(ctx, t.config) t.capPortal = true log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status") } @@ -147,7 +147,7 @@ func (t *TrafPol) handleResolverUpdate(ctx context.Context, update *ResolvedName t.allowNames.Add(update.Name, update.IPs) // set new filter rules - setAllowedIPs(ctx, t.getAllowedHostsIPs()) + setAllowedIPs(ctx, t.config, t.getAllowedHostsIPs()) } // handleAddressCommand handles an address command. @@ -169,7 +169,7 @@ func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolCmd) { } // set new filter rules - setAllowedIPs(ctx, t.getAllowedHostsIPs()) + setAllowedIPs(ctx, t.config, t.getAllowedHostsIPs()) // added/removed successfully cmd.ok = true @@ -201,7 +201,7 @@ func (t *TrafPol) handleCommand(ctx context.Context, cmd *trafPolCmd) { // start starts the traffic policing component. func (t *TrafPol) start(ctx context.Context) { defer close(t.loopDone) - defer unsetFilterRules(ctx) + defer unsetFilterRules(ctx, t.config) defer t.resolver.Stop() defer t.cpd.Stop() defer t.devmon.Stop() @@ -250,10 +250,10 @@ func (t *TrafPol) Start() error { ctx := context.Background() // set firewall config - setFilterRules(ctx, t.config.SplitRouting.FirewallMark) + setFilterRules(ctx, t.config) // set filter rules - setAllowedIPs(ctx, t.getAllowedHostsIPs()) + setAllowedIPs(ctx, t.config, t.getAllowedHostsIPs()) // start resolver for allowed names t.resolver.Start() @@ -284,7 +284,7 @@ cleanup_dnsmon: cleanup_devmon: t.cpd.Stop() t.resolver.Stop() - unsetFilterRules(ctx) + unsetFilterRules(ctx, t.config) return err } @@ -409,6 +409,6 @@ func NewTrafPol(config *daemoncfg.Config) *TrafPol { } // Cleanup cleans up old configuration after a failed shutdown. -func Cleanup(ctx context.Context) { - cleanupFilterRules(ctx) +func Cleanup(ctx context.Context, conf *daemoncfg.Config) { + cleanupFilterRules(ctx, conf) } diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index 12b2e8b3..a69e4d66 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -6,6 +6,7 @@ import ( "reflect" "slices" "sort" + "strings" "sync" "testing" @@ -56,10 +57,10 @@ func TestTrafPolHandleCPDReport(t *testing.T) { var nftMutex sync.Mutex nftCmds := []string{} oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { + execs.RunCmd = func(_ context.Context, cmd string, _ string, args ...string) ([]byte, []byte, error) { nftMutex.Lock() defer nftMutex.Unlock() - nftCmds = append(nftCmds, s) + nftCmds = append(nftCmds, cmd+" "+strings.Join(args, " ")) return nil, nil, nil } defer func() { execs.RunCmd = oldRunCmd }() @@ -85,7 +86,7 @@ func TestTrafPolHandleCPDReport(t *testing.T) { tp.handleCPDReport(ctx, report) want = []string{ - "add element inet oc-daemon-filter allowports { 80, 443 }", + "nft -f - add element inet oc-daemon-filter allowports { 80, 443 }", } got = getNftCmds() if !reflect.DeepEqual(got, want) { @@ -97,8 +98,8 @@ func TestTrafPolHandleCPDReport(t *testing.T) { tp.handleCPDReport(ctx, report) want = []string{ - "add element inet oc-daemon-filter allowports { 80, 443 }", - "delete element inet oc-daemon-filter allowports { 80, 443 }", + "nft -f - add element inet oc-daemon-filter allowports { 80, 443 }", + "nft -f - delete element inet oc-daemon-filter allowports { 80, 443 }", } got = getNftCmds() if !reflect.DeepEqual(got, want) { @@ -295,14 +296,14 @@ func TestNewTrafPol(t *testing.T) { // TestCleanup tests Cleanup. func TestCleanup(t *testing.T) { want := []string{ - "delete table inet oc-daemon-filter", + "nft -f - delete table inet oc-daemon-filter", } got := []string{} - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) + execs.RunCmd = func(_ context.Context, cmd string, _ string, args ...string) ([]byte, []byte, error) { + got = append(got, cmd+" "+strings.Join(args, " ")) return nil, nil, nil } - Cleanup(context.Background()) + Cleanup(context.Background(), daemoncfg.NewConfig()) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go index 3588a1cd..251de10a 100644 --- a/internal/vpnsetup/vpnsetup.go +++ b/internal/vpnsetup/vpnsetup.go @@ -3,15 +3,15 @@ package vpnsetup import ( "context" - "net/netip" - "strconv" + "errors" + "slices" "strings" "time" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/telekom-mms/oc-daemon/internal/splitrt" ) @@ -51,73 +51,43 @@ type VPNSetup struct { } // setupVPNDevice sets up the vpn device with config. -func setupVPNDevice(ctx context.Context, c *daemoncfg.Config) { - // set mtu on device - mtu := strconv.Itoa(c.VPNConfig.Device.MTU) - if stdout, stderr, err := execs.RunIPLink( - ctx, "set", c.VPNConfig.Device.Name, "mtu", mtu); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": c.VPNConfig.Device.Name, - "mtu": mtu, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("Daemon could not set mtu on device") - return - } - - // set device up - if stdout, stderr, err := execs.RunIPLink( - ctx, "set", c.VPNConfig.Device.Name, "up"); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": c.VPNConfig.Device.Name, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("Daemon could not set device up") - return - } - - // set ipv4 and ipv6 addresses on device - setupIP := func(a netip.Prefix) { - dev := c.VPNConfig.Device.Name - addr := a.String() - if stdout, stderr, err := execs.RunIPAddress( - ctx, "add", addr, "dev", dev); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": dev, - "ip": addr, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("Daemon could not set ip on device") - return +func setupVPNDevice(ctx context.Context, config *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("VPNSetupSetupVPNDevice", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get setup VPN device commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run setup VPN device command") } - - } - - if c.VPNConfig.IPv4.IsValid() { - setupIP(c.VPNConfig.IPv4) - } - if c.VPNConfig.IPv6.IsValid() { - setupIP(c.VPNConfig.IPv6) } } // teardownVPNDevice tears down the configured vpn device. -func teardownVPNDevice(ctx context.Context, c *daemoncfg.Config) { - // set device down - if stdout, stderr, err := execs.RunIPLink( - ctx, "set", c.VPNConfig.Device.Name, "down"); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": c.VPNConfig.Device.Name, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("Daemon could not set device down") - return +func teardownVPNDevice(ctx context.Context, config *daemoncfg.Config) { + cmds, err := cmdtmpl.GetCmds("VPNSetupTeardownVPNDevice", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get teardown VPN device commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run teardown VPN device command") + } } - } // setupRouting sets up routing using config. @@ -142,45 +112,61 @@ func (v *VPNSetup) teardownRouting() { // setupDNSServer sets the DNS server. func (v *VPNSetup) setupDNSServer(ctx context.Context, config *daemoncfg.Config) { - device := config.VPNConfig.Device.Name - if stdout, stderr, err := execs.RunResolvectl( - ctx, "dns", device, config.DNSProxy.Address); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": device, - "server": config.DNSProxy.Address, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error setting dns server") + cmds, err := cmdtmpl.GetCmds("VPNSetupSetupDNSServer", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get setup DNS server commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run setup DNS server command") + } } } // setupDNSDomains sets the DNS domains. func (v *VPNSetup) setupDNSDomains(ctx context.Context, config *daemoncfg.Config) { - device := config.VPNConfig.Device.Name - if stdout, stderr, err := execs.RunResolvectl( - ctx, "domain", device, config.VPNConfig.DNS.DefaultDomain, "~."); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": device, - "domain": config.VPNConfig.DNS.DefaultDomain, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error setting dns domains") + cmds, err := cmdtmpl.GetCmds("VPNSetupSetupDNSDomains", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get setup DNS domains commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run setup DNS domains command") + } } } // setupDNSDefaultRoute sets the DNS default route. func (v *VPNSetup) setupDNSDefaultRoute(ctx context.Context, config *daemoncfg.Config) { - device := config.VPNConfig.Device.Name - if stdout, stderr, err := execs.RunResolvectl( - ctx, "default-route", device, "yes"); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": device, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error setting dns default route") + cmds, err := cmdtmpl.GetCmds("VPNSetupSetupDNSDefaultRoute", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get setup DNS default route commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run setup DNS default route command") + } } } @@ -198,31 +184,21 @@ func (v *VPNSetup) setupDNS(ctx context.Context, config *daemoncfg.Config) { v.dnsProxy.SetWatches(excludes) // update dns configuration of host - - // set dns server for device - v.setupDNSServer(ctx, config) - - // set domains for device - // this includes "~." to use this device for all domains - v.setupDNSDomains(ctx, config) - - // set default route for device - v.setupDNSDefaultRoute(ctx, config) - - // flush dns caches - if stdout, stderr, err := execs.RunResolvectl(ctx, "flush-caches"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error flushing dns caches during setup") - } - - // reset learnt server features - if stdout, stderr, err := execs.RunResolvectl(ctx, "reset-server-features"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error resetting server features during setup") + cmds, err := cmdtmpl.GetCmds("VPNSetupSetupDNS", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get setup DNS commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run setup DNS command") + } } } @@ -238,32 +214,21 @@ func (v *VPNSetup) teardownDNS(ctx context.Context, config *daemoncfg.Config) { v.dnsProxy.SetWatches([]string{}) // update dns configuration of host - - // undo device dns configuration - if stdout, stderr, err := execs.RunResolvectl( - ctx, "revert", config.VPNConfig.Device.Name); err != nil { - - log.WithError(err).WithFields(log.Fields{ - "device": config.VPNConfig.Device.Name, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error reverting dns configuration") - } - - // flush dns caches - if stdout, stderr, err := execs.RunResolvectl(ctx, "flush-caches"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error flushing dns caches during teardown") - } - - // reset learnt server features - if stdout, stderr, err := execs.RunResolvectl(ctx, "reset-server-features"); err != nil { - log.WithError(err).WithFields(log.Fields{ - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error resetting server features during teardown") + cmds, err := cmdtmpl.GetCmds("VPNSetupTeardownDNS", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get teardown DNS commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run teardown DNS command") + } } } @@ -320,15 +285,25 @@ func (v *VPNSetup) ensureDNS(ctx context.Context, config *daemoncfg.Config) bool log.Debug("VPNSetup checking DNS settings") // get dns settings - device := config.VPNConfig.Device.Name - stdout, stderr, err := execs.RunResolvectl(ctx, "status", device, "--no-pager") + cmds, err := cmdtmpl.GetCmds("VPNSetupEnsureDNS", config) if err != nil { - log.WithError(err).WithFields(log.Fields{ - "device": device, - "stdout": string(stdout), - "stderr": string(stderr), - }).Error("VPNSetup error getting DNS settings") - return false + log.WithError(err).Error("VPNSetup could not get ensure DNS commands") + } + var stdout []byte + for _, c := range cmds { + sout, serr, err := c.Run(ctx) + if err != nil { + log.WithFields(log.Fields{ + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(sout), + "stderr": string(serr), + }).Error("VPNSetup could not run ensure DNS command") + return false + } + // collect output + stdout = slices.Concat(stdout, sout) } // parse and check dns settings line by line @@ -574,13 +549,19 @@ func NewVPNSetup(dnsProxy *dnsproxy.Proxy) *VPNSetup { func Cleanup(ctx context.Context, config *daemoncfg.Config) { // dns, device, split routing vpnDevice := config.OpenConnect.VPNDevice - if _, _, err := execs.RunResolvectl(ctx, "revert", vpnDevice); err == nil { - log.WithField("device", vpnDevice). - Warn("VPNSetup cleaned up dns config") - } - if _, _, err := execs.RunIPLink(ctx, "delete", vpnDevice); err == nil { - log.WithField("device", vpnDevice). - Warn("VPNSetup cleaned up vpn device") + cmds, err := cmdtmpl.GetCmds("VPNSetupCleanup", config) + if err != nil { + log.WithError(err).Error("VPNSetup could not get cleanup commands") + } + for _, c := range cmds { + if _, _, err := c.Run(ctx); err == nil { + log.WithFields(log.Fields{ + "device": vpnDevice, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + }).Warn("VPNSetup cleaned up configuration") + } } splitrt.Cleanup(ctx, config) }