Skip to content

Commit

Permalink
Merge pull request #120 from telekom-mms/feature/switch-to-netip-inte…
Browse files Browse the repository at this point in the history
…rnally

Switch to package netip internally
  • Loading branch information
hwipl authored Aug 28, 2024
2 parents 2afa099 + 6a89fb4 commit 96e12e5
Show file tree
Hide file tree
Showing 22 changed files with 199 additions and 199 deletions.
14 changes: 11 additions & 3 deletions internal/addrmon/addrmon.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package addrmon

import (
"fmt"
"net"
"net/netip"

log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
Expand All @@ -12,7 +12,7 @@ import (
// Update is an address update.
type Update struct {
Add bool
Address net.IPNet
Address netip.Prefix
Index int
}

Expand Down Expand Up @@ -72,8 +72,16 @@ func (a *AddrMon) start() {
}

// forward event as address update
ip, ok := netip.AddrFromSlice(e.LinkAddress.IP)
if !ok || !ip.IsValid() {
log.WithField("LinkAddress", e.LinkAddress).
Error("AddrMon got invalid IP in addr event")
continue
}
ones, _ := e.LinkAddress.Mask.Size()
addr := netip.PrefixFrom(ip, ones)
u := &Update{
Address: e.LinkAddress,
Address: addr,
Index: e.LinkIndex,
Add: e.NewAddr,
}
Expand Down
8 changes: 7 additions & 1 deletion internal/addrmon/addrmon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package addrmon

import (
"log"
"net"
"testing"

"github.com/vishvananda/netlink"
Expand Down Expand Up @@ -44,7 +45,12 @@ func TestAddrMonStartStop(t *testing.T) {
// helper function for AddrUpdates
addrUpdates := func(updates chan netlink.AddrUpdate, done chan struct{}) {
for {
up := netlink.AddrUpdate{}
up := netlink.AddrUpdate{
LinkAddress: net.IPNet{
IP: net.IPv4(192, 168, 1, 1),
Mask: net.CIDRMask(24, 32),
},
}
select {
case updates <- up:
case <-done:
Expand Down
13 changes: 8 additions & 5 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"reflect"
"slices"
"strconv"
Expand Down Expand Up @@ -66,7 +67,7 @@ type Daemon struct {
disableTrafPol bool

// serverIP is the IP address of the current VPN server
serverIP net.IP
serverIP netip.Addr

// serverIPAllowed indicates whether server IP was added to
// the allowed addresses
Expand Down Expand Up @@ -307,7 +308,9 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) {
}

// set server address
d.serverIP = net.ParseIP(strings.Trim(login.Host, "[]"))
if serverIP, err := netip.ParseAddr(strings.Trim(login.Host, "[]")); err == nil {
d.serverIP = serverIP
}

// update status
d.setStatusOCRunning(true)
Expand All @@ -316,7 +319,7 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) {
d.setStatusConnectionState(vpnstatus.ConnectionStateConnecting)

// add server address to allowed addrs in trafpol
if d.trafpol != nil && d.serverIP != nil {
if d.trafpol != nil && d.serverIP.IsValid() {
d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP)
}

Expand Down Expand Up @@ -542,7 +545,7 @@ func (d *Daemon) handleRunnerDisconnect() {
if d.trafpol != nil && d.serverIPAllowed {
d.trafpol.RemoveAllowedAddr(d.serverIP)
}
d.serverIP = nil
d.serverIP = netip.Addr{}
d.serverIPAllowed = false
}

Expand Down Expand Up @@ -743,7 +746,7 @@ func (d *Daemon) startTrafPol() error {
d.setStatusTrafPolState(vpnstatus.TrafPolStateActive)
d.setStatusAllowedHosts(c.AllowedHosts)

if d.serverIP != nil {
if d.serverIP.IsValid() {
// VPN connection active, allow server IP
d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP)
}
Expand Down
17 changes: 15 additions & 2 deletions internal/dnsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dnsproxy

import (
"math/rand"
"net/netip"

"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -101,7 +102,13 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
log.Error("DNS-Proxy received invalid A record in reply")
return
}
report := NewReport(rr.Hdr.Name, rr.A, rr.Hdr.Ttl)
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.WithField("A", rr.A).
Error("DNS-Proxy received invalid IP in A record in reply")
return
}
report := NewReport(rr.Hdr.Name, addr, rr.Hdr.Ttl)
p.sendReport(report)
p.waitReport(report)
}
Expand All @@ -114,7 +121,13 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
log.Error("DNS-Proxy received invalid AAAA record in reply")
return
}
report := NewReport(rr.Hdr.Name, rr.AAAA, rr.Hdr.Ttl)
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.WithField("AAAA", rr.AAAA).
Error("DNS-Proxy received invalid IP in AAAA record in reply")
return
}
report := NewReport(rr.Hdr.Name, addr, rr.Hdr.Ttl)
p.sendReport(report)
p.waitReport(report)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/dnsproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dnsproxy
import (
"errors"
"net"
"net/netip"
"testing"

"github.com/miekg/dns"
Expand Down Expand Up @@ -127,8 +128,8 @@ func TestProxyHandleRequest(t *testing.T) {
if r.Name != "example.com." {
t.Errorf("invalid domain name: %s", r.Name)
}
if !r.IP.Equal(net.IPv4(127, 0, 0, 1)) &&
!r.IP.Equal(net.ParseIP("::1")) {
if r.IP != netip.MustParseAddr("127.0.0.1") &&
r.IP != netip.MustParseAddr("::1") {
t.Errorf("invalid IP: %s", r.IP)
}
}
Expand Down Expand Up @@ -205,8 +206,8 @@ func TestProxyHandleRequestRecords(t *testing.T) {
t.Fatalf("invalid reports for run %d: %v", i, reports)
}
for _, r := range reports {
if !r.IP.Equal(net.ParseIP("127.0.0.1")) &&
!r.IP.Equal(net.ParseIP("::1")) {
if r.IP != netip.MustParseAddr("127.0.0.1") &&
r.IP != netip.MustParseAddr("::1") {

t.Errorf("invalid report for run %d: %v", i, r)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/dnsproxy/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package dnsproxy

import (
"fmt"
"net"
"net/netip"
)

// Report is a report for a watched domain.
type Report struct {
Name string
IP net.IP
IP netip.Addr
TTL uint32

// done is used to signal that the report has been handled by
Expand All @@ -32,7 +32,7 @@ func (r *Report) Done() <-chan struct{} {
}

// NewReport returns a new report with domain name, IP and TTL.
func NewReport(name string, ip net.IP, ttl uint32) *Report {
func NewReport(name string, ip netip.Addr, ttl uint32) *Report {
return &Report{
Name: name,
IP: ip,
Expand Down
10 changes: 5 additions & 5 deletions internal/dnsproxy/report_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package dnsproxy

import (
"net"
"net/netip"
"testing"
)

// TestReportString tests String of Report.
func TestReportString(t *testing.T) {
name := "example.com."
ip := net.IPv4(192, 168, 1, 1)
ip := netip.MustParseAddr("192.168.1.1")
ttl := uint32(300)
r := NewReport(name, ip, ttl)

Expand All @@ -22,7 +22,7 @@ func TestReportString(t *testing.T) {
// TestReportDone tests Wait and Done of Report.
func TestReportWaitDone(_ *testing.T) {
name := "example.com."
ip := net.IPv4(192, 168, 1, 1)
ip := netip.MustParseAddr("192.168.1.1")
ttl := uint32(300)
r := NewReport(name, ip, ttl)

Expand All @@ -33,7 +33,7 @@ func TestReportWaitDone(_ *testing.T) {
// TestNewReport tests NewReport.
func TestNewReport(t *testing.T) {
name := "example.com."
ip := net.IPv4(192, 168, 1, 1)
ip := netip.MustParseAddr("192.168.1.1")
ttl := uint32(300)
r := NewReport(name, ip, ttl)

Expand All @@ -43,7 +43,7 @@ func TestNewReport(t *testing.T) {
if r.Name != name {
t.Errorf("got %s, want %s", r.Name, name)
}
if !r.IP.Equal(ip) {
if r.IP != ip {
t.Errorf("got %s, want %s", r.IP, ip)
}
if r.TTL != ttl {
Expand Down
6 changes: 3 additions & 3 deletions internal/splitrt/addresses.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package splitrt

import (
"net"
"net/netip"

"github.com/telekom-mms/oc-daemon/internal/addrmon"
)
Expand Down Expand Up @@ -51,9 +51,9 @@ func (a *Addresses) Remove(addr *addrmon.Update) {
}

// Get returns the addresses of the device identified by index.
func (a *Addresses) Get(index int) (addrs []*net.IPNet) {
func (a *Addresses) Get(index int) (addrs []netip.Prefix) {
for _, x := range a.m[index] {
addrs = append(addrs, &x.Address)
addrs = append(addrs, x.Address)
}
return
}
Expand Down
18 changes: 9 additions & 9 deletions internal/splitrt/addresses_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package splitrt

import (
"net"
"net/netip"
"reflect"
"testing"

Expand All @@ -10,14 +10,14 @@ import (

// getTestAddrMonUpdate returns an AddrMon update for testing.
func getTestAddrMonUpdate(t *testing.T, addr string) *addrmon.Update {
_, ipnet, err := net.ParseCIDR(addr)
prefix, err := netip.ParsePrefix(addr)
if err != nil {
t.Fatal(err)
}

return &addrmon.Update{
Add: true,
Address: *ipnet,
Address: prefix,
Index: 1,
}
}
Expand Down Expand Up @@ -72,16 +72,16 @@ func TestAddressesGet(t *testing.T) {
update2 := getTestAddrMonUpdate(t, "192.168.2.0/24")

// get empty
var want []*net.IPNet
var want []netip.Prefix
got := a.Get(1)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}

// get with one address
a.Add(update1)
want = []*net.IPNet{
&update1.Address,
want = []netip.Prefix{
update1.Address,
}
got = a.Get(1)
if !reflect.DeepEqual(got, want) {
Expand All @@ -97,9 +97,9 @@ func TestAddressesGet(t *testing.T) {

// get with multiple addresses
a.Add(update2)
want = []*net.IPNet{
&update1.Address,
&update2.Address,
want = []netip.Prefix{
update1.Address,
update2.Address,
}
got = a.Get(1)
if !reflect.DeepEqual(got, want) {
Expand Down
Loading

0 comments on commit 96e12e5

Please sign in to comment.