diff --git a/CHANGELOG.md b/CHANGELOG.md
index f763b69aa..ad1714780 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+## [1.9.4] - 2024-09-09
+
+### Added
+
+- Support UDP dialing with gVisor. (#1181)
+
+### Changed
+
+- Make some Nebula state programmatically available via control object. (#1188)
+- Switch internal representation of IPs to netip, to prepare for IPv6 support
+ in the overlay. (#1173)
+- Minor build and cleanup changes. (#1171, #1164, #1162)
+- Various dependency updates. (#1195, #1190, #1174, #1168, #1167, #1161, #1147, #1146)
+
+### Fixed
+
+- Fix a bug on big endian hosts, like mips. (#1194)
+- Fix a rare panic if a local index collision happens. (#1191)
+- Fix integer wraparound in the calculation of handshake timeouts on 32-bit targets. (#1185)
+
## [1.9.3] - 2024-06-06
### Fixed
@@ -644,7 +664,8 @@ created.)
- Initial public release.
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.3...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
+[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
[1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1
diff --git a/allow_list.go b/allow_list.go
index 9186b2fc7..90e0de231 100644
--- a/allow_list.go
+++ b/allow_list.go
@@ -2,17 +2,16 @@ package nebula
import (
"fmt"
- "net"
+ "net/netip"
"regexp"
- "github.com/slackhq/nebula/cidr"
+ "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
- cidrTree *cidr.Tree6[bool]
+ cidrTree *bart.Table[bool]
}
type RemoteAllowList struct {
@@ -20,7 +19,7 @@ type RemoteAllowList struct {
// Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList
- insideAllowLists *cidr.Tree6[*AllowList]
+ insideAllowLists *bart.Table[*AllowList]
}
type LocalAllowList struct {
@@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
- tree := cidr.NewTree6[bool]()
+ tree := new(bart.Table[bool])
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
@@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
- _, ipNet, err := net.ParseCIDR(rawCIDR)
+ ipNet, err := netip.ParsePrefix(rawCIDR)
if err != nil {
- return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+ return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
}
+ ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
+
// TODO: should we error on duplicate CIDRs in the config?
- tree.AddCIDR(ipNet, value)
+ tree.Insert(ipNet, value)
- maskBits, maskSize := ipNet.Mask.Size()
+ maskBits := ipNet.Bits()
var rules *allowListRules
- if maskSize == 32 {
+ if ipNet.Addr().Is4() {
rules = &rules4
} else {
rules = &rules6
@@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
if !rules4.defaultSet {
if rules4.allValuesMatch {
- _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
- tree.AddCIDR(zeroCIDR, !rules4.allValues)
+ tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
@@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
if !rules6.defaultSet {
if rules6.allValuesMatch {
- _, zeroCIDR, _ := net.ParseCIDR("::/0")
- tree.AddCIDR(zeroCIDR, !rules6.allValues)
+ tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
@@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
return nameRules, nil
}
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
+func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
- remoteAllowRanges := cidr.NewTree6[*AllowList]()
+ remoteAllowRanges := new(bart.Table[*AllowList])
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
@@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error
return nil, err
}
- _, ipNet, err := net.ParseCIDR(rawCIDR)
+ ipNet, err := netip.ParsePrefix(rawCIDR)
if err != nil {
- return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+ return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
}
- remoteAllowRanges.AddCIDR(ipNet, allowList)
+ remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
}
return remoteAllowRanges, nil
}
-func (al *AllowList) Allow(ip net.IP) bool {
- if al == nil {
- return true
- }
-
- _, result := al.cidrTree.MostSpecificContains(ip)
- return result
-}
-
-func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
- if al == nil {
- return true
- }
-
- _, result := al.cidrTree.MostSpecificContainsIpV4(ip)
- return result
-}
-
-func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
+func (al *AllowList) Allow(ip netip.Addr) bool {
if al == nil {
return true
}
- _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+ result, _ := al.cidrTree.Lookup(ip)
return result
}
-func (al *LocalAllowList) Allow(ip net.IP) bool {
+func (al *LocalAllowList) Allow(ip netip.Addr) bool {
if al == nil {
return true
}
@@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool {
return !al.nameRules[0].Allow
}
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
+func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
if al == nil {
return true
}
return al.AllowList.Allow(ip)
}
-func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
+func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) {
return false
}
return al.AllowList.Allow(ip)
}
-func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
- if al == nil {
- return true
- }
- if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
- return false
- }
- return al.AllowList.AllowIpV4(ip)
-}
-
-func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
- if al == nil {
- return true
- }
- if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
- return false
- }
- return al.AllowList.AllowIpV6(hi, lo)
-}
-
-func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
if al.insideAllowLists != nil {
- ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+ inside, ok := al.insideAllowLists.Lookup(vpnIp)
if ok {
return inside
}
diff --git a/allow_list_test.go b/allow_list_test.go
index 334cb6062..c8b3d08af 100644
--- a/allow_list_test.go
+++ b/allow_list_test.go
@@ -1,11 +1,11 @@
package nebula
import (
- "net"
+ "net/netip"
"regexp"
"testing"
- "github.com/slackhq/nebula/cidr"
+ "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
@@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"192.168.0.0": true,
}
r, err := newAllowListFromConfig(c, "allowlist", nil)
- assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+ assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
@@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) {
}
func TestAllowList_Allow(t *testing.T) {
- assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
-
- tree := cidr.NewTree6[bool]()
- tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
- tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
- tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
- tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
- tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
- tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
- tree.AddCIDR(cidr.Parse("::1/128"), true)
- tree.AddCIDR(cidr.Parse("::2/128"), false)
+ assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
+
+ tree := new(bart.Table[bool])
+ tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
+ tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
+ tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
+ tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
+ tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
+ tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
+ tree.Insert(netip.MustParsePrefix("::1/128"), true)
+ tree.Insert(netip.MustParsePrefix("::2/128"), false)
al := &AllowList{cidrTree: tree}
- assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
- assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
- assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
- assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
- assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
- assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
- assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
+ assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
+ assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
+ assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
}
func TestLocalAllowList_AllowName(t *testing.T) {
diff --git a/calculated_remote.go b/calculated_remote.go
index 38f5bea25..ae2ed500c 100644
--- a/calculated_remote.go
+++ b/calculated_remote.go
@@ -1,41 +1,36 @@
package nebula
import (
+ "encoding/binary"
"fmt"
"math"
"net"
+ "net/netip"
"strconv"
- "github.com/slackhq/nebula/cidr"
+ "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
// This allows us to "guess" what the remote might be for a host while we wait
// for the lighthouse response. See "lighthouse.calculated_remotes" in the
// example config file.
type calculatedRemote struct {
- ipNet net.IPNet
- maskIP iputil.VpnIp
- mask iputil.VpnIp
- port uint32
+ ipNet netip.Prefix
+ mask netip.Prefix
+ port uint32
}
-func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
- // Ensure this is an IPv4 mask that we expect
- ones, bits := ipNet.Mask.Size()
- if ones == 0 || bits != 32 {
- return nil, fmt.Errorf("invalid mask: %v", ipNet)
- }
+func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+ masked := maskCidr.Masked()
if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port)
}
return &calculatedRemote{
- ipNet: *ipNet,
- maskIP: iputil.Ip2VpnIp(ipNet.IP),
- mask: iputil.Ip2VpnIp(ipNet.Mask),
- port: uint32(port),
+ ipNet: maskCidr,
+ mask: masked,
+ port: uint32(port),
}, nil
}
@@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
}
-func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
+func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes
// of the overlay IP
- masked := (c.maskIP & c.mask) | (ip & ^c.mask)
+ if c.ipNet.Addr().Is4() {
+ return c.apply4(ip)
+ }
+ return c.apply6(ip)
+}
+
+func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
+ //TODO: IPV6-WORK this can be less crappy
+ maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+ mask := binary.BigEndian.Uint32(maskb[:])
+
+ b := c.mask.Addr().As4()
+ maskIp := binary.BigEndian.Uint32(b[:])
+
+ b = ip.As4()
+ intIp := binary.BigEndian.Uint32(b[:])
+
+ return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+}
- return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
+func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
+ //TODO: IPV6-WORK
+ panic("Can not calculate ipv6 remote addresses")
}
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
- calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
+ calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[any]any)
if !ok {
@@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
- _, ipNet, err := net.ParseCIDR(rawCIDR)
+ cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
+ //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
entry, err := newCalculatedRemotesListFromConfig(rawValue)
if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
}
- calculatedRemotes.AddCIDR(ipNet, entry)
+ calculatedRemotes.Insert(cidr, entry)
}
return calculatedRemotes, nil
@@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
if !ok {
return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
}
- _, ipNet, err := net.ParseCIDR(rawMask)
+ maskCidr, err := netip.ParsePrefix(rawMask)
if err != nil {
return nil, fmt.Errorf("invalid mask: %s", rawMask)
}
@@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
}
- return newCalculatedRemote(ipNet, port)
+ return newCalculatedRemote(maskCidr, port)
}
diff --git a/calculated_remote_test.go b/calculated_remote_test.go
index 2ddebca74..6ff1cb0bd 100644
--- a/calculated_remote_test.go
+++ b/calculated_remote_test.go
@@ -1,27 +1,25 @@
package nebula
import (
- "net"
+ "net/netip"
"testing"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCalculatedRemoteApply(t *testing.T) {
- _, ipNet, err := net.ParseCIDR("192.168.1.0/24")
+ ipNet, err := netip.ParsePrefix("192.168.1.0/24")
require.NoError(t, err)
c, err := newCalculatedRemote(ipNet, 4242)
require.NoError(t, err)
- input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
+ input, err := netip.ParseAddr("10.0.10.182")
+ assert.NoError(t, err)
- expected := &Ip4AndPort{
- Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
- Port: 4242,
- }
+ expected, err := netip.ParseAddr("192.168.1.182")
+ assert.NoError(t, err)
- assert.Equal(t, expected, c.Apply(input))
+ assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
}
diff --git a/cidr/parse.go b/cidr/parse.go
deleted file mode 100644
index 74367f6e8..000000000
--- a/cidr/parse.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package cidr
-
-import "net"
-
-// Parse is a convenience function that returns only the IPNet
-// This function ignores errors since it is primarily a test helper, the result could be nil
-func Parse(s string) *net.IPNet {
- _, c, _ := net.ParseCIDR(s)
- return c
-}
diff --git a/cidr/tree4.go b/cidr/tree4.go
deleted file mode 100644
index c5ebe54a7..000000000
--- a/cidr/tree4.go
+++ /dev/null
@@ -1,203 +0,0 @@
-package cidr
-
-import (
- "net"
-
- "github.com/slackhq/nebula/iputil"
-)
-
-type Node[T any] struct {
- left *Node[T]
- right *Node[T]
- parent *Node[T]
- hasValue bool
- value T
-}
-
-type entry[T any] struct {
- CIDR *net.IPNet
- Value T
-}
-
-type Tree4[T any] struct {
- root *Node[T]
- list []entry[T]
-}
-
-const (
- startbit = iputil.VpnIp(0x80000000)
-)
-
-func NewTree4[T any]() *Tree4[T] {
- tree := new(Tree4[T])
- tree.root = &Node[T]{}
- tree.list = []entry[T]{}
- return tree
-}
-
-func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
- bit := startbit
- node := tree.root
- next := tree.root
-
- ip := iputil.Ip2VpnIp(cidr.IP)
- mask := iputil.Ip2VpnIp(cidr.Mask)
-
- // Find our last ancestor in the tree
- for bit&mask != 0 {
- if ip&bit != 0 {
- next = node.right
- } else {
- next = node.left
- }
-
- if next == nil {
- break
- }
-
- bit = bit >> 1
- node = next
- }
-
- // We already have this range so update the value
- if next != nil {
- addCIDR := cidr.String()
- for i, v := range tree.list {
- if addCIDR == v.CIDR.String() {
- tree.list = append(tree.list[:i], tree.list[i+1:]...)
- break
- }
- }
-
- tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
- node.value = val
- node.hasValue = true
- return
- }
-
- // Build up the rest of the tree we don't already have
- for bit&mask != 0 {
- next = &Node[T]{}
- next.parent = node
-
- if ip&bit != 0 {
- node.right = next
- } else {
- node.left = next
- }
-
- bit >>= 1
- node = next
- }
-
- // Final node marks our cidr, set the value
- node.value = val
- node.hasValue = true
- tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-}
-
-// Contains finds the first match, which may be the least specific
-func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
- bit := startbit
- node := tree.root
-
- for node != nil {
- if node.hasValue {
- return true, node.value
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
-
- }
-
- return false, value
-}
-
-// MostSpecificContains finds the most specific match
-func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
- bit := startbit
- node := tree.root
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- return ok, value
-}
-
-type eachFunc[T any] func(T) bool
-
-// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
-// The final return value will be true if the provided function returned true
-func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
- bit := startbit
- node := tree.root
-
- for node != nil {
- if node.hasValue {
- // If the each func returns true then we can exit the loop
- if each(node.value) {
- return true
- }
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- return false
-}
-
-// GetCIDR returns the entry added by the most recent matching AddCIDR call
-func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
- bit := startbit
- node := tree.root
-
- ip := iputil.Ip2VpnIp(cidr.IP)
- mask := iputil.Ip2VpnIp(cidr.Mask)
-
- // Find our last ancestor in the tree
- for node != nil && bit&mask != 0 {
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit = bit >> 1
- }
-
- if bit&mask == 0 && node != nil {
- value = node.value
- ok = node.hasValue
- }
-
- return ok, value
-}
-
-// List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4[T]) List() []entry[T] {
- return tree.list
-}
diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go
deleted file mode 100644
index cd17be4dc..000000000
--- a/cidr/tree4_test.go
+++ /dev/null
@@ -1,170 +0,0 @@
-package cidr
-
-import (
- "net"
- "testing"
-
- "github.com/slackhq/nebula/iputil"
- "github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_List(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/16"), "1")
- tree.AddCIDR(Parse("1.0.0.0/8"), "2")
- tree.AddCIDR(Parse("1.0.0.0/16"), "3")
- tree.AddCIDR(Parse("1.0.0.0/16"), "4")
- list := tree.List()
- assert.Len(t, list, 2)
- assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
- assert.Equal(t, "2", list[0].Value)
- assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
- assert.Equal(t, "4", list[1].Value)
-}
-
-func TestCIDRTree_Contains(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
- tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "1", "1.0.0.0"},
- {true, "1", "1.255.255.255"},
- {true, "2", "2.1.0.0"},
- {true, "2", "2.1.255.255"},
- {true, "3", "3.1.1.0"},
- {true, "3", "3.1.1.255"},
- {true, "4a", "4.1.1.255"},
- {true, "4a", "4.1.1.1"},
- {true, "5", "240.0.0.0"},
- {true, "5", "255.255.255.255"},
- {false, "", "239.0.0.0"},
- {false, "", "4.1.2.2"},
- }
-
- for _, tt := range tests {
- ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-
- tree = NewTree4[string]()
- tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
- ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "1", "1.0.0.0"},
- {true, "1", "1.255.255.255"},
- {true, "2", "2.1.0.0"},
- {true, "2", "2.1.255.255"},
- {true, "3", "3.1.1.0"},
- {true, "3", "3.1.1.255"},
- {true, "4a", "4.1.1.255"},
- {true, "4b", "4.1.1.2"},
- {true, "4c", "4.1.1.1"},
- {true, "5", "240.0.0.0"},
- {true, "5", "255.255.255.255"},
- {false, "", "239.0.0.0"},
- {false, "", "4.1.2.2"},
- }
-
- for _, tt := range tests {
- ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-
- tree = NewTree4[string]()
- tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
- ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-}
-
-func TestTree4_GetCIDR(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
- tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
- tests := []struct {
- Found bool
- Result interface{}
- IPNet *net.IPNet
- }{
- {true, "1", Parse("1.0.0.0/8")},
- {true, "2", Parse("2.1.0.0/16")},
- {true, "3", Parse("3.1.1.0/24")},
- {true, "4a", Parse("4.1.1.0/24")},
- {true, "4b", Parse("4.1.1.1/32")},
- {true, "4c", Parse("4.1.2.1/32")},
- {true, "5", Parse("254.0.0.0/4")},
- {false, "", Parse("2.0.0.0/8")},
- }
-
- for _, tt := range tests {
- ok, r := tree.GetCIDR(tt.IPNet)
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.1.0.0/16"), "1")
- tree.AddCIDR(Parse("1.2.1.1/32"), "1")
- tree.AddCIDR(Parse("192.2.1.1/32"), "1")
- tree.AddCIDR(Parse("172.2.1.1/32"), "1")
-
- ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
- b.Run("found", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- tree.Contains(ip)
- }
- })
-
- ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
- b.Run("not found", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- tree.Contains(ip)
- }
- })
-}
diff --git a/cidr/tree6.go b/cidr/tree6.go
deleted file mode 100644
index 3f2cd2a48..000000000
--- a/cidr/tree6.go
+++ /dev/null
@@ -1,189 +0,0 @@
-package cidr
-
-import (
- "net"
-
- "github.com/slackhq/nebula/iputil"
-)
-
-const startbit6 = uint64(1 << 63)
-
-type Tree6[T any] struct {
- root4 *Node[T]
- root6 *Node[T]
-}
-
-func NewTree6[T any]() *Tree6[T] {
- tree := new(Tree6[T])
- tree.root4 = &Node[T]{}
- tree.root6 = &Node[T]{}
- return tree
-}
-
-func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
- var node, next *Node[T]
-
- cidrIP, ipv4 := isIPV4(cidr.IP)
- if ipv4 {
- node = tree.root4
- next = tree.root4
-
- } else {
- node = tree.root6
- next = tree.root6
- }
-
- for i := 0; i < len(cidrIP); i += 4 {
- ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
- mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
- bit := startbit
-
- // Find our last ancestor in the tree
- for bit&mask != 0 {
- if ip&bit != 0 {
- next = node.right
- } else {
- next = node.left
- }
-
- if next == nil {
- break
- }
-
- bit = bit >> 1
- node = next
- }
-
- // Build up the rest of the tree we don't already have
- for bit&mask != 0 {
- next = &Node[T]{}
- next.parent = node
-
- if ip&bit != 0 {
- node.right = next
- } else {
- node.left = next
- }
-
- bit >>= 1
- node = next
- }
- }
-
- // Final node marks our cidr, set the value
- node.value = val
- node.hasValue = true
-}
-
-// Finds the most specific match
-func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
- var node *Node[T]
-
- wholeIP, ipv4 := isIPV4(ip)
- if ipv4 {
- node = tree.root4
- } else {
- node = tree.root6
- }
-
- for i := 0; i < len(wholeIP); i += 4 {
- ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
- bit := startbit
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if bit == 0 {
- break
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
- }
-
- return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
- bit := startbit
- node := tree.root4
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
- ip := hi
- node := tree.root6
-
- for i := 0; i < 2; i++ {
- bit := startbit6
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if bit == 0 {
- break
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- ip = lo
- }
-
- return ok, value
-}
-
-func isIPV4(ip net.IP) (net.IP, bool) {
- if len(ip) == net.IPv4len {
- return ip, true
- }
-
- if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
- return ip[12:16], true
- }
-
- return ip, false
-}
-
-func isZeros(p net.IP) bool {
- for i := 0; i < len(p); i++ {
- if p[i] != 0 {
- return false
- }
- }
- return true
-}
diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go
deleted file mode 100644
index eb159ec74..000000000
--- a/cidr/tree6_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package cidr
-
-import (
- "encoding/binary"
- "net"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
- tree := NewTree6[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "1", "1.0.0.0"},
- {true, "1", "1.255.255.255"},
- {true, "2", "2.1.0.0"},
- {true, "2", "2.1.255.255"},
- {true, "3", "3.1.1.0"},
- {true, "3", "3.1.1.255"},
- {true, "4a", "4.1.1.255"},
- {true, "4b", "4.1.1.2"},
- {true, "4c", "4.1.1.1"},
- {true, "5", "240.0.0.0"},
- {true, "5", "255.255.255.255"},
- {true, "6a", "1:2:0:4:1:1:1:1"},
- {true, "6b", "1:2:0:4:5:1:1:1"},
- {true, "6c", "1:2:0:4:5:0:0:0"},
- {false, "", "239.0.0.0"},
- {false, "", "4.1.2.2"},
- }
-
- for _, tt := range tests {
- ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-
- tree = NewTree6[string]()
- tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
- tree.AddCIDR(Parse("::/0"), "cool6")
- ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.MostSpecificContains(net.ParseIP("::"))
- assert.True(t, ok)
- assert.Equal(t, "cool6", r)
-
- ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
- assert.True(t, ok)
- assert.Equal(t, "cool6", r)
-}
-
-func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
- tree := NewTree6[string]()
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "6a", "1:2:0:4:1:1:1:1"},
- {true, "6b", "1:2:0:4:5:1:1:1"},
- {true, "6c", "1:2:0:4:5:0:0:0"},
- }
-
- for _, tt := range tests {
- ip := net.ParseIP(tt.IP)
- hi := binary.BigEndian.Uint64(ip[:8])
- lo := binary.BigEndian.Uint64(ip[8:])
-
- ok, r := tree.MostSpecificContainsIpV6(hi, lo)
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-}
diff --git a/connection_manager.go b/connection_manager.go
index 0b277b5c1..d2e861647 100644
--- a/connection_manager.go
+++ b/connection_manager.go
@@ -3,6 +3,8 @@ package nebula
import (
"bytes"
"context"
+ "encoding/binary"
+ "net/netip"
"sync"
"time"
@@ -10,8 +12,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
type trafficDecision int
@@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
var index uint32
- var relayFrom iputil.VpnIp
- var relayTo iputil.VpnIp
+ var relayFrom netip.Addr
+ var relayTo netip.Addr
switch {
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
@@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex
switch r.Type {
case TerminalType:
- relayFrom = n.intf.myVpnIp
+ relayFrom = n.intf.myVpnNet.Addr()
relayTo = existing.PeerIp
case ForwardingType:
relayFrom = existing.PeerIp
@@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
}
switch r.Type {
case TerminalType:
- relayFrom = n.intf.myVpnIp
+ relayFrom = n.intf.myVpnNet.Addr()
relayTo = r.PeerIp
case ForwardingType:
relayFrom = r.PeerIp
@@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
}
}
+ //TODO: IPV6-WORK
+ relayFromB := relayFrom.As4()
+ relayToB := relayTo.As4()
+
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
- RelayFromIp: uint32(relayFrom),
- RelayToIp: uint32(relayTo),
+ RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
}
msg, err := req.Marshal()
if err != nil {
@@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
} else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(req.RelayFromIp),
- "relayTo": iputil.VpnIp(req.RelayToIp),
+ "relayFrom": req.RelayFromIp,
+ "relayTo": req.RelayToIp,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}).
@@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
- if current.vpnIp < n.intf.myVpnIp {
+ if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
@@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
if n.punchy.GetTargetEverything() {
- hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
+ hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
- } else if hostinfo.remote != nil {
+ } else if hostinfo.remote.IsValid() {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
diff --git a/connection_manager_test.go b/connection_manager_test.go
index f50bcf862..5f97cad9d 100644
--- a/connection_manager_test.go
+++ b/connection_manager_test.go
@@ -5,28 +5,26 @@ import (
"crypto/ed25519"
"crypto/rand"
"net"
+ "net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
-var vpnIp iputil.VpnIp
-
func newTestLighthouse() *LightHouse {
lh := &LightHouse{
l: test.NewLogger(),
- addrMap: map[iputil.VpnIp]*RemoteList{},
- queryChan: make(chan iputil.VpnIp, 10),
+ addrMap: map[netip.Addr]*RemoteList{},
+ queryChan: make(chan netip.Addr, 10),
}
- lighthouses := map[iputil.VpnIp]struct{}{}
- staticList := map[iputil.VpnIp]struct{}{}
+ lighthouses := map[netip.Addr]struct{}{}
+ staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
lh.staticList.Store(&staticList)
@@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
- preferredRanges := []*net.IPNet{localrange}
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ vpnIp := netip.MustParseAddr("172.1.1.2")
+ preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
@@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- preferredRanges := []*net.IPNet{localrange}
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ vpnIp := netip.MustParseAddr("172.1.1.2")
+ preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
@@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0},
}
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- preferredRanges := []*net.IPNet{localrange}
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ vpnIp := netip.MustParseAddr("172.1.1.2")
+ preferredRanges := []netip.Prefix{localrange}
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
diff --git a/control.go b/control.go
index c227b207b..3468b3536 100644
--- a/control.go
+++ b/control.go
@@ -2,7 +2,7 @@ package nebula
import (
"context"
- "net"
+ "net/netip"
"os"
"os/signal"
"syscall"
@@ -10,9 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
- "github.com/slackhq/nebula/udp"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -21,10 +19,10 @@ import (
type controlEach func(h *HostInfo)
type controlHostLister interface {
- QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+ QueryVpnIp(vpnIp netip.Addr) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
- GetPreferredRanges() []*net.IPNet
+ GetPreferredRanges() []netip.Prefix
}
type Control struct {
@@ -39,15 +37,15 @@ type Control struct {
}
type ControlHostInfo struct {
- VpnIp net.IP `json:"vpnIp"`
+ VpnIp netip.Addr `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
- RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
+ RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
- CurrentRemote *udp.Addr `json:"currentRemote"`
- CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"`
- CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"`
+ CurrentRemote netip.AddrPort `json:"currentRemote"`
+ CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
+ CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -131,8 +129,45 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
}
}
+// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
+func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
+ if c.f.myVpnNet.Addr() == vpnIp {
+ return c.f.pki.GetCertState().Certificate
+ }
+ hi := c.f.hostMap.QueryVpnIp(vpnIp)
+ if hi == nil {
+ return nil
+ }
+ return hi.GetCert()
+}
+
+// CreateTunnel creates a new tunnel to the given vpn ip.
+func (c *Control) CreateTunnel(vpnIp netip.Addr) {
+ c.f.handshakeManager.StartHandshake(vpnIp, nil)
+}
+
+// PrintTunnel creates a new tunnel to the given vpn ip.
+func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
+ hi := c.f.hostMap.QueryVpnIp(vpnIp)
+ if hi == nil {
+ return nil
+ }
+ chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges())
+ return &chi
+}
+
+// QueryLighthouse queries the lighthouse.
+func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
+ hi := c.f.lightHouse.Query(vpnIp)
+ if hi == nil {
+ return nil
+ }
+ return hi.CopyCache()
+}
+
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister
if pending {
hl = c.f.handshakeManager
@@ -150,19 +185,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return nil
}
- hostInfo.SetRemote(addr.Copy())
+ hostInfo.SetRemote(addr)
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return false
@@ -205,7 +242,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
}
// Learn which hosts are being used as relays, so we can shut them down last.
- relayingHosts := map[iputil.VpnIp]*HostInfo{}
+ relayingHosts := map[netip.Addr]*HostInfo{}
// Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays {
@@ -236,15 +273,16 @@ func (c *Control) Device() overlay.Device {
return c.f.inside
}
-func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
+func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{
- VpnIp: h.vpnIp.ToIP(),
+ VpnIp: h.vpnIp,
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CurrentRelaysToMe: h.relayState.CopyRelayIps(),
CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
+ CurrentRemote: h.remote,
}
if h.ConnectionState != nil {
@@ -255,10 +293,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi.Cert = c.Copy()
}
- if h.remote != nil {
- chi.CurrentRemote = h.remote.Copy()
- }
-
return chi
}
diff --git a/control_test.go b/control_test.go
index c64a3a4b7..fbf29c060 100644
--- a/control_test.go
+++ b/control_test.go
@@ -2,15 +2,14 @@ package nebula
import (
"net"
+ "net/netip"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
- "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
@@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
- hm := newHostMap(l, &net.IPNet{})
- hm.preferredRanges.Store(&[]*net.IPNet{})
+ hm := newHostMap(l, netip.Prefix{})
+ hm.preferredRanges.Store(&[]netip.Prefix{})
+
+ remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
+ remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
- remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
- remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
- IP: net.IPv4(1, 2, 3, 4),
+ IP: remote1.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
- IP: net.ParseIP("1:2:3:4:5:6:7:8"),
+ IP: remote2.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
}
@@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}
remotes := NewRemoteList(nil)
- remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
- remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
+ remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
+ remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+
+ vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
+ assert.True(t, ok)
+
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
@@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: vpnIp,
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
+ vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
+ assert.True(t, ok)
+
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
@@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
- vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
+ vpnIp: vpnIp2,
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
@@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(),
}
- thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
+ thi := c.GetHostInfoByVpnIp(vpnIp, false)
expectedInfo := ControlHostInfo{
- VpnIp: net.IPv4(1, 2, 3, 4).To4(),
+ VpnIp: vpnIp,
LocalIndex: 201,
RemoteIndex: 200,
- RemoteAddrs: []*udp.Addr{remote2, remote1},
+ RemoteAddrs: []netip.AddrPort{remote2, remote1},
Cert: crt.Copy(),
MessageCounter: 0,
- CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
- CurrentRelaysToMe: []iputil.VpnIp{},
- CurrentRelaysThroughMe: []iputil.VpnIp{},
+ CurrentRemote: remote1,
+ CurrentRelaysToMe: []netip.Addr{},
+ CurrentRelaysThroughMe: []netip.Addr{},
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
- test.AssertDeepCopyEqual(t, &expectedInfo, thi)
+ assert.EqualValues(t, &expectedInfo, thi)
+ //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
+ //test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
- thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
+ thi = c.GetHostInfoByVpnIp(vpnIp2, false)
})
}
diff --git a/control_tester.go b/control_tester.go
index b786ba383..d46540f04 100644
--- a/control_tester.go
+++ b/control_tester.go
@@ -4,14 +4,13 @@
package nebula
import (
- "net"
+ "net/netip"
"github.com/slackhq/nebula/cert"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
@@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse
-func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
+func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock()
- remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+ remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
- iVpnIp := iputil.Ip2VpnIp(vpnIp)
- if v4 := toAddr.IP.To4(); v4 != nil {
- remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+ if toAddr.Addr().Is4() {
+ remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
} else {
- remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+ remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
}
}
// InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
// This is necessary to inform an initiator of possible relays for communicating with a responder
-func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) {
+func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock()
- remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+ remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
- iVpnIp := iputil.Ip2VpnIp(vpnIp)
- uVpnIp := []uint32{}
- for _, rVPnIp := range relayVpnIps {
- uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
- }
-
- remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
+ remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
}
// GetFromTun will pull a packet off the tun side of nebula
@@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
+func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
+ //TODO: IPV6-WORK
ip := layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
- SrcIP: c.f.inside.Cidr().IP,
- DstIP: toIp,
+ SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
+ DstIP: toIp.Unmap().AsSlice(),
}
udp := layers.UDP{
@@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
}
-func (c *Control) GetVpnIp() iputil.VpnIp {
- return c.f.myVpnIp
+func (c *Control) GetVpnIp() netip.Addr {
+ return c.f.myVpnNet.Addr()
}
-func (c *Control) GetUDPAddr() string {
- return c.f.outside.(*udp.TesterConn).Addr.String()
+func (c *Control) GetUDPAddr() netip.AddrPort {
+ return c.f.outside.(*udp.TesterConn).Addr
}
-func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
- hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
+ hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
if hostinfo == nil {
return false
}
@@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.pki.GetCertState().Certificate
}
-func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+func (c *Control) ReHandshake(vpnIp netip.Addr) {
c.f.handshakeManager.StartHandshake(vpnIp, nil)
}
diff --git a/dns_server.go b/dns_server.go
index 4e7bb83af..5fea65c47 100644
--- a/dns_server.go
+++ b/dns_server.go
@@ -3,6 +3,7 @@ package nebula
import (
"fmt"
"net"
+ "net/netip"
"strconv"
"strings"
"sync"
@@ -10,7 +11,6 @@ import (
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
// This whole thing should be rewritten to use context
@@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string {
}
func (d *dnsRecords) QueryCert(data string) string {
- ip := net.ParseIP(data[:len(data)-1])
- if ip == nil {
+ ip, err := netip.ParseAddr(data[:len(data)-1])
+ if err != nil {
return ""
}
- iip := iputil.Ip2VpnIp(ip)
- hostinfo := d.hostMap.QueryVpnIp(iip)
+
+ hostinfo := d.hostMap.QueryVpnIp(ip)
if hostinfo == nil {
return ""
}
+
q := hostinfo.GetCert()
if q == nil {
return ""
}
+
cert := q.Details
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
return c
@@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
- b := net.ParseIP(a)
+ b, err := netip.ParseAddr(a)
+ if err != nil {
+ return
+ }
+
// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go
index 59f1d0e52..3d42a560c 100644
--- a/e2e/handshakes_test.go
+++ b/e2e/handshakes_test.go
@@ -5,7 +5,7 @@ package e2e
import (
"fmt"
- "net"
+ "net/netip"
"testing"
"time"
@@ -13,19 +13,18 @@ import (
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
func BenchmarkHotPath(b *testing.B) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
@@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs()
for n := 0; n < b.N; n++ {
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
@@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) {
}
func TestGoodHandshake(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) {
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
- assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
+ assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
- assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
@@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) {
}
func TestWrongResponderHandshake(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
// The IPs here are chosen on purpose:
// The current remote handling will sort by preference, public, and then lexically.
// So we need them to have a higher address than evil (we could apply a preference though)
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
- evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
+ evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
// Add their real udp addr, which should be tried after evil.
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl)
@@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{}
err := h.Parse(p.Data)
@@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) {
panic(err)
}
- if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+ if p.To == theirUdpAddr && h.Type == 1 {
return router.RouteAndExit
}
@@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
- assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Test the tunnel with them")
- assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers")
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
- assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
- assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
+ assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
+ assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone
@@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) {
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
// But will eventually collapse down to a single tunnel
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse and vice versa
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
@@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) {
r.Log("Route until they receive a message packet")
myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Their cached packet should be received by me")
theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
- assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
r.Log("Do a bidirectional tunnel test")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myHostmapHosts := myControl.ListHostmapHosts(false)
myHostmapIndexes := myControl.ListHostmapIndexes(false)
@@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) {
r.Log("Spin until connection manager tears down a tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
@@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) {
}
func TestUncleanShutdownRaceLoser(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Nuke my hostmap")
myHostmap := myControl.GetHostmap()
- myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+ myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
p = r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Assert the tunnel works")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(theirControl.GetHostmap().Indexes)
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
if len(theirControl.GetHostmap().Indexes) < start {
break
}
@@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
}
func TestUncleanShutdownRaceWinner(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
r.Log("Nuke my hostmap")
theirHostmap := theirControl.GetHostmap()
- theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+ theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
p = r.RouteForAllUntilTxTun(myControl)
- assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
r.Log("Assert the tunnel works")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(myControl.GetHostmap().Indexes)
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
if len(myControl.GetHostmap().Indexes) < start {
break
}
@@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
}
func TestRelays(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -361,31 +360,31 @@ func TestRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
}
func TestStage1RaceRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) {
theirControl.Start()
r.Log("Get a tunnel between me and relay")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl)
@@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) {
func TestStage1RaceRelays2(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
@@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) {
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
retries--
@@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
myControl.Stop()
theirControl.Stop()
@@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) {
//
////TODO: assert hostmaps
}
+
func TestRehandshakingRelays(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) {
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
- c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+ c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
@@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) {
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+ c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
@@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) {
}
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) {
func TestRehandshakingRelaysPrimary(t *testing.T) {
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
- c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+ c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
@@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+ c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
@@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
}
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
}
func TestRehandshaking(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) {
theirControl.Start()
t.Log("Stand up a tunnel between me and them")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) {
myConfig.ReloadConfigString(string(rc))
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+ c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
break
@@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) {
r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+ c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
assert.Contains(t, c.Cert.Details.Groups, "new group")
// We should only have a single tunnel now on both sides
@@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) {
func TestRehandshakingLoser(t *testing.T) {
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
// Should be the one with the new certificate
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) {
theirControl.Start()
t.Log("Stand up a tunnel between me and them")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
- tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
- tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+ tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+ tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) {
theirConfig.ReloadConfigString(string(rc))
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
- theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+ theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
if theirNewGroup {
@@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) {
r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won
- theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+ theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
// We should only have a single tunnel now on both sides
@@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) {
// This test forces stage 1, stage 2, stage 1 to be received by me from them
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
// caused a cross-linked hostinfo
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Start the servers
myControl.Start()
@@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true)
@@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
t.Log("Make sure the tunnel still works")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
diff --git a/e2e/helpers.go b/e2e/helpers.go
index 13146ab71..71df805f8 100644
--- a/e2e/helpers.go
+++ b/e2e/helpers.go
@@ -4,6 +4,7 @@ import (
"crypto/rand"
"io"
"net"
+ "net/netip"
"time"
"github.com/slackhq/nebula/cert"
@@ -12,7 +13,7 @@ import (
)
// NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
@@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
}
if len(ips) > 0 {
- nc.Details.Ips = ips
+ nc.Details.Ips = make([]*net.IPNet, len(ips))
+ for i, ip := range ips {
+ nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+ }
}
if len(subnets) > 0 {
- nc.Details.Subnets = subnets
+ nc.Details.Subnets = make([]*net.IPNet, len(subnets))
+ for i, ip := range subnets {
+ nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+ }
}
if len(groups) > 0 {
@@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
issuer, err := ca.Sha256Sum()
if err != nil {
panic(err)
@@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
}
pub, rawPriv := x25519Keypair()
-
+ ipb := ip.Addr().AsSlice()
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
- Name: name,
- Ips: []*net.IPNet{ip},
- Subnets: subnets,
+ Name: name,
+ Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
+ //Subnets: subnets,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go
index b05c84a22..527f55bc7 100644
--- a/e2e/helpers_test.go
+++ b/e2e/helpers_test.go
@@ -6,7 +6,7 @@ package e2e
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"testing"
"time"
@@ -19,7 +19,6 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
@@ -27,15 +26,23 @@ import (
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
- vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
- copy(vpnIpNet.IP, udpIp)
- vpnIpNet.IP[1] += 128
- udpAddr := net.UDPAddr{
- IP: udpIp,
- Port: 4242,
+ vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
+ if err != nil {
+ panic(err)
+ }
+
+ var udpAddr netip.AddrPort
+ if vpnIpNet.Addr().Is4() {
+ budpIp := vpnIpNet.Addr().As4()
+ budpIp[1] -= 128
+ udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+ } else {
+ budpIp := vpnIpNet.Addr().As16()
+ budpIp[13] -= 128
+ udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
@@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
// "try_interval": "1s",
//},
"listen": m{
- "host": udpAddr.IP.String(),
- "port": udpAddr.Port,
+ "host": udpAddr.Addr().String(),
+ "port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
@@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
panic(err)
}
- return control, vpnIpNet, &udpAddr, c
+ return control, vpnIpNet, udpAddr, c
}
type doneCb func()
@@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
}
}
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
-func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
// Get both host infos
- hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
+ hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
- hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
+ hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
// Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
- assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
- assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
-
- assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
- assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
+ assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
+ assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
// Check that our indexes match
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
@@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB
//checkIndexes("hmB", hmB, hAinB)
}
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
+func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")
- assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
- assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
+ assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
+ assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go
index 120be6960..c14ab2e77 100644
--- a/e2e/router/hostmap.go
+++ b/e2e/router/hostmap.go
@@ -5,11 +5,11 @@ package router
import (
"fmt"
+ "net/netip"
"sort"
"strings"
"github.com/slackhq/nebula"
- "github.com/slackhq/nebula/iputil"
)
type edge struct {
@@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
return r, globalLines
}
-func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
- keys := make([]iputil.VpnIp, 0, len(hosts))
+func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
+ keys := make([]netip.Addr, 0, len(hosts))
for key := range hosts {
keys = append(keys, key)
}
sort.SliceStable(keys, func(i, j int) bool {
- return keys[i] > keys[j]
+ return keys[i].Compare(keys[j]) > 0
})
return keys
diff --git a/e2e/router/router.go b/e2e/router/router.go
index 730853a99..08905705c 100644
--- a/e2e/router/router.go
+++ b/e2e/router/router.go
@@ -6,12 +6,11 @@ package router
import (
"context"
"fmt"
- "net"
+ "net/netip"
"os"
"path/filepath"
"reflect"
"sort"
- "strconv"
"strings"
"sync"
"testing"
@@ -21,7 +20,6 @@ import (
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"golang.org/x/exp/maps"
)
@@ -29,18 +27,18 @@ import (
type R struct {
// Simple map of the ip:port registered on a control to the control
// Basically a router, right?
- controls map[string]*nebula.Control
+ controls map[netip.AddrPort]*nebula.Control
// A map for inbound packets for a control that doesn't know about this address
- inNat map[string]*nebula.Control
+ inNat map[netip.AddrPort]*nebula.Control
// A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
- outNat map[string]net.UDPAddr
+ outNat map[string]netip.AddrPort
// A map of vpn ip to the nebula control it belongs to
- vpnControls map[iputil.VpnIp]*nebula.Control
+ vpnControls map[netip.Addr]*nebula.Control
ignoreFlows []ignoreFlow
flow []flowEntry
@@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
}
r := &R{
- controls: make(map[string]*nebula.Control),
- vpnControls: make(map[iputil.VpnIp]*nebula.Control),
- inNat: make(map[string]*nebula.Control),
- outNat: make(map[string]net.UDPAddr),
+ controls: make(map[netip.AddrPort]*nebula.Control),
+ vpnControls: make(map[netip.Addr]*nebula.Control),
+ inNat: make(map[netip.AddrPort]*nebula.Control),
+ outNat: make(map[string]netip.AddrPort),
flow: []flowEntry{},
ignoreFlows: []ignoreFlow{},
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
for _, c := range controls {
addr := c.GetUDPAddr()
if _, ok := r.controls[addr]; ok {
- panic("Duplicate listen address: " + addr)
+ panic("Duplicate listen address: " + addr.String())
}
r.vpnControls[c.GetVpnIp()] = c
@@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
// It does not look at the addr attached to the instance.
// If a route is used, this will behave like a NAT for the return path.
// Rewriting the source ip:port to what was last sent to from the origin
-func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
+func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
r.Lock()
defer r.Unlock()
- inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
+ inAddr := netip.AddrPortFrom(ip, port)
if _, ok := r.inNat[inAddr]; ok {
- panic("Duplicate listen address inNat: " + inAddr)
+ panic("Duplicate listen address inNat: " + inAddr.String())
}
r.inNat[inAddr] = c
}
@@ -198,7 +196,7 @@ func (r *R) renderFlow() {
panic(err)
}
- var participants = map[string]struct{}{}
+ var participants = map[netip.AddrPort]struct{}{}
var participantsVals []string
fmt.Fprintln(f, "```mermaid")
@@ -215,7 +213,7 @@ func (r *R) renderFlow() {
continue
}
participants[addr] = struct{}{}
- sanAddr := strings.Replace(addr, ":", "-", 1)
+ sanAddr := strings.Replace(addr.String(), ":", "-", 1)
participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf(
f, " participant %s as Nebula: %s
UDP: %s\n",
@@ -252,9 +250,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n",
- strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
+ strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
line,
- strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+ strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
)
}
@@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool {
- return c[i].GetVpnIp() > c[j].GetVpnIp()
+ return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
})
s := renderHostmaps(c...)
@@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along
case p := <-udpTx:
- outAddr := sender.GetUDPAddr()
r.Lock()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- c := r.getControl(outAddr, inAddr, p)
+ c := r.getControl(sender.GetUDPAddr(), p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
@@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
- outAddr := cm[x].GetUDPAddr()
-
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- c := r.getControl(outAddr, inAddr, p)
+ c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
@@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
panic(err)
}
- outAddr := sender.GetUDPAddr()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- receiver := r.getControl(outAddr, inAddr, p)
+ receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
if receiver == nil {
r.Unlock()
- panic("Can't route for host: " + inAddr)
+ panic("Can't RouteExitFunc for host: " + p.To.String())
}
e := whatDo(p, receiver)
@@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
// If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
if finish == KeepRouting {
finish = RouteAndExit
}
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
- if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
+ if p.To == toAddr {
return finish
}
@@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
r.Lock()
p := rx.Interface().(*udp.Packet)
-
- outAddr := cm[x].GetUDPAddr()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- receiver := r.getControl(outAddr, inAddr, p)
+ receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
if receiver == nil {
r.Unlock()
- panic("Can't route for host: " + inAddr)
+ panic("Can't RouteForAllExitFunc for host: " + p.To.String())
}
e := whatDo(p, receiver)
@@ -697,12 +685,10 @@ func (r *R) FlushAll() {
p := rx.Interface().(*udp.Packet)
- outAddr := cm[x].GetUDPAddr()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- receiver := r.getControl(outAddr, inAddr, p)
+ receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
if receiver == nil {
r.Unlock()
- panic("Can't route for host: " + inAddr)
+ panic("Can't FlushAll for host: " + p.To.String())
}
r.Unlock()
}
@@ -710,28 +696,14 @@ func (r *R) FlushAll() {
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
- if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
- p.FromIp = newAddr.IP
- p.FromPort = uint16(newAddr.Port)
+func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
+ if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
+ p.From = newAddr
}
c, ok := r.inNat[toAddr]
if ok {
- sHost, sPort, err := net.SplitHostPort(toAddr)
- if err != nil {
- panic(err)
- }
-
- port, err := strconv.Atoi(sPort)
- if err != nil {
- panic(err)
- }
-
- r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
- IP: net.ParseIP(sHost),
- Port: port,
- }
+ r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
return c
}
@@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string {
}
from := "unknown"
- if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok {
- from = c.GetUDPAddr()
+ srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
+ if c, ok := r.vpnControls[srcAddr]; ok {
+ from = c.GetUDPAddr().String()
}
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
@@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string {
return fmt.Sprintf(
" %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n",
strings.Replace(from, ":", "-", 1),
- strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+ strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
udp.SrcPort,
udp.DstPort,
string(data.Payload()),
diff --git a/examples/go_service/main.go b/examples/go_service/main.go
index f46273acf..30178c034 100644
--- a/examples/go_service/main.go
+++ b/examples/go_service/main.go
@@ -4,6 +4,7 @@ import (
"bufio"
"fmt"
"log"
+ "net"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/service"
@@ -54,16 +55,16 @@ pki:
cert: /home/rice/Developer/nebula-config/app.crt
key: /home/rice/Developer/nebula-config/app.key
`
- var config config.C
- if err := config.LoadString(configStr); err != nil {
+ var cfg config.C
+ if err := cfg.LoadString(configStr); err != nil {
return err
}
- service, err := service.New(&config)
+ svc, err := service.New(&cfg)
if err != nil {
return err
}
- ln, err := service.Listen("tcp", ":1234")
+ ln, err := svc.Listen("tcp", ":1234")
if err != nil {
return err
}
@@ -73,16 +74,24 @@ pki:
log.Printf("accept error: %s", err)
break
}
- defer conn.Close()
+ defer func(conn net.Conn) {
+ _ = conn.Close()
+ }(conn)
log.Printf("got connection")
- conn.Write([]byte("hello world\n"))
+ _, err = conn.Write([]byte("hello world\n"))
+ if err != nil {
+ log.Printf("write error: %s", err)
+ }
scanner := bufio.NewScanner(conn)
for scanner.Scan() {
message := scanner.Text()
- fmt.Fprintf(conn, "echo: %q\n", message)
+ _, err = fmt.Fprintf(conn, "echo: %q\n", message)
+ if err != nil {
+ log.Printf("write error: %s", err)
+ }
log.Printf("got message %q", message)
}
@@ -92,8 +101,8 @@ pki:
}
}
- service.Close()
- if err := service.Wait(); err != nil {
+ _ = svc.Close()
+ if err := svc.Wait(); err != nil {
return err
}
return nil
diff --git a/firewall.go b/firewall.go
index 3e760feb3..8a409d25d 100644
--- a/firewall.go
+++ b/firewall.go
@@ -6,23 +6,23 @@ import (
"errors"
"fmt"
"hash/fnv"
- "net"
+ "net/netip"
"reflect"
"strconv"
"strings"
"sync"
"time"
+ "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
)
type FirewallInterface interface {
- AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
+ AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
}
type conn struct {
@@ -52,8 +52,8 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
- localIps *cidr.Tree4[struct{}]
- assignedCIDR *net.IPNet
+ localIps *bart.Table[struct{}]
+ assignedCIDR netip.Prefix
hasSubnets bool
rules string
@@ -108,7 +108,7 @@ type FirewallRule struct {
Any *firewallLocalCIDR
Hosts map[string]*firewallLocalCIDR
Groups []*firewallGroups
- CIDR *cidr.Tree4[*firewallLocalCIDR]
+ CIDR *bart.Table[*firewallLocalCIDR]
}
type firewallGroups struct {
@@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
- LocalCIDR *cidr.Tree4[struct{}]
+ LocalCIDR *bart.Table[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout
}
- localIps := cidr.NewTree4[struct{}]()
- var assignedCIDR *net.IPNet
+ localIps := new(bart.Table[struct{}])
+ var assignedCIDR netip.Prefix
+ var assignedSet bool
for _, ip := range c.Details.Ips {
- ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
- localIps.AddCIDR(ipNet, struct{}{})
+ //TODO: IPV6-WORK the unmap is a bit unfortunate
+ nip, _ := netip.AddrFromSlice(ip.IP)
+ nip = nip.Unmap()
+ nprefix := netip.PrefixFrom(nip, nip.BitLen())
+ localIps.Insert(nprefix, struct{}{})
- if assignedCIDR == nil {
+ if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
- assignedCIDR = ipNet
+ assignedCIDR = nprefix
+ assignedSet = true
}
}
for _, n := range c.Details.Subnets {
- localIps.AddCIDR(n, struct{}{})
+ nip, _ := netip.AddrFromSlice(n.IP)
+ ones, _ := n.Mask.Size()
+ nip = nip.Unmap()
+ localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
}
return &Firewall{
@@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
}
// AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
- if ip != nil {
+ if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
- if localIp != nil {
+ if localIp.IsValid() {
lIp = localIp.String()
}
@@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
- var cidr *net.IPNet
+ var cidr netip.Prefix
if r.Cidr != "" {
- _, cidr, err = net.ParseCIDR(r.Cidr)
+ cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
- var localCidr *net.IPNet
+ var localCidr netip.Prefix
if r.LocalCidr != "" {
- _, localCidr, err = net.ParseCIDR(r.LocalCidr)
+ localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
@@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
- ok, _ := remoteCidr.Contains(fp.RemoteIP)
+ //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+ _, ok := remoteCidr.Lookup(fp.RemoteIP)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
@@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// Make sure we are supposed to be handling this local ip address
- ok, _ := f.localIps.Contains(fp.LocalIP)
+ //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+ _, ok := f.localIps.Lookup(fp.LocalIP)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
return ErrInvalidLocalIP
@@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) {
- //TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
@@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
-func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
-func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([]*firewallGroups, 0),
- CIDR: cidr.NewTree4[*firewallLocalCIDR](),
+ CIDR: new(bart.Table[*firewallLocalCIDR]),
}
}
@@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
-func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
- LocalCIDR: cidr.NewTree4[struct{}](),
+ LocalCIDR: new(bart.Table[struct{}]),
}
}
@@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
fr.Hosts[host] = nlc
}
- if ip != nil {
- _, nlc := fr.CIDR.GetCIDR(ip)
+ if ip.IsValid() {
+ nlc, _ := fr.CIDR.Get(ip)
if nlc == nil {
nlc = flc()
}
@@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
if err != nil {
return err
}
- fr.CIDR.AddCIDR(ip, nlc)
+ fr.CIDR.Insert(ip, nlc)
}
return nil
}
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
- if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
+ if len(groups) == 0 && host == "" && !ip.IsValid() {
return true
}
@@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true
}
- if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
+ if ip.IsValid() && ip.Bits() == 0 {
return true
}
@@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
}
- return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
- return flc.match(p, c)
+ matched := false
+ prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
+ fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
+ if prefix.Contains(p.RemoteIP) && val.match(p, c) {
+ matched = true
+ return false
+ }
+ return true
})
+ return matched
}
-func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
- if localIp == nil {
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
+ if !localIp.IsValid() {
if !f.hasSubnets || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
localIp = f.assignedCIDR
- } else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
+ } else if localIp.Bits() == 0 {
flc.Any = true
}
- flc.LocalCIDR.AddCIDR(localIp, struct{}{})
+ flc.LocalCIDR.Insert(localIp, struct{}{})
return nil
}
@@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
return true
}
- ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
+ _, ok := flc.LocalCIDR.Lookup(p.LocalIP)
return ok
}
diff --git a/firewall/packet.go b/firewall/packet.go
index dc3270eba..0cd206721 100644
--- a/firewall/packet.go
+++ b/firewall/packet.go
@@ -4,8 +4,7 @@ import (
"encoding/json"
"fmt"
mathrand "math/rand"
-
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
type m map[string]interface{}
@@ -21,8 +20,8 @@ const (
)
type Packet struct {
- LocalIP iputil.VpnIp
- RemoteIP iputil.VpnIp
+ LocalIP netip.Addr
+ RemoteIP netip.Addr
LocalPort uint16
RemotePort uint16
Protocol uint8
diff --git a/firewall_test.go b/firewall_test.go
index b5beff61e..4d47e785f 100644
--- a/firewall_test.go
+++ b/firewall_test.go
@@ -5,13 +5,13 @@ import (
"errors"
"math"
"net"
+ "net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
@@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) {
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
- _, ti, _ := net.ParseCIDR("1.2.3.4/32")
+ ti, err := netip.ParsePrefix("1.2.3.4/32")
+ assert.NoError(t, err)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
- ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
+ _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
- ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
+ _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
+ anyIp, err := netip.ParsePrefix("0.0.0.0/0")
+ assert.NoError(t, err)
+
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
- assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
+ assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+ assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
@@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr("1.2.3.4"),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
- assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
+ assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
@@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) {
// test remote mismatch
oldRemote := p.RemoteIP
- p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
+ p.RemoteIP = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
TCP: firewallPort{},
}
- _, n, _ := net.ParseCIDR("172.1.1.1/32")
- goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
- _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
- _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
+ pfix := netip.MustParsePrefix("172.1.1.1/32")
+ _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
+ _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{}
- ip, _, _ := net.ParseCIDR("9.254.254.254/32")
- lip := iputil.Ip2VpnIp(ip)
+ ip := netip.MustParsePrefix("9.254.254.254/32")
for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
+ assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
}
})
@@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+ assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
}
})
@@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
- assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+ assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
}
})
@@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h.CreateRemoteCIDR(&c)
@@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
@@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
@@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c1,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h1.CreateRemoteCIDR(&c1)
@@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c2,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h2.CreateRemoteCIDR(&c2)
@@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c3,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
@@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
- assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
+ assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
- assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
+ assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
@@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr
- cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
+ cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
@@ -871,8 +872,8 @@ type addRuleCall struct {
endPort int32
groups []string
host string
- ip *net.IPNet
- localIp *net.IPNet
+ ip netip.Prefix
+ localIp netip.Prefix
caName string
caSha string
}
@@ -882,7 +883,7 @@ type mockFirewall struct {
nextCallReturn error
}
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{
incoming: incoming,
proto: proto,
diff --git a/go.mod b/go.mod
index dc9e01e06..adb2e84cf 100644
--- a/go.mod
+++ b/go.mod
@@ -10,6 +10,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
+ github.com/gaissmai/bart v0.11.1
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
@@ -22,12 +23,12 @@ require (
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.2.1-beta.2
- golang.org/x/crypto v0.24.0
+ golang.org/x/crypto v0.26.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
- golang.org/x/net v0.26.0
- golang.org/x/sync v0.7.0
- golang.org/x/sys v0.21.0
- golang.org/x/term v0.21.0
+ golang.org/x/net v0.28.0
+ golang.org/x/sync v0.8.0
+ golang.org/x/sys v0.24.0
+ golang.org/x/term v0.23.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -38,6 +39,7 @@ require (
require (
github.com/beorn7/perks v1.0.1 // indirect
+ github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect
diff --git a/go.sum b/go.sum
index 32099f2d1..3afd6cb05 100644
--- a/go.sum
+++ b/go.sum
@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
+github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -24,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
+github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -147,8 +151,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
-golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
+golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
+golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -167,8 +171,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
-golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
+golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
+golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -176,8 +180,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
+golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -195,11 +199,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
-golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
+golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
-golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
+golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
+golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
diff --git a/handshake_ix.go b/handshake_ix.go
index 95c51f81f..0d54b0175 100644
--- a/handshake_ix.go
+++ b/handshake_ix.go
@@ -1,12 +1,12 @@
package nebula
import (
+ "net/netip"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
@@ -72,7 +72,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true
}
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
@@ -108,12 +108,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
e.Info("Invalid certificate from host")
return
}
- vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+
+ vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+ if !ok {
+ e := f.l.WithError(err).WithField("udpAddr", addr).
+ WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
+
+ if f.l.Level > logrus.DebugLevel {
+ e = e.WithField("cert", remoteCert)
+ }
+
+ e.Info("Invalid vpn ip from host")
+ return
+ }
+
+ vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
- if vpnIp == f.myVpnIp {
+ if vpnIp == f.myVpnNet.Addr() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
@@ -122,8 +136,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return
}
- if addr != nil {
- if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
+ if addr.IsValid() {
+ if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
@@ -153,13 +167,10 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
TotalPorts: uint32(f.multiPort.TxPorts),
}
}
- if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(addr.Port) {
+ if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(addr.Port()) {
// The other side sent us a handshake from a different port, make sure
// we send responses back to the BasePort
- addr = &udp.Addr{
- IP: addr.IP,
- Port: uint16(hs.Details.InitiatorMultiPort.BasePort),
- }
+ addr = netip.AddrPortFrom(addr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort))
}
hostinfo := &HostInfo{
@@ -172,8 +183,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
multiportTx: multiportTx,
multiportRx: multiportRx,
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
@@ -246,7 +257,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
case ErrAlreadySeen:
if hostinfo.multiportRx {
// The other host is sending to us with multiport, so only grab the IP
- addr.Port = hostinfo.remote.Port
+ addr = netip.AddrPortFrom(addr.Addr(), hostinfo.remote.Port())
}
// Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
@@ -257,7 +268,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
- if addr != nil {
+ if addr.IsValid() {
if multiportTx {
// TODO remove alloc here
raw := make([]byte, len(msg)+udp.RawOverhead)
@@ -330,7 +341,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
- if addr != nil {
+ if addr.IsValid() {
if multiportTx {
// TODO remove alloc here
raw := make([]byte, len(msg)+udp.RawOverhead)
@@ -379,7 +390,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return
}
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
@@ -389,8 +400,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
defer hh.Unlock()
hostinfo := hh.hostinfo
- if addr != nil {
- if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
+ if addr.IsValid() {
+ if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
@@ -432,13 +443,13 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx
}
- if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(addr.Port) {
+ if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(addr.Port()) {
// The other side sent us a handshake from a different port, make sure
// we send responses back to the BasePort
- addr = &udp.Addr{
- IP: addr.IP,
- Port: uint16(hs.Details.ResponderMultiPort.BasePort),
- }
+ addr = netip.AddrPortFrom(
+ addr.Addr(),
+ uint16(hs.Details.ResponderMultiPort.BasePort),
+ )
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
@@ -456,7 +467,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
return true
}
- vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+ vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+ if !ok {
+ e := f.l.WithError(err).WithField("udpAddr", addr).
+ WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
+
+ if f.l.Level > logrus.DebugLevel {
+ e = e.WithField("cert", remoteCert)
+ }
+
+ e.Info("Invalid vpn ip from host")
+ return true
+ }
+
+ vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
@@ -521,7 +545,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding
- if addr != nil {
+ if addr.IsValid() {
hostinfo.SetRemote(addr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
diff --git a/handshake_manager.go b/handshake_manager.go
index d7a9ee719..ce8af3a6a 100644
--- a/handshake_manager.go
+++ b/handshake_manager.go
@@ -6,15 +6,15 @@ import (
"crypto/rand"
"encoding/binary"
"errors"
- "net"
+ "net/netip"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
+ "golang.org/x/exp/slices"
)
const (
@@ -35,7 +35,7 @@ var (
type HandshakeConfig struct {
tryInterval time.Duration
- retries int
+ retries int64
triggerBuffer int
useRelays bool
@@ -46,14 +46,14 @@ type HandshakeManager struct {
// Mutex for interacting with the vpnIps and indexes maps
sync.RWMutex
- vpnIps map[iputil.VpnIp]*HandshakeHostInfo
+ vpnIps map[netip.Addr]*HandshakeHostInfo
indexes map[uint32]*HandshakeHostInfo
mainHostMap *HostMap
lightHouse *LightHouse
outside udp.Conn
config HandshakeConfig
- OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
+ OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
@@ -64,17 +64,17 @@ type HandshakeManager struct {
udpRaw *udp.RawConn
// can be used to trigger outbound handshake for the given vpnIp
- trigger chan iputil.VpnIp
+ trigger chan netip.Addr
}
type HandshakeHostInfo struct {
sync.Mutex
- startTime time.Time // Time that we first started trying with this handshake
- ready bool // Is the handshake ready
- counter int // How many attempts have we made so far
- lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt
- packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+ startTime time.Time // Time that we first started trying with this handshake
+ ready bool // Is the handshake ready
+ counter int64 // How many attempts have we made so far
+ lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
+ packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
}
@@ -106,14 +106,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
- vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{},
+ vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{},
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
- trigger: make(chan iputil.VpnIp, config.triggerBuffer),
- OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+ trigger: make(chan netip.Addr, config.triggerBuffer),
+ OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -137,10 +137,10 @@ func (c *HandshakeManager) Run(ctx context.Context) {
}
}
-func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp
- if addr != nil {
- if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+ if addr.IsValid() {
+ if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
@@ -173,7 +173,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
}
}
-func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
hh := hm.queryVpnIp(vpnIp)
if hh == nil {
return
@@ -215,7 +215,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
}
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
- remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
+ remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
// This is a very specific optimization for a fast lighthouse reply.
@@ -237,9 +237,9 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
}
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
- var sentTo []*udp.Addr
+ var sentTo []netip.AddrPort
var sentMultiport bool
- hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
+ hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
@@ -294,13 +294,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to
- if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
+ if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
continue
}
- relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
- if relayHostInfo == nil || relayHostInfo.remote == nil {
+ relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+ if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
- hm.f.Handshake(*relay)
+ hm.f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through it
@@ -311,12 +311,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+
+ //TODO: IPV6-WORK
+ myVpnIpB := hm.f.myVpnNet.Addr().As4()
+ theirVpnIpB := vpnIp.As4()
+
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
- RelayFromIp: uint32(hm.lightHouse.myVpnIp),
- RelayToIp: uint32(vpnIp),
+ RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
+ RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
@@ -327,10 +332,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
- "relayFrom": hm.lightHouse.myVpnIp,
+ "relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
- "relay": *relay}).
+ "relay": relay}).
Info("send CreateRelayRequest")
}
default:
@@ -342,17 +347,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
}
} else {
// No relays exist or requested yet.
- if relayHostInfo.remote != nil {
+ if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
+ //TODO: IPV6-WORK
+ myVpnIpB := hm.f.myVpnNet.Addr().As4()
+ theirVpnIpB := vpnIp.As4()
+
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
- RelayFromIp: uint32(hm.lightHouse.myVpnIp),
- RelayToIp: uint32(vpnIp),
+ RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
+ RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
@@ -362,10 +371,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
- "relayFrom": hm.lightHouse.myVpnIp,
+ "relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
- "relay": *relay}).
+ "relay": relay}).
Info("send CreateRelayRequest")
}
}
@@ -381,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
// The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
+func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
hm.mainHostMap.RLock()
h, ok := hm.mainHostMap.Hosts[vpnIp]
hm.mainHostMap.RUnlock()
@@ -398,7 +407,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
}
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok {
@@ -414,8 +423,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
@@ -505,7 +514,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo
- return existingIndex, ErrLocalIndexCollision
+ return existingPendingIndex.hostinfo, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
@@ -581,7 +590,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
- c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
+ c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
}
delete(c.indexes, hostinfo.localIndexId)
@@ -596,7 +605,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
}
}
-func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp)
if hh != nil {
return hh.hostinfo
@@ -605,7 +614,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
}
-func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
hm.RLock()
defer hm.RUnlock()
return hm.vpnIps[vpnIp]
@@ -625,7 +634,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index]
}
-func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
+func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return c.mainHostMap.GetPreferredRanges()
}
@@ -682,6 +691,6 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
return index, nil
}
-func hsTimeout(tries int, interval time.Duration) time.Duration {
- return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
+func hsTimeout(tries int64, interval time.Duration) time.Duration {
+ return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
}
diff --git a/handshake_manager_test.go b/handshake_manager_test.go
index 9a6335757..a78b45f54 100644
--- a/handshake_manager_test.go
+++ b/handshake_manager_test.go
@@ -1,13 +1,12 @@
package nebula
import (
- "net"
+ "net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
@@ -15,10 +14,11 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger()
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
- preferredRanges := []*net.IPNet{localrange}
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ ip := netip.MustParseAddr("172.1.1.2")
+
+ preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l, vpncidr)
mainHM.preferredRanges.Store(&preferredRanges)
@@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip)
}
-func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
+func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
@@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
type mockEncWriter struct {
}
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
return
}
@@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
return
}
-func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}
+func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
diff --git a/hostmap.go b/hostmap.go
index 73bd563df..40031a33b 100644
--- a/hostmap.go
+++ b/hostmap.go
@@ -3,18 +3,17 @@ package nebula
import (
"errors"
"net"
+ "net/netip"
"sync"
"sync/atomic"
"time"
+ "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
// const ProbeLen = 100
@@ -49,7 +48,7 @@ type Relay struct {
State int
LocalIndex uint32
RemoteIndex uint32
- PeerIp iputil.VpnIp
+ PeerIp netip.Addr
}
type HostMap struct {
@@ -57,9 +56,9 @@ type HostMap struct {
Indexes map[uint32]*HostInfo
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
- Hosts map[iputil.VpnIp]*HostInfo
- preferredRanges atomic.Pointer[[]*net.IPNet]
- vpnCIDR *net.IPNet
+ Hosts map[netip.Addr]*HostInfo
+ preferredRanges atomic.Pointer[[]netip.Prefix]
+ vpnCIDR netip.Prefix
l *logrus.Logger
}
@@ -69,12 +68,12 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
- relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
- relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
- relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
+ relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
+ relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
+ relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
}
-func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
+func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
delete(rs.relays, ip)
@@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret
}
-func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[ip]
return r, ok
}
-func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) {
+func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
rs.relays[ip] = struct{}{}
}
-func (rs *RelayState) CopyRelayIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayIps() []netip.Addr {
rs.RLock()
defer rs.RUnlock()
- ret := make([]iputil.VpnIp, 0, len(rs.relays))
+ ret := make([]netip.Addr, 0, len(rs.relays))
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret
}
-func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock()
defer rs.RUnlock()
- currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp))
+ currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
for relayIp := range rs.relayForByIp {
currentRelays = append(currentRelays, relayIp)
}
@@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
return ret
}
-func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
- rs.Lock()
- defer rs.Unlock()
- r, ok := rs.relayForByIdx[localIdx]
- if !ok {
- return iputil.VpnIp(0), false
- }
- delete(rs.relayForByIdx, localIdx)
- delete(rs.relayForByIp, r.PeerIp)
- return r.PeerIp, true
-}
-
-func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
+func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp]
@@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
return &newRelay, true
}
-func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp]
@@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
return r, ok
}
-func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
+func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock()
defer rs.Unlock()
rs.relayForByIp[ip] = r
@@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
}
type HostInfo struct {
- remote *udp.Addr
+ remote netip.AddrPort
remotes *RemoteList
promoteCounter atomic.Uint32
ConnectionState *ConnectionState
remoteIndexId uint32
localIndexId uint32
- vpnIp iputil.VpnIp
+ vpnIp netip.Addr
recvError atomic.Uint32
- remoteCidr *cidr.Tree4[struct{}]
+ remoteCidr *bart.Table[struct{}]
relayState RelayState
// If true, we should send to this remote using multiport
@@ -233,7 +220,7 @@ type HostInfo struct {
lastHandshakeTime uint64
lastRoam time.Time
- lastRoamRemote *udp.Addr
+ lastRoamRemote netip.AddrPort
// Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock.
@@ -260,7 +247,7 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
+func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR)
hm.reload(c, true)
@@ -275,12 +262,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho
return hm
}
-func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
+func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
- Hosts: map[iputil.VpnIp]*HostInfo{},
+ Hosts: map[netip.Addr]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
@@ -288,11 +275,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
func (hm *HostMap) reload(c *config.C, initial bool) {
if initial || c.HasChanged("preferred_ranges") {
- var preferredRanges []*net.IPNet
+ var preferredRanges []netip.Prefix
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
for _, rawPreferredRange := range rawPreferredRanges {
- _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
+ preferredRange, err := netip.ParsePrefix(rawPreferredRange)
if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
@@ -384,7 +371,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 {
- hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+ hm.Hosts = map[netip.Addr]*HostInfo{}
}
if hostinfo.next != nil {
@@ -467,11 +454,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
}
}
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil)
}
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock()
defer hm.RUnlock()
@@ -489,7 +476,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
return nil, nil, errors.New("unable to find host with relay")
}
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@@ -541,7 +528,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}
}
-func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
+func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
return *hm.preferredRanges.Load()
}
@@ -566,14 +553,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
-func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
+func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
c := i.promoteCounter.Add(1)
if c%ifce.tryPromoteEvery.Load() == 0 {
remote := i.remote
// return early if we are already on a preferred remote
- if remote != nil {
- rIP := remote.IP
+ if remote.IsValid() {
+ rIP := remote.Addr()
for _, l := range preferredRanges {
if l.Contains(rIP) {
return
@@ -581,8 +568,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}
}
- i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
- if remote != nil && (addr == nil || !preferred) {
+ i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
+ if remote.IsValid() && (!addr.IsValid() || !preferred) {
return
}
@@ -611,23 +598,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil
}
-func (i *HostInfo) SetRemote(remote *udp.Addr) {
+func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object
- if !i.remote.Equals(remote) {
- i.remote = remote.Copy()
- i.remotes.LearnRemote(i.vpnIp, remote.Copy())
+ if i.remote != remote {
+ i.remote = remote
+ i.remotes.LearnRemote(i.vpnIp, remote)
}
}
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
- if newRemote == nil {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
+ if !newRemote.IsValid() {
// relays have nil udp Addrs
return false
}
currentRemote := i.remote
- if currentRemote == nil {
+ if !currentRemote.IsValid() {
i.SetRemote(newRemote)
return true
}
@@ -637,11 +624,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
newIsPreferred := false
for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote
- if l.Contains(currentRemote.IP) {
+ if l.Contains(currentRemote.Addr()) {
return false
}
- if l.Contains(newRemote.IP) {
+ if l.Contains(newRemote.Addr()) {
newIsPreferred = true
}
}
@@ -649,7 +636,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
if newIsPreferred {
// Consider this a roaming event
i.lastRoam = time.Now()
- i.lastRoamRemote = currentRemote.Copy()
+ i.lastRoamRemote = currentRemote
i.SetRemote(newRemote)
@@ -672,13 +659,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
return
}
- remoteCidr := cidr.NewTree4[struct{}]()
+ remoteCidr := new(bart.Table[struct{}])
for _, ip := range c.Details.Ips {
- remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+ //TODO: IPV6-WORK what to do when ip is invalid?
+ nip, _ := netip.AddrFromSlice(ip.IP)
+ nip = nip.Unmap()
+ bits, _ := ip.Mask.Size()
+ remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
}
for _, n := range c.Details.Subnets {
- remoteCidr.AddCIDR(n, struct{}{})
+ //TODO: IPV6-WORK what to do when ip is invalid?
+ nip, _ := netip.AddrFromSlice(n.IP)
+ nip = nip.Unmap()
+ bits, _ := n.Mask.Size()
+ remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
}
i.remoteCidr = remoteCidr
}
@@ -703,9 +698,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
-func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
+func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
- var ips []net.IP
+ var ips []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
@@ -727,20 +722,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
ip = v.IP
}
+ nip, ok := netip.AddrFromSlice(ip)
+ if !ok {
+ if l.Level >= logrus.DebugLevel {
+ l.WithField("localIp", ip).Debug("ip was invalid for netip")
+ }
+ continue
+ }
+ nip = nip.Unmap()
+
//TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well
- if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
- allow := allowList.Allow(ip)
+ if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
+ allow := allowList.Allow(nip)
if l.Level >= logrus.TraceLevel {
- l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
+ l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
}
if !allow {
continue
}
- ips = append(ips, ip)
+ ips = append(ips, nip)
}
}
}
- return &ips
+ return ips
}
diff --git a/hostmap_test.go b/hostmap_test.go
index 8311cef0b..7e2feb810 100644
--- a/hostmap_test.go
+++ b/hostmap_test.go
@@ -1,7 +1,7 @@
package nebula
import (
- "net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/config"
@@ -13,18 +13,15 @@ func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := newHostMap(
l,
- &net.IPNet{
- IP: net.IP{10, 0, 0, 1},
- Mask: net.IPMask{255, 255, 255, 0},
- },
+ netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{}
- h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
- h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
- h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
- h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+ h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+ h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+ h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+ h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f)
@@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
- prim := hm.QueryVpnIp(1)
+ prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := newHostMap(
l,
- &net.IPNet{
- IP: net.IP{10, 0, 0, 1},
- Mask: net.IPMask{255, 255, 255, 0},
- },
+ netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{}
- h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
- h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
- h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
- h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
- h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
- h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+ h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+ h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+ h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+ h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+ h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
+ h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f)
@@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
- prim := hm.QueryVpnIp(1)
+ prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next)
// Make sure we only have h4
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
@@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next)
// Make sure we have nil
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim)
}
@@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) {
hm := NewHostMapFromConfig(
l,
- &net.IPNet{
- IP: net.IP{10, 0, 0, 1},
- Mask: net.IPMask{255, 255, 255, 0},
- },
+ netip.MustParsePrefix("10.0.0.1/24"),
c,
)
- toS := func(ipn []*net.IPNet) []string {
+ toS := func(ipn []netip.Prefix) []string {
var s []string
for _, n := range ipn {
s = append(s, n.String())
diff --git a/hostmap_tester.go b/hostmap_tester.go
index 0d5d41bf7..b2d1d1b5b 100644
--- a/hostmap_tester.go
+++ b/hostmap_tester.go
@@ -5,9 +5,11 @@ package nebula
// This file contains functions used to export information to the e2e testing framework
-import "github.com/slackhq/nebula/iputil"
+import (
+ "net/netip"
+)
-func (i *HostInfo) GetVpnIp() iputil.VpnIp {
+func (i *HostInfo) GetVpnIp() netip.Addr {
return i.vpnIp
}
diff --git a/inside.go b/inside.go
index 429408bc3..467f1f28f 100644
--- a/inside.go
+++ b/inside.go
@@ -1,6 +1,8 @@
package nebula
import (
+ "net/netip"
+
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -19,11 +21,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
// Ignore local broadcast packets
- if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
+ if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
return
}
- if fwPacket.RemoteIP == f.myVpnIp {
+ if fwPacket.RemoteIP == f.myVpnNet.Addr() {
// Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula
@@ -39,8 +41,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
- // Ignore broadcast packets
- if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
+ // Ignore multicast packets
+ if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
return
}
@@ -64,7 +66,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
- f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q, fwPacket)
+ f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q, fwPacket)
} else {
f.rejectInside(packet, out, q)
@@ -113,19 +115,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
return
}
- f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q, nil)
+ f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil)
}
-func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
+func (f *Interface) Handshake(vpnIp netip.Addr) {
f.getOrHandshake(vpnIp, nil)
}
// getOrHandshake returns nil if the vpnIp is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
- if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
+func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+ if !f.myVpnNet.Contains(vpnIp) {
vpnIp = f.inside.RouteFor(vpnIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return nil, false
}
}
@@ -152,11 +154,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
return
}
- f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0, nil)
+ f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})
@@ -182,10 +184,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
- f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0, nil)
+ f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
}
-func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
+func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0, nil)
}
@@ -255,7 +257,7 @@ func (f *Interface) SendVia(via *HostInfo,
f.connectionManager.RelayUsed(relay.LocalIndex)
}
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
if ci.eKey == nil {
//TODO: log warning
return
@@ -277,7 +279,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
- useRelay := remote == nil && hostinfo.remote == nil
+ useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out
if useRelay {
@@ -325,7 +327,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
return
}
- if remote != nil {
+ if remote.IsValid() {
if multiport {
rawOut = rawOut[:len(out)+udp.RawOverhead]
port := udpPortGetter.UDPSendPort(f.multiPort.TxPorts)
@@ -337,7 +339,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
- } else if hostinfo.remote != nil {
+ } else if hostinfo.remote.IsValid() {
if multiport {
rawOut = rawOut[:len(out)+udp.RawOverhead]
port := udpPortGetter.UDPSendPort(f.multiPort.TxPorts)
@@ -363,8 +365,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
}
-
-func isMulticast(ip iputil.VpnIp) bool {
- // Class D multicast
- return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
-}
diff --git a/interface.go b/interface.go
index d933a3ebc..63abe8dde 100644
--- a/interface.go
+++ b/interface.go
@@ -2,10 +2,11 @@ package nebula
import (
"context"
+ "encoding/binary"
"errors"
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"runtime"
"sync/atomic"
@@ -16,7 +17,6 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
@@ -63,8 +63,8 @@ type Interface struct {
serveDns bool
createTime time.Time
lightHouse *LightHouse
- localBroadcast iputil.VpnIp
- myVpnIp iputil.VpnIp
+ myBroadcastAddr netip.Addr
+ myVpnNet netip.Prefix
dropLocalBroadcast bool
dropMulticast bool
routines int
@@ -103,7 +103,7 @@ type MultiPortConfig struct {
TxBasePort uint16
TxPorts int
TxHandshake bool
- TxHandshakeDelay int
+ TxHandshakeDelay int64
}
type EncWriter interface {
@@ -114,9 +114,9 @@ type EncWriter interface {
out []byte,
nocopy bool,
)
- SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+ SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
- Handshake(vpnIp iputil.VpnIp)
+ Handshake(vpnIp netip.Addr)
}
type sendRecvErrorConfig uint8
@@ -127,10 +127,10 @@ const (
sendRecvErrorPrivate
)
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
switch s {
case sendRecvErrorPrivate:
- return ip.IsPrivate()
+ return ip.Addr().IsPrivate()
case sendRecvErrorAlways:
return true
case sendRecvErrorNever:
@@ -168,7 +168,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
}
certificate := c.pki.GetCertState().Certificate
- myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
+
+ myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+ if !ok {
+ return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
+ }
+
+ myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
+ if !ok {
+ return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
+ }
+
+ myVpnAddr = myVpnAddr.Unmap()
+ myVpnMask = myVpnMask.Unmap()
+
+ if myVpnAddr.BitLen() != myVpnMask.BitLen() {
+ return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
+ }
+
+ ones, _ := certificate.Details.Ips[0].Mask.Size()
+ myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
+
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
@@ -180,14 +200,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
- localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
- myVpnIp: myVpnIp,
+ myVpnNet: myVpnNet,
relayManager: c.relayManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@@ -202,6 +221,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
+ if myVpnAddr.Is4() {
+ addr := myVpnNet.Masked().Addr().As4()
+ binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
+ ifce.myBroadcastAddr = netip.AddrFrom4(addr)
+ }
+
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
diff --git a/iputil/packet.go b/iputil/packet.go
index b18e52447..719e0349e 100644
--- a/iputil/packet.go
+++ b/iputil/packet.go
@@ -6,6 +6,8 @@ import (
"golang.org/x/net/ipv4"
)
+//TODO: IPV6-WORK can probably delete this
+
const (
// Need 96 bytes for the largest reject packet:
// - 20 byte ipv4 header
diff --git a/iputil/util.go b/iputil/util.go
deleted file mode 100644
index 65f7677aa..000000000
--- a/iputil/util.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package iputil
-
-import (
- "encoding/binary"
- "fmt"
- "net"
- "net/netip"
-)
-
-type VpnIp uint32
-
-const maxIPv4StringLen = len("255.255.255.255")
-
-func (ip VpnIp) String() string {
- b := make([]byte, maxIPv4StringLen)
-
- n := ubtoa(b, 0, byte(ip>>24))
- b[n] = '.'
- n++
-
- n += ubtoa(b, n, byte(ip>>16&255))
- b[n] = '.'
- n++
-
- n += ubtoa(b, n, byte(ip>>8&255))
- b[n] = '.'
- n++
-
- n += ubtoa(b, n, byte(ip&255))
- return string(b[:n])
-}
-
-func (ip VpnIp) MarshalJSON() ([]byte, error) {
- return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
-}
-
-func (ip VpnIp) ToIP() net.IP {
- nip := make(net.IP, 4)
- binary.BigEndian.PutUint32(nip, uint32(ip))
- return nip
-}
-
-func (ip VpnIp) ToNetIpAddr() netip.Addr {
- var nip [4]byte
- binary.BigEndian.PutUint32(nip[:], uint32(ip))
- return netip.AddrFrom4(nip)
-}
-
-func Ip2VpnIp(ip []byte) VpnIp {
- if len(ip) == 16 {
- return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
- }
- return VpnIp(binary.BigEndian.Uint32(ip))
-}
-
-func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
- addr, ok := netip.AddrFromSlice(ip)
- if !ok {
- return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
- }
- return addr, nil
-}
-
-func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
- addr, err := ToNetIpAddr(ipNet.IP)
- if err != nil {
- return netip.Prefix{}, err
- }
- ones, bits := ipNet.Mask.Size()
- if ones == 0 && bits == 0 {
- return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
- }
- return netip.PrefixFrom(addr, ones), nil
-}
-
-// ubtoa encodes the string form of the integer v to dst[start:] and
-// returns the number of bytes written to dst. The caller must ensure
-// that dst has sufficient length.
-func ubtoa(dst []byte, start int, v byte) int {
- if v < 10 {
- dst[start] = v + '0'
- return 1
- } else if v < 100 {
- dst[start+1] = v%10 + '0'
- dst[start] = v/10 + '0'
- return 2
- }
-
- dst[start+2] = v%10 + '0'
- dst[start+1] = (v/10)%10 + '0'
- dst[start] = v/100 + '0'
- return 3
-}
diff --git a/iputil/util_test.go b/iputil/util_test.go
deleted file mode 100644
index 712d4264b..000000000
--- a/iputil/util_test.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package iputil
-
-import (
- "net"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestVpnIp_String(t *testing.T) {
- assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
- assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
- assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
- assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
- assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
- assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
-}
diff --git a/lighthouse.go b/lighthouse.go
index df68e1e88..62f406560 100644
--- a/lighthouse.go
+++ b/lighthouse.go
@@ -7,16 +7,16 @@ import (
"fmt"
"net"
"net/netip"
+ "strconv"
"sync"
"sync/atomic"
"time"
+ "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
)
@@ -26,25 +26,18 @@ import (
var ErrHostNotKnown = errors.New("host not known")
-type netIpAndPort struct {
- ip net.IP
- port uint16
-}
-
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps
ctx context.Context
amLighthouse bool
- myVpnIp iputil.VpnIp
- myVpnZeros iputil.VpnIp
- myVpnNet *net.IPNet
+ myVpnNet netip.Prefix
punchConn udp.Conn
punchy *Punchy
// Local cache of answers from light houses
// map of vpn Ip to answers
- addrMap map[iputil.VpnIp]*RemoteList
+ addrMap map[netip.Addr]*RemoteList
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
@@ -57,26 +50,26 @@ type LightHouse struct {
localAllowList atomic.Pointer[LocalAllowList]
// used to trigger the HandshakeManager when we receive HostQueryReply
- handshakeTrigger chan<- iputil.VpnIp
+ handshakeTrigger chan<- netip.Addr
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
- staticList atomic.Pointer[map[iputil.VpnIp]struct{}]
- lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
+ staticList atomic.Pointer[map[netip.Addr]struct{}]
+ lighthouses atomic.Pointer[map[netip.Addr]struct{}]
interval atomic.Int64
updateCancel context.CancelFunc
ifce EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
- advertiseAddrs atomic.Pointer[[]netIpAndPort]
+ advertiseAddrs atomic.Pointer[[]netip.AddrPort]
// IP's of relays that can be used by peers to access me
- relaysForMe atomic.Pointer[[]iputil.VpnIp]
+ relaysForMe atomic.Pointer[[]netip.Addr]
- queryChan chan iputil.VpnIp
+ queryChan chan netip.Addr
- calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
+ calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
@@ -85,7 +78,7 @@ type LightHouse struct {
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
if err != nil {
return nil, util.NewContextualError("Failed to get listening port", nil, err)
}
- nebulaPort = uint32(uPort.Port)
+ nebulaPort = uint32(uPort.Port())
}
- ones, _ := myVpnNet.Mask.Size()
h := LightHouse{
ctx: ctx,
amLighthouse: amLighthouse,
- myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
- myVpnZeros: iputil.VpnIp(32 - ones),
myVpnNet: myVpnNet,
- addrMap: make(map[iputil.VpnIp]*RemoteList),
+ addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
- queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
+ queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
- lighthouses := make(map[iputil.VpnIp]struct{})
+ lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses)
- staticList := make(map[iputil.VpnIp]struct{})
+ staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
if c.GetBool("stats.lighthouse_metrics", false) {
@@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
return &h, nil
}
-func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
-func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
return *lh.lighthouses.Load()
}
@@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
return lh.localAllowList.Load()
}
-func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort {
+func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort {
return *lh.advertiseAddrs.Load()
}
-func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
+func (lh *LightHouse) GetRelaysForMe() []netip.Addr {
return *lh.relaysForMe.Load()
}
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
+func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] {
return lh.calculatedRemotes.Load()
}
@@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 {
func (lh *LightHouse) reload(c *config.C, initial bool) error {
if initial || c.HasChanged("lighthouse.advertise_addrs") {
rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
- advAddrs := make([]netIpAndPort, 0)
+ advAddrs := make([]netip.AddrPort, 0)
for i, rawAddr := range rawAdvAddrs {
- fIp, fPort, err := udp.ParseIPAndPort(rawAddr)
+ host, sport, err := net.SplitHostPort(rawAddr)
if err != nil {
return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
}
- if fPort == 0 {
- fPort = uint16(lh.nebulaPort)
+ ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
+ if err != nil {
+ return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+ }
+ if len(ips) == 0 {
+ return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil)
+ }
+
+ port, err := strconv.Atoi(sport)
+ if err != nil {
+ return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+ }
+
+ if port == 0 {
+ port = int(lh.nebulaPort)
}
- if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) {
+ //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used
+ ip := ips[0].Unmap()
+ if lh.myVpnNet.Contains(ip) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
continue
}
- advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort})
+ advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port)))
}
lh.advertiseAddrs.Store(&advAddrs)
@@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.RUnlock()
}
// Build a new list based on current config.
- staticList := make(map[iputil.VpnIp]struct{})
- err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
+ staticList := make(map[netip.Addr]struct{})
+ err := lh.loadStaticMap(c, staticList)
if err != nil {
return err
}
@@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
- lhMap := make(map[iputil.VpnIp]struct{})
- err := lh.parseLighthouses(c, lh.myVpnNet, lhMap)
+ lhMap := make(map[netip.Addr]struct{})
+ err := lh.parseLighthouses(c, lhMap)
if err != nil {
return err
}
@@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
if len(c.GetStringSlice("relay.relays", nil)) > 0 {
lh.l.Info("Ignoring relays from config because am_relay is true")
}
- relaysForMe := []iputil.VpnIp{}
+ relaysForMe := []netip.Addr{}
lh.relaysForMe.Store(&relaysForMe)
case false:
- relaysForMe := []iputil.VpnIp{}
+ relaysForMe := []netip.Addr{}
for _, v := range c.GetStringSlice("relay.relays", nil) {
lh.l.WithField("relay", v).Info("Read relay from config")
- configRIP := net.ParseIP(v)
- if configRIP != nil {
- relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP))
+ configRIP, err := netip.ParseAddr(v)
+ //TODO: We could print the error here
+ if err == nil {
+ relaysForMe = append(relaysForMe, configRIP)
}
}
lh.relaysForMe.Store(&relaysForMe)
@@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
-func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
for i, host := range lhs {
- ip := net.ParseIP(host)
- if ip == nil {
- return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+ ip, err := netip.ParseAddr(host)
+ if err != nil {
+ return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
- if !tunCidr.Contains(ip) {
- return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+ if !lh.myVpnNet.Contains(ip) {
+ return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil)
}
- lhMap[iputil.Ip2VpnIp(ip)] = struct{}{}
+ lhMap[ip] = struct{}{}
}
if !lh.amLighthouse && len(lhMap) == 0 {
@@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) {
return network, nil
}
-func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error {
d, err := getStaticMapCadence(c)
if err != nil {
return err
@@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return err
}
- lookup_timeout, err := getStaticMapLookupTimeout(c)
+ lookupTimeout, err := getStaticMapLookupTimeout(c)
if err != nil {
return err
}
@@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
i := 0
for k, v := range shm {
- rip := net.ParseIP(fmt.Sprintf("%v", k))
- if rip == nil {
- return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil)
+ vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k))
+ if err != nil {
+ return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
}
- if !tunCidr.Contains(rip) {
- return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil)
+ if !lh.myVpnNet.Contains(vpnIp) {
+ return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil)
}
- vpnIp := iputil.Ip2VpnIp(rip)
vals, ok := v.([]interface{})
if !ok {
vals = []interface{}{v}
@@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
}
- err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+ err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList)
if err != nil {
return err
}
@@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil
}
-func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) Query(ip netip.Addr) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip)
}
@@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
}
// QueryServer is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
+func (lh *LightHouse) QueryServer(ip netip.Addr) {
// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
if lh.amLighthouse || lh.IsLighthouseIP(ip) {
return
@@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
lh.queryChan <- ip
}
-func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
@@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
-func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
+func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock()
// Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok {
@@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
return false, 0, nil
}
-func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
+func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
@@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error {
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
@@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetIPs() {
-
+ if !lh.shouldAdd(vpnIp, addrPort.Addr()) {
+ continue
+ }
switch {
case addrPort.Addr().Is4():
- to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
- if !lh.unlockedShouldAddV4(vpnIp, to) {
- continue
- }
- am.unlockedPrependV4(lh.myVpnIp, to)
+ am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
case addrPort.Addr().Is6():
- to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
- if !lh.unlockedShouldAddV6(vpnIp, to) {
- continue
- }
- am.unlockedPrependV6(lh.myVpnIp, to)
+ am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
}
}
@@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
// addCalculatedRemotes adds any calculated remotes based on the
// lighthouse.calculated_remotes configuration. It returns true if any
// calculated remotes were added
-func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool {
tree := lh.getCalculatedRemotes()
if tree == nil {
return false
}
- ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+ calculatedRemotes, ok := tree.Lookup(vpnIp)
if !ok {
return false
}
@@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
defer am.Unlock()
lh.Unlock()
- am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4)
+ am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4)
return len(calculated) > 0
}
// unlockedGetRemoteList assumes you have the lh lock
-func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList {
am, ok := lh.addrMap[vpnIp]
if !ok {
am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
@@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
return am
}
-func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
- switch {
- case to.Is4():
- ipBytes := to.As4()
- ip := iputil.Ip2VpnIp(ipBytes[:])
- allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
- if lh.l.Level >= logrus.TraceLevel {
- lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
- }
- if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
- return false
- }
- case to.Is6():
- ipBytes := to.As16()
-
- hi := binary.BigEndian.Uint64(ipBytes[:8])
- lo := binary.BigEndian.Uint64(ipBytes[8:])
- allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
- if lh.l.Level >= logrus.TraceLevel {
- lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
- }
-
- // We don't check our vpn network here because nebula does not support ipv6 on the inside
- if !allow {
- return false
- }
+func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool {
+ allow := lh.GetRemoteAllowList().Allow(vpnIp, to)
+ if lh.l.Level >= logrus.TraceLevel {
+ lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+ }
+ if !allow || lh.myVpnNet.Contains(to) {
+ return false
}
+
return true
}
// unlockedShouldAddV4 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
- allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
+func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool {
+ ip := AddrPortFromIp4AndPort(to)
+ allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
- if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
+ if !allow || lh.myVpnNet.Contains(ip.Addr()) {
return false
}
@@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
}
// unlockedShouldAddV6 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
- allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo)
+func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool {
+ ip := AddrPortFromIp6AndPort(to)
+ allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
- // We don't check our vpn network here because nebula does not support ipv6 on the inside
- if !allow {
+ if !allow || lh.myVpnNet.Contains(ip.Addr()) {
return false
}
@@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip
}
-func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool {
if _, ok := lh.GetLighthouses()[vpnIp]; ok {
return true
}
return false
}
-func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
+func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta {
+ if vpnIp.Is6() {
+ //TODO: need to support ipv6
+ panic("ipv6 is not yet supported")
+ }
+
+ b := vpnIp.As4()
return &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
- VpnIp: uint32(VpnIp),
+ VpnIp: binary.BigEndian.Uint32(b[:]),
},
}
}
-func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
- ipp := Ip4AndPort{Port: port}
- ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
- return &ipp
+func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort {
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], ip.Ip)
+ return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port))
+}
+
+func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort {
+ b := [16]byte{}
+ binary.BigEndian.PutUint64(b[:8], ip.Hi)
+ binary.BigEndian.PutUint64(b[8:], ip.Lo)
+ return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port))
}
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
@@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
}
}
-func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
- return &Ip6AndPort{
- Hi: binary.BigEndian.Uint64(ip[:8]),
- Lo: binary.BigEndian.Uint64(ip[8:]),
- Port: port,
- }
-}
-
+// TODO: IPV6-WORK we can delete some more of these
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
ip6Addr := ip.As16()
return &Ip6AndPort{
@@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
Port: uint32(port),
}
}
-func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
- ip := ipp.Ip
- return udp.NewAddr(
- net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
- uint16(ipp.Port),
- )
-}
-
-func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
- return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
-}
func (lh *LightHouse) startQueryWorker() {
if lh.amLighthouse {
@@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() {
}()
}
-func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
+func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) {
if lh.IsLighthouseIP(ip) {
return
}
@@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() {
var v6 []*Ip6AndPort
for _, e := range lh.GetAdvertiseAddrs() {
- if ip := e.ip.To4(); ip != nil {
- v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port)))
+ if e.Addr().Is4() {
+ v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port()))
} else {
- v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port)))
+ v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port()))
}
}
lal := lh.GetLocalAllowList()
- for _, e := range *localIps(lh.l, lal) {
- if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
+ for _, e := range localIps(lh.l, lal) {
+ if lh.myVpnNet.Contains(e) {
continue
}
// Only add IPs that aren't my VPN/tun IP
- if ip := e.To4(); ip != nil {
- v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort))
+ if e.Is4() {
+ v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort)))
} else {
- v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort))
+ v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort)))
}
}
var relays []uint32
for _, r := range lh.GetRelaysForMe() {
- relays = append(relays, (uint32)(r))
+ //TODO: IPV6-WORK both relays and vpnip need ipv6 support
+ b := r.As4()
+ relays = append(relays, binary.BigEndian.Uint32(b[:]))
}
+ //TODO: IPV6-WORK both relays and vpnip need ipv6 support
+ b := lh.myVpnNet.Addr().As4()
+
m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
- VpnIp: uint32(lh.myVpnIp),
+ VpnIp: binary.BigEndian.Uint32(b[:]),
Ip4AndPorts: v4,
Ip6AndPorts: v6,
RelayVpnIp: relays,
@@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
}
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
- return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
+ return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) {
lhh.HandleRequest(rAddr, vpnIp, p, f)
}
}
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
+func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
@@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
}
}
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
@@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
//TODO: we can DRY this further
reqVpnIp := n.Details.VpnIp
+
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ queryVpnIp := netip.AddrFrom4(b)
+
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
- found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
+ found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIp
@@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
- n.Details.VpnIp = uint32(vpnIp)
-
+ //TODO: IPV6-WORK
+ b = vpnIp.As4()
+ n.Details.VpnIp = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb)
@@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
}
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
- w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
+
+ //TODO: IPV6-WORK
+ binary.BigEndian.PutUint32(b[:], reqVpnIp)
+ sendTo := netip.AddrFrom4(b)
+ w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
}
if c.relay != nil {
- n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...)
+ //TODO: IPV6-WORK
+ relays := make([]uint32, len(c.relay.relay))
+ b := [4]byte{}
+ for i, _ := range relays {
+ b = c.relay.relay[i].As4()
+ relays[i] = binary.BigEndian.Uint32(b[:])
+ }
+ n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...)
}
}
-func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
+func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
lhh.lh.Lock()
- am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ certVpnIp := netip.AddrFrom4(b)
+ am := lhh.lh.unlockedGetRemoteList(certVpnIp)
am.Lock()
lhh.lh.Unlock()
- certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+ //TODO: IPV6-WORK
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
- am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+
+ //TODO: IPV6-WORK
+ relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+ for i, _ := range n.Details.RelayVpnIp {
+ binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+ relays[i] = netip.AddrFrom4(b)
+ }
+ am.unlockedSetRelay(vpnIp, certVpnIp, relays)
am.Unlock()
// Non-blocking attempt to trigger, skip if it would block
select {
- case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
+ case lhh.lh.handshakeTrigger <- certVpnIp:
default:
}
}
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
}
//Simple check that the host sent this not someone else
- if n.Details.VpnIp != uint32(vpnIp) {
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ detailsVpnIp := netip.AddrFrom4(b)
+ if detailsVpnIp != vpnIp {
if lhh.l.Level >= logrus.DebugLevel {
- lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+ lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
}
return
}
@@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Lock()
lhh.lh.Unlock()
- certVpnIp := iputil.VpnIp(n.Details.VpnIp)
- am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
- am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
- am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+ am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+ am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+
+ //TODO: IPV6-WORK
+ relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+ for i, _ := range n.Details.RelayVpnIp {
+ binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+ relays[i] = netip.AddrFrom4(b)
+ }
+ am.unlockedSetRelay(vpnIp, detailsVpnIp, relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
- n.Details.VpnIp = uint32(vpnIp)
+
+ //TODO: IPV6-WORK
+ vpnIpB := vpnIp.As4()
+ n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:])
ln, err := n.MarshalTo(lhh.pb)
if err != nil {
@@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
empty := []byte{0}
- punch := func(vpnPeer *udp.Addr) {
- if vpnPeer == nil {
+ punch := func(vpnPeer netip.AddrPort) {
+ if !vpnPeer.IsValid() {
return
}
@@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
if lhh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
- lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
+ //TODO: IPV6-WORK, make this debug line not suck
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b))
}
}
for _, a := range n.Details.Ip4AndPorts {
- punch(NewUDPAddrFromLH4(a))
+ punch(AddrPortFromIp4AndPort(a))
}
for _, a := range n.Details.Ip6AndPorts {
- punch(NewUDPAddrFromLH6(a))
+ punch(AddrPortFromIp6AndPort(a))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
- queryVpnIp := iputil.VpnIp(n.Details.VpnIp)
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ queryVpnIp := netip.AddrFrom4(b)
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
@@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
}()
}
}
-
-// ipMaskContains checks if testIp is contained by ip after applying a cidr.
-// zeros is 32 - bits from net.IPMask.Size()
-func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
- return (testIp^ip)>>zeros == 0
-}
diff --git a/lighthouse_test.go b/lighthouse_test.go
index 66427e339..2599f5f2e 100644
--- a/lighthouse_test.go
+++ b/lighthouse_test.go
@@ -2,15 +2,14 @@ package nebula
import (
"context"
+ "encoding/binary"
"fmt"
- "net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
- "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
@@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) {
var m Ip4AndPort
err := m.Unmarshal(b)
assert.NoError(t, err)
- assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
+ ip := netip.MustParseAddr("10.1.1.1")
+ bp := ip.As4()
+ assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
}
func TestNewLhQuery(t *testing.T) {
- myIp := net.ParseIP("192.1.1.1")
- myIpint := iputil.Ip2VpnIp(myIp)
+ myIp, err := netip.ParseAddr("192.1.1.1")
+ assert.NoError(t, err)
// Generating a new lh query should work
- a := NewLhQueryByInt(myIpint)
+ a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
@@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) {
func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger()
- _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+ myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
lh1 := "10.128.0.2"
c := config.NewC(l)
@@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) {
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
- _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+ myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
lh1 := "10.128.0.2"
c := config.NewC(l)
@@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) {
lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies
- c.ReloadConfigString("lighthouse:\n interval: 5")
+ assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function
- c.ReloadConfigString("lighthouse:\n interval: 10")
+ assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine
- c.ReloadConfigString("lighthouse:\n interval: 11")
+ assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
assert.Equal(t, int64(11), lh.interval.Load())
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
- _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
+ myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
@@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
b.Fatal()
}
- hAddr := udp.NewAddrFromString("4.5.6.7:12345")
- hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
- lh.addrMap[3] = NewRemoteList(nil)
- lh.addrMap[3].unlockedSetV4(
- 3,
- 3,
+ hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
+ hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
+
+ vpnIp3 := netip.MustParseAddr("0.0.0.3")
+ lh.addrMap[vpnIp3] = NewRemoteList(nil)
+ lh.addrMap[vpnIp3].unlockedSetV4(
+ vpnIp3,
+ vpnIp3,
[]*Ip4AndPort{
- NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
- NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+ NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
+ NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
- rAddr := udp.NewAddrFromString("1.2.2.3:12345")
- rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
- lh.addrMap[2] = NewRemoteList(nil)
- lh.addrMap[2].unlockedSetV4(
- 3,
- 3,
+ rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
+ rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
+ vpnIp2 := netip.MustParseAddr("0.0.0.3")
+ lh.addrMap[vpnIp2] = NewRemoteList(nil)
+ lh.addrMap[vpnIp2].unlockedSetV4(
+ vpnIp3,
+ vpnIp3,
[]*Ip4AndPort{
- NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
- NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+ NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
+ NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
mw := &mockEncWriter{}
@@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
p, err := req.Marshal()
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
- lhh.HandleRequest(rAddr, 2, p, mw)
+ lhh.HandleRequest(rAddr, vpnIp2, p, mw)
}
})
b.Run("found", func(b *testing.B) {
@@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
- lhh.HandleRequest(rAddr, 2, p, mw)
+ lhh.HandleRequest(rAddr, vpnIp2, p, mw)
}
})
}
@@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
func TestLighthouse_Memory(t *testing.T) {
l := test.NewLogger()
- myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
- myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
- myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
- myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
- myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
- myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
- myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
- myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
- myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
- myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
- myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
- myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
- myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
-
- theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
- theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
- theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
- theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
- theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
- theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+ myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
+ myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
+ myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
+ myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
+ myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
+ myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
+ myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
+ myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
+ myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
+ myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
+ myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
+ myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
+ myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+ theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
+ theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
+ theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
+ theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
+ theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
+ theirVpnIp := netip.MustParseAddr("10.128.0.3")
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
- lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+ lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err)
lhh := lh.NewRequestHandler()
// Test that my first update responds with just that
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
// Grow it back to 2
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it
- newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+ newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
@@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) {
newLHHostUpdate(
myUdpAddr0,
myVpnIp,
- []*udp.Addr{
+ []netip.AddrPort{
myUdpAddr1,
myUdpAddr2,
myUdpAddr3,
@@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) {
)
// Make sure we won't add ips in our vpn network
- bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
- bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
- good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
+ bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
+ bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
+ good := netip.MustParseAddrPort("1.128.0.99:4242")
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
}
@@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
- lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+ lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err)
nc := map[interface{}]interface{}{
@@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) {
assert.NoError(t, err)
}
-func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
+ //TODO: IPV6-WORK
+ bip := queryVpnIp.As4()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
- VpnIp: uint32(queryVpnIp),
+ VpnIp: binary.BigEndian.Uint32(bip[:]),
},
}
@@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
return w.lastReply
}
-func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
+ //TODO: IPV6-WORK
+ bip := vpnIp.As4()
req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
- VpnIp: uint32(vpnIp),
+ VpnIp: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
},
}
for k, v := range addrs {
- req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
+ req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
}
b, err := req.Marshal()
@@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
// )
//}
-func Test_ipMaskContains(t *testing.T) {
- assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
- assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
- assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-}
-
type testLhReply struct {
nebType header.MessageType
nebSubType header.MessageSubType
- vpnIp iputil.VpnIp
+ vpnIp netip.Addr
msg *NebulaMeta
}
@@ -414,7 +416,7 @@ type testEncWriter struct {
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
}
-func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
+func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
}
func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
@@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
}
}
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{}
err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
- if !assert.Len(t, have, len(want)) {
- return
- }
-
- for k, w := range want {
- if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
- assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
- }
- }
-}
-
-// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
+func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) {
return
}
for k, w := range want {
- if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
- assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+ //TODO: IPV6-WORK
+ h := AddrPortFromIp4AndPort(have[k])
+ if !(h == w) {
+ assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
}
}
}
-
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
- addrs := make([]*udp.Addr, len(ips))
- for k, v := range ips {
- addrs[k] = NewUDPAddrFromLH4(v)
- }
- return addrs
-}
diff --git a/main.go b/main.go
index d36a2fda8..e60dbd922 100644
--- a/main.go
+++ b/main.go
@@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"net"
+ "net/netip"
"time"
"github.com/sirupsen/logrus"
@@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
- // TODO: make sure mask is 4 bytes
- tunCidr := certificate.Details.Ips[0]
+ ones, _ := certificate.Details.Ips[0].Mask.Size()
+ addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+ if !ok {
+ err = util.NewContextualError(
+ "Invalid ip address in certificate",
+ m{"vpnIp": certificate.Details.Ips[0].IP},
+ nil,
+ )
+ return nil, err
+ }
+ tunCidr := netip.PrefixFrom(addr, ones)
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil {
@@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if !configTest {
rawListenHost := c.GetString("listen.host", "0.0.0.0")
- var listenHost *net.IPAddr
+ var listenHost netip.Addr
if rawListenHost == "[::]" {
// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
- listenHost = &net.IPAddr{IP: net.IPv6zero}
+ listenHost = netip.IPv6Unspecified()
} else {
- listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
+ ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
}
+ if len(ips) == 0 {
+ return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
+ }
+ listenHost = ips[0].Unmap()
}
for i := 0; i < routines; i++ {
- l.Infof("listening %q %d", listenHost.IP, port)
- udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
+ l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
+ udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
@@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil {
return nil, util.NewContextualError("Failed to get listening port", nil, err)
}
- port = int(uPort.Port)
+ port = int(uPort.Port())
}
}
}
@@ -201,7 +215,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
handshakeConfig := HandshakeConfig{
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
- retries: c.GetInt("handshakes.retries", DefaultHandshakeRetries),
+ retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
useRelays: useRelays,
@@ -289,7 +303,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
ifce.multiPort.TxBasePort = uint16(port)
ifce.multiPort.TxPorts = c.GetInt("tun.multiport.tx_ports", 100)
ifce.multiPort.TxHandshake = c.GetBool("tun.multiport.tx_handshake", false)
- ifce.multiPort.TxHandshakeDelay = c.GetInt("tun.multiport.tx_handshake_delay", 2)
+ ifce.multiPort.TxHandshakeDelay = int64(c.GetInt("tun.multiport.tx_handshake_delay", 2))
ifce.udpRaw.ReloadConfig(c)
}
ifce.multiPort.Tx = tx
diff --git a/outside.go b/outside.go
index 1595e6a00..274600441 100644
--- a/outside.go
+++ b/outside.go
@@ -4,6 +4,7 @@ import (
"encoding/binary"
"errors"
"fmt"
+ "net/netip"
"time"
"github.com/flynn/noise"
@@ -11,7 +12,6 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
"google.golang.org/protobuf/proto"
@@ -21,9 +21,10 @@ const (
minFwPacketLen = 4
)
+// TODO: IPV6-WORK this can likely be removed now
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
- addr *udp.Addr,
+ addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
@@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader {
}
}
-func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
- f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
+ f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
- if addr != nil {
- if ip4 := addr.IP.To4(); ip4 != nil {
- if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
- if f.l.Level >= logrus.DebugLevel {
- f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
- }
- return
+ if ip.IsValid() {
+ if f.myVpnNet.Contains(ip.Addr()) {
+ if f.l.Level >= logrus.DebugLevel {
+ f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
+ return
}
}
@@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
@@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
- f.handleHostRoaming(hostinfo, addr)
+ f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo.localIndexId)
f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
- f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+ f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
- hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
@@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
return
}
- lhf(addr, hostinfo.vpnIp, d)
+ lhf(ip, hostinfo.vpnIp, d)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
- hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt test packet")
@@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
- f.handleHostRoaming(hostinfo, addr)
+ f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
@@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- f.handshakeManager.HandleIncoming(addr, via, packet, h)
+ f.handshakeManager.HandleIncoming(ip, via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- f.handleRecvError(addr, h)
+ f.handleRecvError(ip, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
- hostinfo.logger(f.l).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
- hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
@@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
+ hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
}
- f.handleHostRoaming(hostinfo, addr)
+ f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo.localIndexId)
}
@@ -264,47 +263,44 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
- if addr != nil && !hostinfo.remote.Equals(addr) {
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
+ if ip.IsValid() && hostinfo.remote != ip {
if hostinfo.multiportRx {
// If the remote is sending with multiport, we aren't roaming unless
// the IP has changed
- if hostinfo.remote.IP.Equal(addr.IP) {
+ if hostinfo.remote.Compare(ip) == 0 {
return
}
// Keep the port from the original hostinfo, because the remote is transmitting from multiport ports
- addr = &udp.Addr{
- IP: addr.IP,
- Port: hostinfo.remote.Port,
- }
+ ip = netip.AddrPortFrom(ip.Addr(), hostinfo.remote.Port())
}
- if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
- hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
+ if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
+ hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
return
}
- if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+ if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
- hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+ hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
- hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+ hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
- hostinfo.SetRemote(addr)
+ hostinfo.SetRemote(ip)
}
}
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
- if addr != nil {
+ if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
return false
} else {
@@ -353,8 +349,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Firewall packets are locally oriented
if incoming {
- fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
- fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
+ //TODO: IPV6-WORK
+ fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
+ fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@@ -363,8 +360,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
- fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
- fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
+ //TODO: IPV6-WORK
+ fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
+ fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@@ -438,13 +436,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return true
}
-func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {
- if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) {
+func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
+ if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) {
f.sendRecvError(endpoint, index)
}
}
-func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
+func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index
@@ -457,7 +455,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
}
}
-func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
+func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
@@ -474,7 +472,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
return
}
- if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
+ if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}
diff --git a/outside_test.go b/outside_test.go
index 682107bb0..f9d4bfa48 100644
--- a/outside_test.go
+++ b/outside_test.go
@@ -2,10 +2,10 @@ package nebula
import (
"net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/firewall"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
)
@@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
- assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
- assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
+ assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
+ assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4))
@@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2))
- assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
- assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
+ assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
+ assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
}
diff --git a/overlay/device.go b/overlay/device.go
index 3f3f2eb47..50ad6ad5b 100644
--- a/overlay/device.go
+++ b/overlay/device.go
@@ -2,16 +2,14 @@ package overlay
import (
"io"
- "net"
-
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
type Device interface {
io.ReadWriteCloser
Activate() error
- Cidr() *net.IPNet
+ Cidr() netip.Prefix
Name() string
- RouteFor(iputil.VpnIp) iputil.VpnIp
+ RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error)
}
diff --git a/overlay/route.go b/overlay/route.go
index 64c624c7e..8ccc9943c 100644
--- a/overlay/route.go
+++ b/overlay/route.go
@@ -1,34 +1,30 @@
package overlay
import (
- "bytes"
"fmt"
"math"
"net"
+ "net/netip"
"runtime"
"strconv"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
type Route struct {
MTU int
Metric int
- Cidr *net.IPNet
- Via *iputil.VpnIp
+ Cidr netip.Prefix
+ Via netip.Addr
Install bool
}
// Equal determines if a route that could be installed in the system route table is equal to another
// Via is ignored since that is only consumed within nebula itself
func (r Route) Equal(t Route) bool {
- if !r.Cidr.IP.Equal(t.Cidr.IP) {
- return false
- }
- if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
+ if r.Cidr != t.Cidr {
return false
}
if r.Metric != t.Metric {
@@ -51,21 +47,21 @@ func (r Route) String() string {
return s
}
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
- routeTree := cidr.NewTree4[iputil.VpnIp]()
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
+ routeTree := new(bart.Table[netip.Addr])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
}
- if r.Via != nil {
- routeTree.AddCIDR(r.Cidr, *r.Via)
+ if r.Via.IsValid() {
+ routeTree.Insert(r.Cidr, r.Via)
}
}
return routeTree, nil
}
-func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.routes")
@@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
MTU: mtu,
}
- _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+ r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
}
- if !ipWithin(network, r.Cidr) {
+ if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
i+1,
@@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return routes, nil
}
-func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.unsafe_routes")
@@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
}
- nVia := net.ParseIP(via)
- if nVia == nil {
- return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
+ viaVpnIp, err := netip.ParseAddr(via)
+ if err != nil {
+ return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
}
rRoute, ok := m["route"]
@@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
}
- viaVpnIp := iputil.Ip2VpnIp(nVia)
-
install := true
rInstall, ok := m["install"]
if ok {
@@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
}
r := Route{
- Via: &viaVpnIp,
+ Via: viaVpnIp,
MTU: mtu,
Metric: metric,
Install: install,
}
- _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+ r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
}
- if ipWithin(network, r.Cidr) {
+ if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
i+1,
diff --git a/overlay/route_test.go b/overlay/route_test.go
index 46fb87ceb..d7913894b 100644
--- a/overlay/route_test.go
+++ b/overlay/route_test.go
@@ -2,11 +2,10 @@ package overlay
import (
"fmt"
- "net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
@@ -14,7 +13,8 @@ import (
func Test_parseRoutes(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
- _, n, _ := net.ParseCIDR("10.0.0.0/24")
+ n, err := netip.ParsePrefix("10.0.0.0/24")
+ assert.NoError(t, err)
// test no routes config
routes, err := parseRoutes(c, n)
@@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
- assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
+ assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
@@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) {
func Test_parseUnsafeRoutes(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
- _, n, _ := net.ParseCIDR("10.0.0.0/24")
+ n, err := netip.ParsePrefix("10.0.0.0/24")
+ assert.NoError(t, err)
// test no routes config
routes, err := parseUnsafeRoutes(c, n)
@@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
- assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
+ assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
@@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
- assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
+ assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
@@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
- _, n, _ := net.ParseCIDR("10.0.0.0/24")
+ n, err := netip.ParsePrefix("10.0.0.0/24")
+ assert.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
@@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) {
routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err)
- ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
- ok, r := routeTree.MostSpecificContains(ip)
+ ip, err := netip.ParseAddr("1.0.0.2")
+ assert.NoError(t, err)
+ r, ok := routeTree.Lookup(ip)
assert.True(t, ok)
- assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
- ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
- ok, r = routeTree.MostSpecificContains(ip)
+ nip, err := netip.ParseAddr("192.168.0.1")
+ assert.NoError(t, err)
+ assert.Equal(t, nip, r)
+
+ ip, err = netip.ParseAddr("1.0.0.1")
+ assert.NoError(t, err)
+ r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
- assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
- ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
- ok, r = routeTree.MostSpecificContains(ip)
+ nip, err = netip.ParseAddr("192.168.0.2")
+ assert.NoError(t, err)
+ assert.Equal(t, nip, r)
+
+ ip, err = netip.ParseAddr("1.1.0.1")
+ assert.NoError(t, err)
+ r, ok = routeTree.Lookup(ip)
assert.False(t, ok)
}
diff --git a/overlay/tun.go b/overlay/tun.go
index cedd7fe76..12460da1f 100644
--- a/overlay/tun.go
+++ b/overlay/tun.go
@@ -1,7 +1,7 @@
package overlay
import (
- "net"
+ "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@@ -11,9 +11,9 @@ import (
const DefaultMTU = 1300
// TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
- return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+ return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, tunCidr)
}
}
-func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil
}
diff --git a/overlay/tun_android.go b/overlay/tun_android.go
index c15827fe6..98ad9b408 100644
--- a/overlay/tun_android.go
+++ b/overlay/tun_android.go
@@ -6,27 +6,26 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"sync/atomic"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type tun struct {
io.ReadWriteCloser
fd int
- cidr *net.IPNet
+ cidr netip.Prefix
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
return t, nil
}
-func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go
index 1c6382827..0b573e6b3 100644
--- a/overlay/tun_darwin.go
+++ b/overlay/tun_darwin.go
@@ -8,15 +8,15 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"os"
"sync/atomic"
"syscall"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
@@ -25,10 +25,10 @@ import (
type tun struct {
io.ReadWriteCloser
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
@@ -73,7 +73,7 @@ type ifreqMTU struct {
pad [8]byte
}
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@@ -188,8 +188,13 @@ func (t *tun) Activate() error {
var addr, mask [4]byte
- copy(addr[:], t.cidr.IP.To4())
- copy(mask[:], t.cidr.Mask)
+ if !t.cidr.Addr().Is4() {
+ //TODO: IPV6-WORK
+ panic("need ipv6")
+ }
+
+ addr = t.cidr.Addr().As4()
+ copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket(
unix.AF_INET,
@@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- ok, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, ok := t.routeTree.Load().Lookup(ip)
if ok {
return r
}
-
- return 0
+ return netip.Addr{}
}
// Get the LinkAddr for the interface of the given name
@@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error {
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load()
for _, r := range routes {
- if r.Via == nil || !r.Install {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- copy(routeAddr.IP[:], r.Cidr.IP.To4())
- copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+ if !r.Cidr.Addr().Is4() {
+ //TODO: implement ipv6
+ panic("Cant handle ipv6 routes yet")
+ }
+
+ routeAddr.IP = r.Cidr.Addr().As4()
+ //TODO: we could avoid the copy
+ copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil {
@@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
- copy(routeAddr.IP[:], r.Cidr.IP.To4())
- copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+ if r.Cidr.Addr().Is6() {
+ //TODO: implement ipv6
+ panic("Cant handle ipv6 routes yet")
+ }
+
+ routeAddr.IP = r.Cidr.Addr().As4()
+ copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil {
@@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -547,3 +562,11 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
+
+func prefixToMask(prefix netip.Prefix) []byte {
+ pLen := 128
+ if prefix.Addr().Is4() {
+ pLen = 32
+ }
+ return net.CIDRMask(prefix.Bits(), pLen)
+}
diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go
index e1e4ede67..130f8f99f 100644
--- a/overlay/tun_disabled.go
+++ b/overlay/tun_disabled.go
@@ -3,7 +3,7 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"strings"
"github.com/rcrowley/go-metrics"
@@ -13,7 +13,7 @@ import (
type disabledTun struct {
read chan []byte
- cidr *net.IPNet
+ cidr netip.Prefix
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
@@ -21,7 +21,7 @@ type disabledTun struct {
l *logrus.Logger
}
-func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
cidr: cidr,
read: make(chan []byte, queueLen),
@@ -43,11 +43,11 @@ func (*disabledTun) Activate() error {
return nil
}
-func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
- return 0
+func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
+ return netip.Addr{}
}
-func (t *disabledTun) Cidr() *net.IPNet {
+func (t *disabledTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go
index 3b1b80f1a..bdfeb5802 100644
--- a/overlay/tun_freebsd.go
+++ b/overlay/tun_freebsd.go
@@ -9,7 +9,7 @@ import (
"fmt"
"io"
"io/fs"
- "net"
+ "net/netip"
"os"
"os/exec"
"strconv"
@@ -17,10 +17,9 @@ import (
"syscall"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
@@ -48,10 +47,10 @@ type ifreqDestroy struct {
type tun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -79,11 +78,11 @@ func (t *tun) Close() error {
return nil
}
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
@@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
- cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
- if r.Via == nil || !r.Install {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go
index ba15d665e..20981f08c 100644
--- a/overlay/tun_ios.go
+++ b/overlay/tun_ios.go
@@ -7,32 +7,31 @@ import (
"errors"
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"sync"
"sync/atomic"
"syscall"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type tun struct {
io.ReadWriteCloser
- cidr *net.IPNet
+ cidr netip.Prefix
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
-func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
cidr: cidr,
@@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close()
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go
index 2f06951af..0e7e20d41 100644
--- a/overlay/tun_linux.go
+++ b/overlay/tun_linux.go
@@ -4,19 +4,18 @@
package overlay
import (
- "bytes"
"fmt"
"io"
"net"
+ "net/netip"
"os"
"strings"
"sync/atomic"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@@ -26,7 +25,7 @@ type tun struct {
io.ReadWriteCloser
fd int
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
@@ -34,7 +33,7 @@ type tun struct {
ioctlFd uintptr
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeChan chan struct{}
useSystemRoutes bool
@@ -65,7 +64,7 @@ type ifreqQLEN struct {
pad [8]byte
}
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
@@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
return t, nil
}
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t
return t, nil
}
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
@@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -275,8 +274,10 @@ func (t *tun) Activate() error {
var addr, mask [4]byte
- copy(addr[:], t.cidr.IP.To4())
- copy(mask[:], t.cidr.Mask)
+ //TODO: IPV6-WORK
+ addr = t.cidr.Addr().As4()
+ tmask := net.CIDRMask(t.cidr.Bits(), 32)
+ copy(mask[:], tmask)
s, err := unix.Socket(
unix.AF_INET,
@@ -364,14 +365,19 @@ func (t *tun) setMTU() {
func (t *tun) setDefaultRoute() error {
// Default route
- dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
+
+ dr := &net.IPNet{
+ IP: t.cidr.Masked().Addr().AsSlice(),
+ Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+ }
+
nr := netlink.Route{
LinkIndex: t.deviceIndex,
Dst: dr,
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
- Src: t.cidr.IP,
+ Src: net.IP(t.cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
@@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
+ dr := &net.IPNet{
+ IP: r.Cidr.Masked().Addr().AsSlice(),
+ Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+ }
+
nr := netlink.Route{
LinkIndex: t.deviceIndex,
- Dst: r.Cidr,
+ Dst: dr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
@@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) {
continue
}
+ dr := &net.IPNet{
+ IP: r.Cidr.Masked().Addr().AsSlice(),
+ Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+ }
+
nr := netlink.Route{
LinkIndex: t.deviceIndex,
- Dst: r.Cidr,
+ Dst: dr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
@@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) {
}
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
- if !t.cidr.Contains(r.Gw) {
+ //TODO: IPV6-WORK what if not ok?
+ gwAddr, ok := netip.AddrFromSlice(r.Gw)
+ if !ok {
+ t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
+ return
+ }
+
+ gwAddr = gwAddr.Unmap()
+ if !t.cidr.Contains(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
return
@@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
- newTree := cidr.NewTree4[iputil.VpnIp]()
- if r.Type == unix.RTM_NEWROUTE {
- for _, oldR := range t.routeTree.Load().List() {
- newTree.AddCIDR(oldR.CIDR, oldR.Value)
- }
+ dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
+ if !ok {
+ t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
+ return
+ }
+
+ ones, _ := r.Dst.Mask.Size()
+ dst := netip.PrefixFrom(dstAddr, ones)
+
+ newTree := t.routeTree.Load().Clone()
+ if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
- newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
+ newTree.Insert(dst, gwAddr)
} else {
- gw := iputil.Ip2VpnIp(r.Gw)
- for _, oldR := range t.routeTree.Load().List() {
- if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
- // This is the record to delete
- t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
- continue
- }
-
- newTree.AddCIDR(oldR.CIDR, oldR.Value)
- }
+ newTree.Delete(dst)
+ t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
}
-
t.routeTree.Store(newTree)
}
diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go
index cc0216fe9..24ab24f78 100644
--- a/overlay/tun_netbsd.go
+++ b/overlay/tun_netbsd.go
@@ -6,7 +6,7 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"os/exec"
"regexp"
@@ -15,10 +15,9 @@ import (
"syscall"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
@@ -29,10 +28,10 @@ type ifreqDestroy struct {
type tun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -59,13 +58,13 @@ func (t *tun) Close() error {
return nil
}
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
@@ -109,13 +108,13 @@ func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
- cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
- cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
+ cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
- if r.Via == nil || !r.Install {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
- cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go
index 53f57b137..6463ccbba 100644
--- a/overlay/tun_openbsd.go
+++ b/overlay/tun_openbsd.go
@@ -6,7 +6,7 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"os/exec"
"regexp"
@@ -14,19 +14,18 @@ import (
"sync/atomic"
"syscall"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type tun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -43,13 +42,13 @@ func (t *tun) Close() error {
return nil
}
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
- cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -139,7 +138,7 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
- cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
+ cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@@ -149,20 +148,20 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
- if r.Via == nil || !r.Install {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
- cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go
index 383398322..ba15723a1 100644
--- a/overlay/tun_tester.go
+++ b/overlay/tun_tester.go
@@ -6,21 +6,20 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"sync/atomic"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
type TestTun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ routeTree *bart.Table[netip.Addr]
l *logrus.Logger
closed atomic.Bool
@@ -28,7 +27,7 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula
}
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
if err != nil {
return nil, err
@@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e
}, nil
}
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
@@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
-func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Lookup(ip)
return r
}
@@ -96,7 +95,7 @@ func (t *TestTun) Activate() error {
return nil
}
-func (t *TestTun) Cidr() *net.IPNet {
+func (t *TestTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go
index a1acd2b25..d78f564cf 100644
--- a/overlay/tun_water_windows.go
+++ b/overlay/tun_water_windows.go
@@ -4,30 +4,30 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"os/exec"
"strconv"
"sync/atomic"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/songgao/water"
)
type waterTun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
f *net.Interface
*water.Interface
}
-func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
+func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{
cidr: cidr,
@@ -70,8 +70,8 @@ func (t *waterTun) Activate() error {
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device),
"source=static",
- fmt.Sprintf("addr=%s", t.cidr.IP),
- fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
+ fmt.Sprintf("addr=%s", t.cidr.Addr()),
+ fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
"gateway=none",
).Run()
if err != nil {
@@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
- if r.Via == nil || !r.Install {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
@@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) {
}
}
-func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *waterTun) Cidr() *net.IPNet {
+func (t *waterTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go
index f85ee9cee..3d883093c 100644
--- a/overlay/tun_windows.go
+++ b/overlay/tun_windows.go
@@ -5,7 +5,7 @@ package overlay
import (
"fmt"
- "net"
+ "net/netip"
"os"
"path/filepath"
"runtime"
@@ -15,11 +15,11 @@ import (
"github.com/slackhq/nebula/config"
)
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
useWintun := true
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go
index 197e3a717..d0103879a 100644
--- a/overlay/tun_wintun_windows.go
+++ b/overlay/tun_wintun_windows.go
@@ -4,15 +4,13 @@ import (
"crypto"
"fmt"
"io"
- "net"
"net/netip"
"sync/atomic"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
@@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
- cidr *net.IPNet
- prefix netip.Prefix
+ cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
@@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
-func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
+func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
- prefix, err := iputil.ToNetIpPrefix(*cidr)
- if err != nil {
- return nil, err
- }
-
t := &winTun{
Device: deviceName,
cidr: cidr,
- prefix: prefix,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error {
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
- err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
+ err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
@@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error {
foundDefault4 := false
for _, r := range routes {
- if r.Via == nil || !r.Install {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
- if err != nil {
- retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
- if logErrors {
- retErr.Log(t.l)
- continue
- } else {
- return retErr
- }
- }
-
// Add our unsafe route
- err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
+ err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
@@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
}
if !foundDefault4 {
- if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
+ if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
@@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
continue
}
- prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
- if err != nil {
- t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
- continue
- }
-
- err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
+ err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
@@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error {
return nil
}
-func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *winTun) Cidr() *net.IPNet {
+func (t *winTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/user.go b/overlay/user.go
index 9d819ae99..1bb4ef5f7 100644
--- a/overlay/user.go
+++ b/overlay/user.go
@@ -2,18 +2,17 @@ package overlay
import (
"io"
- "net"
+ "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
-func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return NewUserDevice(tunCidr)
}
-func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
// these pipes guarantee each write/read will match 1:1
or, ow := io.Pipe()
ir, iw := io.Pipe()
@@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
}
type UserDevice struct {
- tunCidr *net.IPNet
+ tunCidr netip.Prefix
outboundReader *io.PipeReader
outboundWriter *io.PipeWriter
@@ -39,9 +38,9 @@ type UserDevice struct {
func (d *UserDevice) Activate() error {
return nil
}
-func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr }
-func (d *UserDevice) Name() string { return "faketun0" }
-func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr }
+func (d *UserDevice) Name() string { return "faketun0" }
+func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}
diff --git a/pki.go b/pki.go
index 91478ce51..ab95a0477 100644
--- a/pki.go
+++ b/pki.go
@@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
}
if !initial {
+ //TODO: include check for mask equality as well
+
// did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Details.Ips
diff --git a/relay_manager.go b/relay_manager.go
index 7aa06ccb4..1a3a4d48f 100644
--- a/relay_manager.go
+++ b/relay_manager.go
@@ -2,14 +2,15 @@ package nebula
import (
"context"
+ "encoding/binary"
"errors"
"fmt"
+ "net/netip"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
)
type relayManager struct {
@@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) {
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
-func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) {
+func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
hm.Lock()
defer hm.Unlock()
for i := 0; i < 32; i++ {
@@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(m.RelayFromIp),
- "relayTo": iputil.VpnIp(m.RelayToIp),
+ "relayFrom": m.RelayFromIp,
+ "relayTo": m.RelayToIp,
"initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
Info("handleCreateRelayResponse")
- target := iputil.VpnIp(m.RelayToIp)
+ target := m.RelayToIp
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+ targetAddr := netip.AddrFrom4(b)
relay, err := rm.EstablishRelay(h, m)
if err != nil {
@@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
return
}
- peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
+ peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
return
}
if peerRelay.State == PeerRequested {
+ //TODO: IPV6-WORK
+ b = peerHostInfo.vpnIp.As4()
peerRelay.State = Established
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: peerRelay.LocalIndex,
InitiatorRelayIndex: peerRelay.RemoteIndex,
- RelayFromIp: uint32(peerHostInfo.vpnIp),
+ RelayFromIp: binary.BigEndian.Uint32(b[:]),
RelayToIp: uint32(target),
}
msg, err := resp.Marshal()
@@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
} else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(resp.RelayFromIp),
- "relayTo": iputil.VpnIp(resp.RelayToIp),
+ "relayFrom": resp.RelayFromIp,
+ "relayTo": resp.RelayToIp,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": peerHostInfo.vpnIp}).
@@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
}
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
+ from := netip.AddrFrom4(b)
- from := iputil.VpnIp(m.RelayFromIp)
- target := iputil.VpnIp(m.RelayToIp)
+ binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+ target := netip.AddrFrom4(b)
logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from,
@@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
- if from == f.myVpnIp {
- logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
+ if from == f.myVpnNet.Addr() {
+ logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return
}
// Is the target of the relay me?
- if target == f.myVpnIp {
+ if target == f.myVpnNet.Addr() {
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok {
switch existingRelay.State {
@@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
return
}
+ //TODO: IPV6-WORK
+ fromB := from.As4()
+ targetB := target.As4()
+
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
- RelayFromIp: uint32(from),
- RelayToIp: uint32(target),
+ RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := resp.Marshal()
if err != nil {
@@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(resp.RelayFromIp),
- "relayTo": iputil.VpnIp(resp.RelayToIp),
+ //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
+ "relayFrom": from,
+ "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
@@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.Handshake(target)
return
}
- if peer.remote == nil {
+ if !peer.remote.IsValid() {
// Only create relays to peers for whom I have a direct connection
return
}
@@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
sendCreateRequest = true
}
if sendCreateRequest {
+ //TODO: IPV6-WORK
+ fromB := h.vpnIp.As4()
+ targetB := target.As4()
+
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
- RelayFromIp: uint32(h.vpnIp),
- RelayToIp: uint32(target),
+ RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := req.Marshal()
if err != nil {
@@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(req.RelayFromIp),
- "relayTo": iputil.VpnIp(req.RelayToIp),
+ //TODO: IPV6-WORK another lazy used to use the req object
+ "relayFrom": h.vpnIp,
+ "relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": target}).
@@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
+ //TODO: IPV6-WORK
+ fromB := h.vpnIp.As4()
+ targetB := target.As4()
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
- RelayFromIp: uint32(h.vpnIp),
- RelayToIp: uint32(target),
+ RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := resp.Marshal()
if err != nil {
@@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(resp.RelayFromIp),
- "relayTo": iputil.VpnIp(resp.RelayToIp),
+ //TODO: IPV6-WORK more lazy, used to use resp object
+ "relayFrom": h.vpnIp,
+ "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
@@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
}
}
}
-
-func (rm *relayManager) RemoveRelay(localIdx uint32) {
- rm.hostmap.RemoveRelay(localIdx)
-}
diff --git a/remote_list.go b/remote_list.go
index 60a1afdaf..fa14f4295 100644
--- a/remote_list.go
+++ b/remote_list.go
@@ -1,7 +1,6 @@
package nebula
import (
- "bytes"
"context"
"net"
"net/netip"
@@ -12,16 +11,14 @@ import (
"time"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
-type forEachFunc func(addr *udp.Addr, preferred bool)
+type forEachFunc func(addr netip.AddrPort, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
+type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
@@ -30,9 +27,9 @@ type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here
type Cache struct {
- Learned []*udp.Addr `json:"learned,omitempty"`
- Reported []*udp.Addr `json:"reported,omitempty"`
- Relay []*net.IP `json:"relay"`
+ Learned []netip.AddrPort `json:"learned,omitempty"`
+ Reported []netip.AddrPort `json:"reported,omitempty"`
+ Relay []netip.Addr `json:"relay"`
}
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -46,7 +43,7 @@ type cache struct {
}
type cacheRelay struct {
- relay []uint32
+ relay []netip.Addr
}
// cacheV4 stores learned and reported ipv4 records under cache
@@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
continue
}
for _, a := range addrs {
- netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+ netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{}
}
}
origSet := r.ips.Load()
@@ -193,22 +190,22 @@ type RemoteList struct {
sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand.
- addrs []*udp.Addr
+ addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
- relays []*iputil.VpnIp
+ relays []netip.Addr
// These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet
- cache map[iputil.VpnIp]*cache
+ cache map[netip.Addr]*cache
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
- badRemotes []*udp.Addr
+ badRemotes []netip.AddrPort
// A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool
@@ -217,9 +214,9 @@ type RemoteList struct {
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{
- addrs: make([]*udp.Addr, 0),
- relays: make([]*iputil.VpnIp, 0),
- cache: make(map[iputil.VpnIp]*cache),
+ addrs: make([]netip.AddrPort, 0),
+ relays: make([]netip.Addr, 0),
+ cache: make(map[netip.Addr]*cache),
shouldAdd: shouldAdd,
}
}
@@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
+func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
@@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
// ForEach locks and will call the forEachFunc for every deduplicated address in the list
// The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
+func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) {
r.Rebuild(preferredRanges)
r.RLock()
for _, v := range r.addrs {
- forEach(v, isPreferred(v.IP, preferredRanges))
+ forEach(v, isPreferred(v.Addr(), preferredRanges))
}
r.RUnlock()
}
// CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
+func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort {
if r == nil {
return nil
}
@@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
r.RLock()
defer r.RUnlock()
- c := make([]*udp.Addr, len(r.addrs))
+ c := make([]netip.AddrPort, len(r.addrs))
for i, v := range r.addrs {
- c[i] = v.Copy()
+ c[i] = v
}
return c
}
@@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
// TODO: this needs to support the allow list list
-func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
+func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
r.Lock()
defer r.Unlock()
- if v4 := addr.IP.To4(); v4 != nil {
- r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
+ if remote.Addr().Is4() {
+ r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
} else {
- r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
+ r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
}
}
@@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
c := cm[vpnIp]
if c == nil {
c = &Cache{
- Learned: make([]*udp.Addr, 0),
- Reported: make([]*udp.Addr, 0),
- Relay: make([]*net.IP, 0),
+ Learned: make([]netip.AddrPort, 0),
+ Reported: make([]netip.AddrPort, 0),
+ Relay: make([]netip.Addr, 0),
}
cm[vpnIp] = c
}
@@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap {
if mc.v4 != nil {
if mc.v4.learned != nil {
- c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
+ c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
}
for _, a := range mc.v4.reported {
- c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
+ c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
}
}
if mc.v6 != nil {
if mc.v6.learned != nil {
- c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
+ c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
}
for _, a := range mc.v6.reported {
- c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
+ c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
}
}
if mc.relay != nil {
for _, a := range mc.relay.relay {
- nip := iputil.VpnIp(a).ToIP()
- c.Relay = append(c.Relay, &nip)
+ c.Relay = append(c.Relay, a)
}
}
}
@@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad *udp.Addr) {
- if bad == nil {
+func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
+ if !bad.IsValid() {
// relays can have nil udp Addrs
return
}
@@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) {
}
// We copy here because we are taking something else's memory and we can't trust everything
- r.badRemotes = append(r.badRemotes, bad.Copy())
+ r.badRemotes = append(r.badRemotes, bad)
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true
}
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
-func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
+func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
r.RLock()
defer r.RUnlock()
- c := make([]*udp.Addr, len(r.badRemotes))
+ c := make([]netip.AddrPort, len(r.badRemotes))
for i, v := range r.badRemotes {
- c[i] = v.Copy()
+ c[i] = v
}
return c
}
@@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() {
// Rebuild locks and generates the deduplicated address list only if there is work to be done
// There is generally no reason to call this directly but it is safe to do so
-func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
+func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
r.Lock()
defer r.Unlock()
@@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
}
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
-func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
+func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
for _, v := range r.badRemotes {
- if v.Equals(remote) {
+ if v == remote {
return true
}
}
@@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
}
}
-func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) {
+func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeRelay(ownerVpnIp)
@@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort)
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort)
}
}
-func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay {
+func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() {
for _, c := range r.cache {
if c.v4 != nil {
if c.v4.learned != nil {
- u := NewUDPAddrFromLH4(c.v4.learned)
+ u := AddrPortFromIp4AndPort(c.v4.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v4.reported {
- u := NewUDPAddrFromLH4(v)
+ u := AddrPortFromIp4AndPort(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
@@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() {
if c.v6 != nil {
if c.v6.learned != nil {
- u := NewUDPAddrFromLH6(c.v6.learned)
+ u := AddrPortFromIp6AndPort(c.v6.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v6.reported {
- u := NewUDPAddrFromLH6(v)
+ u := AddrPortFromIp6AndPort(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
@@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() {
if c.relay != nil {
for _, v := range c.relay.relay {
- ip := iputil.VpnIp(v)
- relays = append(relays, &ip)
+ relays = append(relays, v)
}
}
}
@@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetIPs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
- v6 := addr.Addr().As16()
- addrs = append(addrs, &udp.Addr{
- IP: v6[:],
- Port: addr.Port(),
- })
+ addrs = append(addrs, addr)
}
}
@@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() {
}
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
-func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
+func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
n := len(r.addrs)
if n < 2 {
return
@@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
b := r.addrs[j]
// Preferred addresses first
- aPref := isPreferred(a.IP, preferredRanges)
- bPref := isPreferred(b.IP, preferredRanges)
+ aPref := isPreferred(a.Addr(), preferredRanges)
+ bPref := isPreferred(b.Addr(), preferredRanges)
switch {
case aPref && !bPref:
// If i is preferred and j is not, i is less than j
@@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
}
// ipv6 addresses 2nd
- a4 := a.IP.To4()
- b4 := b.IP.To4()
+ a4 := a.Addr().Is4()
+ b4 := b.Addr().Is4()
switch {
- case a4 == nil && b4 != nil:
+ case a4 == false && b4 == true:
// If i is v6 and j is v4, i is less than j
return true
- case a4 != nil && b4 == nil:
+ case a4 == true && b4 == false:
// If j is v6 and i is v4, i is not less than j
return false
- case a4 != nil && b4 != nil:
- // Special case for ipv4, a4 and b4 are not nil
- aPrivate := isPrivateIP(a4)
- bPrivate := isPrivateIP(b4)
+ case a4 == true && b4 == true:
+ // i and j are both ipv4
+ aPrivate := a.Addr().IsPrivate()
+ bPrivate := b.Addr().IsPrivate()
switch {
case !aPrivate && bPrivate:
// If i is a public ip (not private) and j is a private ip, i is less then j
@@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
}
// lexical order of ips 3rd
- c := bytes.Compare(a.IP, b.IP)
+ c := a.Addr().Compare(b.Addr())
if c == 0 {
// Ips are the same, Lexical order of ports 4th
- return a.Port < b.Port
+ return a.Port() < b.Port()
}
// Ip wasn't the same
@@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
// Deduplicate
a, b := 0, 1
for b < n {
- if !r.addrs[a].Equals(r.addrs[b]) {
+ if r.addrs[a] != r.addrs[b] {
a++
if a != b {
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
@@ -693,7 +684,7 @@ func minInt(a, b int) int {
}
// isPreferred returns true of the ip is contained in the preferredRanges list
-func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
+func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
//TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges {
if p.Contains(ip) {
@@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
}
return false
}
-
-var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
-var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
-var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
-
-// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
-func isPrivateIP(ip net.IP) bool {
- //TODO: another great cidrtree option
- //TODO: Private for ipv6 or just let it ride?
- return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
-}
diff --git a/remote_list_test.go b/remote_list_test.go
index 49aa17191..62a892b00 100644
--- a/remote_list_test.go
+++ b/remote_list_test.go
@@ -1,47 +1,47 @@
package nebula
import (
- "net"
+ "encoding/binary"
+ "net/netip"
"testing"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList(nil)
rl.unlockedSetV4(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
+ newIp4AndPortFromString("172.17.0.182:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
+ newIp4AndPortFromString("172.18.0.1:10101"), // this is duped
+ newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe
+ newIp4AndPortFromString("172.19.0.1:10101"),
+ newIp4AndPortFromString("172.31.0.1:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
+ newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
- 1,
- 1,
+ netip.MustParseAddr("0.0.0.1"),
+ netip.MustParseAddr("0.0.0.1"),
[]*Ip6AndPort{
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
- NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
- NewIp6AndPort(net.ParseIP("1:100::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
- NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
+ newIp6AndPortFromString("[1::1]:1"), // this is duped
+ newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
+ newIp6AndPortFromString("[1:100::1]:1"),
+ newIp6AndPortFromString("[1::1]:1"), // this is a dupe
+ newIp6AndPortFromString("[1::1]:2"), // this is a dupe
},
- func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+ func(netip.Addr, *Ip6AndPort) bool { return true },
)
- rl.Rebuild([]*net.IPNet{})
+ rl.Rebuild([]netip.Prefix{})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv6 first, sorted lexically within
@@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
// Now ensure we can hoist ipv4 up
- _, ipNet, err := net.ParseCIDR("0.0.0.0/0")
- assert.NoError(t, err)
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv4 first, public then private, lexically within them
@@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// Ensure we can hoist a specific ipv4 range over anything else
- _, ipNet, err = net.ParseCIDR("172.17.0.0/16")
- assert.NoError(t, err)
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// Preferred ipv4 first
@@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) {
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList(nil)
rl.unlockedSetV4(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+ newIp4AndPortFromString("70.199.182.92:1475"),
+ newIp4AndPortFromString("172.17.0.182:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"),
+ newIp4AndPortFromString("172.18.0.1:10101"),
+ newIp4AndPortFromString("172.19.0.1:10101"),
+ newIp4AndPortFromString("172.31.0.1:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{
- NewIp6AndPort(net.ParseIP("1::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
- NewIp6AndPort(net.ParseIP("1:100::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+ newIp6AndPortFromString("[1::1]:1"),
+ newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+ newIp6AndPortFromString("[1:100::1]:1"),
+ newIp6AndPortFromString("[1::1]:1"), // this is a dupe
},
- func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+ func(netip.Addr, *Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{})
+ rl.Rebuild([]netip.Prefix{})
}
})
- _, ipNet, err := net.ParseCIDR("172.17.0.0/16")
- assert.NoError(b, err)
+ ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{ipNet1})
}
})
- _, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
- assert.NoError(b, err)
+ ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+ rl.Rebuild([]netip.Prefix{ipNet2})
}
})
- _, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
- assert.NoError(b, err)
+ ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
}
})
}
@@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) {
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList(nil)
rl.unlockedSetV4(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+ newIp4AndPortFromString("70.199.182.92:1475"),
+ newIp4AndPortFromString("172.17.0.182:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"),
+ newIp4AndPortFromString("172.18.0.1:10101"),
+ newIp4AndPortFromString("172.19.0.1:10101"),
+ newIp4AndPortFromString("172.31.0.1:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{
- NewIp6AndPort(net.ParseIP("1::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
- NewIp6AndPort(net.ParseIP("1:100::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+ newIp6AndPortFromString("[1::1]:1"),
+ newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+ newIp6AndPortFromString("[1:100::1]:1"),
+ newIp6AndPortFromString("[1::1]:1"), // this is a dupe
},
- func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+ func(netip.Addr, *Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{})
+ rl.Rebuild([]netip.Prefix{})
}
})
- _, ipNet, err := net.ParseCIDR("172.17.0.0/16")
- rl.Rebuild([]*net.IPNet{ipNet})
+ ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
+ rl.Rebuild([]netip.Prefix{ipNet1})
- assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{ipNet1})
}
})
- _, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+ ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
- assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
}
})
- _, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+ ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
- assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
}
})
}
+
+func newIp4AndPortFromString(s string) *Ip4AndPort {
+ a := netip.MustParseAddrPort(s)
+ v4Addr := a.Addr().As4()
+ return &Ip4AndPort{
+ Ip: binary.BigEndian.Uint32(v4Addr[:]),
+ Port: uint32(a.Port()),
+ }
+}
+
+func newIp6AndPortFromString(s string) *Ip6AndPort {
+ a := netip.MustParseAddrPort(s)
+ v6Addr := a.Addr().As16()
+ return &Ip6AndPort{
+ Hi: binary.BigEndian.Uint64(v6Addr[:8]),
+ Lo: binary.BigEndian.Uint64(v6Addr[8:]),
+ Port: uint32(a.Port()),
+ }
+}
diff --git a/service/service.go b/service/service.go
index 6816be673..4ddd30182 100644
--- a/service/service.go
+++ b/service/service.go
@@ -8,6 +8,7 @@ import (
"log"
"math"
"net"
+ "net/netip"
"os"
"strings"
"sync"
@@ -91,7 +92,7 @@ func New(config *config.C) (*Service, error) {
ipNet := device.Cidr()
pa := tcpip.ProtocolAddress{
- AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
+ AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
Protocol: ipv4.ProtocolNumber,
}
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
@@ -153,24 +154,48 @@ func New(config *config.C) (*Service, error) {
return &s, nil
}
-// DialContext dials the provided address. Currently only TCP is supported.
-func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
- if network != "tcp" && network != "tcp4" {
- return nil, errors.New("only tcp is supported")
- }
-
- addr, err := net.ResolveTCPAddr(network, address)
- if err != nil {
- return nil, err
+func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber {
+ if addr.Is6() {
+ return ipv6.ProtocolNumber
}
+ return ipv4.ProtocolNumber
+}
- fullAddr := tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.AddrFromSlice(addr.IP),
- Port: uint16(addr.Port),
+// DialContext dials the provided address.
+func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ switch network {
+ case "udp", "udp4", "udp6":
+ addr, err := net.ResolveUDPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ fullAddr := tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.AddrFromSlice(addr.IP),
+ Port: uint16(addr.Port),
+ }
+ num := getProtocolNumber(addr.AddrPort().Addr())
+ return gonet.DialUDP(s.ipstack, nil, &fullAddr, num)
+ case "tcp", "tcp4", "tcp6":
+ addr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ fullAddr := tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.AddrFromSlice(addr.IP),
+ Port: uint16(addr.Port),
+ }
+ num := getProtocolNumber(addr.AddrPort().Addr())
+ return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num)
+ default:
+ return nil, fmt.Errorf("unknown network type: %s", network)
}
+}
- return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
+// Dial dials the provided address
+func (s *Service) Dial(network, address string) (net.Conn, error) {
+ return s.DialContext(context.Background(), network, address)
}
// Listen listens on the provided address. Currently only TCP with wildcard
diff --git a/service/service_test.go b/service/service_test.go
index d1909cd15..31762090d 100644
--- a/service/service_test.go
+++ b/service/service_test.go
@@ -4,7 +4,7 @@ import (
"bytes"
"context"
"errors"
- "net"
+ "net/netip"
"testing"
"time"
@@ -18,12 +18,8 @@ import (
type m map[string]interface{}
-func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
-
- vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
- copy(vpnIpNet.IP, udpIp)
-
- _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
+ _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
caB, err := caCrt.MarshalToPEM()
if err != nil {
panic(err)
@@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string,
}
func TestService(t *testing.T) {
- ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+ ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{},
"lighthouse": m{
"am_lighthouse": true,
@@ -94,7 +90,7 @@ func TestService(t *testing.T) {
"port": 4243,
},
})
- b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+ b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
"static_host_map": m{
"10.0.0.1": []string{"localhost:4243"},
},
diff --git a/ssh.go b/ssh.go
index f0961211f..2ff0954d6 100644
--- a/ssh.go
+++ b/ssh.go
@@ -7,6 +7,7 @@ import (
"flag"
"fmt"
"net"
+ "net/netip"
"os"
"reflect"
"runtime"
@@ -18,9 +19,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd"
- "github.com/slackhq/nebula/udp"
)
type sshListHostMapFlags struct {
@@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
}
sort.Slice(hm, func(i, j int) bool {
- return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
+ return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
})
if fs.Json || fs.Pretty {
@@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
- var addr *udp.Addr
+ var addr netip.AddrPort
if flags.Address != "" {
- addr = udp.NewAddrFromString(flags.Address)
- if addr == nil {
+ addr, err = netip.ParseAddrPort(flags.Address)
+ if err != nil {
return w.WriteLine("Address could not be parsed")
}
}
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
- if addr != nil {
+ if addr.IsValid() {
hostInfo.SetRemote(addr)
}
@@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No address was provided")
}
- addr := udp.NewAddrFromString(flags.Address)
- if addr == nil {
+ addr, err := netip.ParseAddrPort(flags.Address)
+ if err != nil {
return w.WriteLine("Address could not be parsed")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
cert := ifce.pki.GetCertState().Certificate
if len(a) > 0 {
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
Error error
Type string
State string
- PeerIp iputil.VpnIp
+ PeerIp netip.Addr
LocalIndex uint32
RemoteIndex uint32
- RelayedThrough []iputil.VpnIp
+ RelayedThrough []netip.Addr
}
type RelayOutput struct {
- NebulaIp iputil.VpnIp
+ NebulaIp netip.Addr
RelayForIps []RelayFor
}
@@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
diff --git a/test/tun.go b/test/tun.go
index 86656c920..fbf58295a 100644
--- a/test/tun.go
+++ b/test/tun.go
@@ -3,23 +3,21 @@ package test
import (
"errors"
"io"
- "net"
-
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
type NoopTun struct{}
-func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
- return 0
+func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
+ return netip.Addr{}
}
func (NoopTun) Activate() error {
return nil
}
-func (NoopTun) Cidr() *net.IPNet {
- return nil
+func (NoopTun) Cidr() netip.Prefix {
+ return netip.Prefix{}
}
func (NoopTun) Name() string {
diff --git a/timeout_test.go b/timeout_test.go
index 3f81ff400..4c6364ef5 100644
--- a/timeout_test.go
+++ b/timeout_test.go
@@ -1,6 +1,7 @@
package nebula
import (
+ "net/netip"
"testing"
"time"
@@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.Equal(t, 0, tw.current)
fps := []firewall.Packet{
- {LocalIP: 1},
- {LocalIP: 2},
- {LocalIP: 3},
- {LocalIP: 4},
+ {LocalIP: netip.MustParseAddr("0.0.0.1")},
+ {LocalIP: netip.MustParseAddr("0.0.0.2")},
+ {LocalIP: netip.MustParseAddr("0.0.0.3")},
+ {LocalIP: netip.MustParseAddr("0.0.0.4")},
}
tw.Add(fps[0], time.Second*1)
diff --git a/udp/conn.go b/udp/conn.go
index a2c24a1f1..fa4e44304 100644
--- a/udp/conn.go
+++ b/udp/conn.go
@@ -1,6 +1,8 @@
package udp
import (
+ "net/netip"
+
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -9,7 +11,7 @@ import (
const MTU = 9001
type EncReader func(
- addr *Addr,
+ addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
@@ -22,9 +24,9 @@ type EncReader func(
type Conn interface {
Rebind() error
- LocalAddr() (*Addr, error)
+ LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
- WriteTo(b []byte, addr *Addr) error
+ WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
}
@@ -34,13 +36,13 @@ type NoopConn struct{}
func (NoopConn) Rebind() error {
return nil
}
-func (NoopConn) LocalAddr() (*Addr, error) {
- return nil, nil
+func (NoopConn) LocalAddr() (netip.AddrPort, error) {
+ return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
return
}
-func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
+func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
diff --git a/udp/temp.go b/udp/temp.go
index 2efe31d24..b281906f5 100644
--- a/udp/temp.go
+++ b/udp/temp.go
@@ -1,9 +1,10 @@
package udp
import (
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
-type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
+// TODO: IPV6-WORK this can likely be removed now
+type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)
diff --git a/udp/udp_all.go b/udp/udp_all.go
deleted file mode 100644
index 093bf69cc..000000000
--- a/udp/udp_all.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package udp
-
-import (
- "encoding/json"
- "fmt"
- "net"
- "strconv"
-)
-
-type m map[string]interface{}
-
-type Addr struct {
- IP net.IP
- Port uint16
-}
-
-func NewAddr(ip net.IP, port uint16) *Addr {
- addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
- copy(addr.IP, ip.To16())
- return &addr
-}
-
-func NewAddrFromString(s string) *Addr {
- ip, port, err := ParseIPAndPort(s)
- //TODO: handle err
- _ = err
- return &Addr{IP: ip.To16(), Port: port}
-}
-
-func (ua *Addr) Equals(t *Addr) bool {
- if t == nil || ua == nil {
- return t == nil && ua == nil
- }
- return ua.IP.Equal(t.IP) && ua.Port == t.Port
-}
-
-func (ua *Addr) String() string {
- if ua == nil {
- return ""
- }
-
- return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
-}
-
-func (ua *Addr) MarshalJSON() ([]byte, error) {
- if ua == nil {
- return nil, nil
- }
-
- return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
-}
-
-func (ua *Addr) Copy() *Addr {
- if ua == nil {
- return nil
- }
-
- nu := Addr{
- Port: ua.Port,
- IP: make(net.IP, len(ua.IP)),
- }
-
- copy(nu.IP, ua.IP)
- return &nu
-}
-
-type AddrSlice []*Addr
-
-func (a AddrSlice) Equal(b AddrSlice) bool {
- if len(a) != len(b) {
- return false
- }
-
- for i := range a {
- if !a[i].Equals(b[i]) {
- return false
- }
- }
-
- return true
-}
-
-func ParseIPAndPort(s string) (net.IP, uint16, error) {
- rIp, sPort, err := net.SplitHostPort(s)
- if err != nil {
- return nil, 0, err
- }
-
- addr, err := net.ResolveIPAddr("ip", rIp)
- if err != nil {
- return nil, 0, err
- }
-
- iPort, err := strconv.Atoi(sPort)
- if err != nil {
- return nil, 0, err
- }
-
- return addr.IP, uint16(iPort), nil
-}
diff --git a/udp/udp_android.go b/udp/udp_android.go
index 8d6907488..bb1919546 100644
--- a/udp/udp_android.go
+++ b/udp/udp_android.go
@@ -6,13 +6,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go
index 785aa6a74..65ef31a56 100644
--- a/udp/udp_bsd.go
+++ b/udp/udp_bsd.go
@@ -9,13 +9,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go
index 08e1b6a80..183ac7af2 100644
--- a/udp/udp_darwin.go
+++ b/udp/udp_darwin.go
@@ -8,13 +8,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_generic.go b/udp/udp_generic.go
index 1dd6d1de7..2d8453694 100644
--- a/udp/udp_generic.go
+++ b/udp/udp_generic.go
@@ -11,6 +11,7 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@@ -25,7 +26,7 @@ type GenericConn struct {
var _ Conn = &GenericConn{}
-func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
@@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
-func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
- _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
+func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
+ _, err := u.UDPConn.WriteToUDPAddrPort(b, addr)
return err
}
-func (u *GenericConn) LocalAddr() (*Addr, error) {
+func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
- addr := &Addr{IP: make([]byte, len(v.IP))}
- copy(addr.IP, v.IP)
- addr.Port = uint16(v.Port)
- return addr, nil
+ addr, ok := netip.AddrFromSlice(v.IP)
+ if !ok {
+ return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
+ }
+ return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default:
- return nil, fmt.Errorf("LocalAddr returned: %#v", a)
+ return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
@@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
- n, rua, err := u.ReadFromUDP(buffer)
+ n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
- udpAddr.IP = rua.IP
- udpAddr.Port = uint16(rua.Port)
- r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(
+ netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
+ plaintext[:0],
+ buffer[:n],
+ h,
+ fwPacket,
+ lhf,
+ nb,
+ q,
+ cache.Get(u.l),
+ )
}
}
diff --git a/udp/udp_linux.go b/udp/udp_linux.go
index 02c8ce0f1..2eee76ee2 100644
--- a/udp/udp_linux.go
+++ b/udp/udp_linux.go
@@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net"
+ "net/netip"
"syscall"
"unsafe"
@@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
return ip, false
}
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
- ipV4, isV4 := maybeIPV4(ip)
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
af := unix.AF_INET6
- if isV4 {
+ if ip.Is4() {
af = unix.AF_INET
}
syscall.ForkLock.RLock()
@@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//TODO: support multiple listening IPs (for limiting ipv6)
var sa unix.Sockaddr
- if isV4 {
+ if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port}
- copy(sa4.Addr[:], ipV4)
+ sa4.Addr = ip.As4()
sa = sa4
} else {
sa6 := &unix.SockaddrInet6{Port: port}
- copy(sa6.Addr[:], ip.To16())
+ sa6.Addr = ip.As16()
sa = sa6
}
if err = unix.Bind(fd, sa); err != nil {
@@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
- return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
+ return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) Rebind() error {
@@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}
-func (u *StdConn) LocalAddr() (*Addr, error) {
+func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
- return nil, err
+ return netip.AddrPort{}, err
}
- addr := &Addr{}
switch sa := sa.(type) {
case *unix.SockaddrInet4:
- addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
- addr.Port = uint16(sa.Port)
+ return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
+
case *unix.SockaddrInet6:
- addr.IP = sa.Addr[0:]
- addr.Port = uint16(sa.Port)
- }
+ return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
- return addr, nil
+ default:
+ return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
+ }
}
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- udpAddr := &Addr{}
+ var ip netip.Addr
nb := make([]byte, 12, 12)
//TODO: should we track this?
@@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
//metric.Update(int64(n))
for i := 0; i < n; i++ {
if u.isV4 {
- udpAddr.IP = names[i][4:8]
+ ip, _ = netip.AddrFromSlice(names[i][4:8])
+ //TODO: IPV6-WORK what is not ok?
} else {
- udpAddr.IP = names[i][8:24]
+ ip, _ = netip.AddrFromSlice(names[i][8:24])
+ //TODO: IPV6-WORK what is not ok?
}
- udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
- r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(
+ netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
+ plaintext[:0],
+ buffers[i][:msgs[i].Len],
+ h,
+ fwPacket,
+ lhf,
+ nb,
+ q,
+ cache.Get(u.l),
+ )
}
}
}
@@ -197,19 +207,18 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
}
}
-func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
+func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.isV4 {
- return u.writeTo4(b, addr)
+ return u.writeTo4(b, ip)
}
- return u.writeTo6(b, addr)
+ return u.writeTo6(b, ip)
}
-func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
+func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
- // Little Endian -> Network Endian
- rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
- copy(rsa.Addr[:], addr.IP.To16())
+ rsa.Addr = ip.Addr().As16()
+ binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
for {
_, _, err := unix.Syscall6(
@@ -232,17 +241,15 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
}
}
-func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
- addrV4, isAddrV4 := maybeIPV4(addr.IP)
- if !isAddrV4 {
+func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
+ if !ip.Addr().Is4() {
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
- // Little Endian -> Network Endian
- rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
- copy(rsa.Addr[:], addrV4)
+ rsa.Addr = ip.Addr().As4()
+ binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port())
for {
_, _, err := unix.Syscall6(
diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go
index 3c14face3..3b69159ad 100644
--- a/udp/udp_netbsd.go
+++ b/udp/udp_netbsd.go
@@ -8,13 +8,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_raw_linux.go b/udp/udp_raw_linux.go
index 647f3a973..ebb3438bf 100644
--- a/udp/udp_raw_linux.go
+++ b/udp/udp_raw_linux.go
@@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net"
+ "net/netip"
"syscall"
"unsafe"
@@ -74,10 +75,10 @@ func NewRawConn(l *logrus.Logger, ip string, port int, basePort uint16) (*RawCon
// WriteTo must be called with raw leaving the first `udp.RawOverhead` bytes empty,
// for the IP/UDP headers.
-func (u *RawConn) WriteTo(raw []byte, fromPort uint16, addr *Addr) error {
+func (u *RawConn) WriteTo(raw []byte, fromPort uint16, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
- copy(rsa.Addr[:], addr.IP.To4())
+ rsa.Addr = ip.Addr().As4()
totalLen := len(raw)
udpLen := totalLen - ipv4.HeaderLen
@@ -97,7 +98,7 @@ func (u *RawConn) WriteTo(raw []byte, fromPort uint16, addr *Addr) error {
// UDP header
fromPort = u.basePort + fromPort
binary.BigEndian.PutUint16(raw[20:22], uint16(fromPort)) // src port
- binary.BigEndian.PutUint16(raw[22:24], uint16(addr.Port)) // dst port
+ binary.BigEndian.PutUint16(raw[22:24], uint16(ip.Port())) // dst port
binary.BigEndian.PutUint16(raw[24:26], uint16(udpLen)) // UDP length
binary.BigEndian.PutUint16(raw[26:28], 0) // checksum (optional)
@@ -150,8 +151,8 @@ func (u *RawConn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUF)
}
-func (u *RawConn) getMemInfo(meminfo *_SK_MEMINFO) error {
- var vallen uint32 = 4 * _SK_MEMINFO_VARS
+func (u *RawConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
+ var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 {
return err
@@ -161,10 +162,10 @@ func (u *RawConn) getMemInfo(meminfo *_SK_MEMINFO) error {
func NewRawStatsEmitter(rawConn *RawConn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
- var gauges [_SK_MEMINFO_VARS]metrics.Gauge
- var meminfo _SK_MEMINFO
+ var gauges [unix.SK_MEMINFO_VARS]metrics.Gauge
+ var meminfo [unix.SK_MEMINFO_VARS]uint32
if err := rawConn.getMemInfo(&meminfo); err == nil {
- gauges = [_SK_MEMINFO_VARS]metrics.Gauge{
+ gauges = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge("raw.rmem_alloc", nil),
metrics.GetOrRegisterGauge("raw.rcvbuf", nil),
metrics.GetOrRegisterGauge("raw.wmem_alloc", nil),
@@ -182,7 +183,7 @@ func NewRawStatsEmitter(rawConn *RawConn) func() {
return func() {
if err := rawConn.getMemInfo(&meminfo); err == nil {
- for j := 0; j < _SK_MEMINFO_VARS; j++ {
+ for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j]))
}
}
diff --git a/udp/udp_raw_unsupported.go b/udp/udp_raw_unsupported.go
index 287a9abce..e77276f97 100644
--- a/udp/udp_raw_unsupported.go
+++ b/udp/udp_raw_unsupported.go
@@ -5,6 +5,7 @@ package udp
import (
"fmt"
+ "net/netip"
"runtime"
"github.com/sirupsen/logrus"
@@ -19,7 +20,7 @@ func NewRawConn(l *logrus.Logger, ip string, port int, basePort uint16) (*RawCon
return nil, fmt.Errorf("multiport tx is not supported on %s", runtime.GOOS)
}
-func (u *RawConn) WriteTo(raw []byte, fromPort uint16, addr *Addr) error {
+func (u *RawConn) WriteTo(raw []byte, fromPort uint16, addr netip.AddrPort) error {
return fmt.Errorf("multiport tx is not supported on %s", runtime.GOOS)
}
diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go
index 31c1a554c..ee7e1e002 100644
--- a/udp/udp_rio_windows.go
+++ b/udp/udp_rio_windows.go
@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"sync"
"sync/atomic"
"syscall"
@@ -61,16 +62,14 @@ type RIOConn struct {
results [packetsPerRing]winrio.Result
}
-func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
+func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio")
}
u := &RIOConn{l: l}
- addr := [16]byte{}
- copy(addr[:], ip.To16())
- err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
+ err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port})
if err != nil {
return nil, fmt.Errorf("bind: %w", err)
}
@@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
@@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
return
}
- udpAddr.IP = rua.Addr[:]
- p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
- p[0] = byte(rua.Port >> 8)
- p[1] = byte(rua.Port)
- r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(
+ netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
+ plaintext[:0],
+ buffer[:n],
+ h,
+ fwPacket,
+ lhf,
+ nb,
+ q,
+ cache.Get(u.l),
+ )
}
}
@@ -231,7 +235,7 @@ retry:
return n, ep, nil
}
-func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
+func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
if !u.isOpen.Load() {
return net.ErrClosed
}
@@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
packet := u.tx.Push()
packet.addr.Family = windows.AF_INET6
- p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
- p[0] = byte(addr.Port >> 8)
- p[1] = byte(addr.Port)
- copy(packet.addr.Addr[:], addr.IP.To16())
+ packet.addr.Addr = ip.Addr().As16()
+ port := ip.Port()
+ packet.addr.Port = (port >> 8) | ((port & 0xff) << 8)
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
@@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
-func (u *RIOConn) LocalAddr() (*Addr, error) {
+func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock)
if err != nil {
- return nil, err
+ return netip.AddrPort{}, err
}
v6 := sa.(*windows.SockaddrInet6)
- return &Addr{
- IP: v6.Addr[:],
- Port: uint16(v6.Port),
- }, nil
+ return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil
+
}
func (u *RIOConn) Rebind() error {
diff --git a/udp/udp_tester.go b/udp/udp_tester.go
index 55985f47f..f03a3535f 100644
--- a/udp/udp_tester.go
+++ b/udp/udp_tester.go
@@ -4,9 +4,8 @@
package udp
import (
- "fmt"
"io"
- "net"
+ "net/netip"
"sync/atomic"
"github.com/sirupsen/logrus"
@@ -16,30 +15,24 @@ import (
)
type Packet struct {
- ToIp net.IP
- ToPort uint16
- FromIp net.IP
- FromPort uint16
- Data []byte
+ To netip.AddrPort
+ From netip.AddrPort
+ Data []byte
}
func (u *Packet) Copy() *Packet {
n := &Packet{
- ToIp: make(net.IP, len(u.ToIp)),
- ToPort: u.ToPort,
- FromIp: make(net.IP, len(u.FromIp)),
- FromPort: u.FromPort,
- Data: make([]byte, len(u.Data)),
+ To: u.To,
+ From: u.From,
+ Data: make([]byte, len(u.Data)),
}
- copy(n.ToIp, u.ToIp)
- copy(n.FromIp, u.FromIp)
copy(n.Data, u.Data)
return n
}
type TesterConn struct {
- Addr *Addr
+ Addr netip.AddrPort
RxPackets chan *Packet // Packets to receive into nebula
TxPackets chan *Packet // Packets transmitted outside by nebula
@@ -48,9 +41,9 @@ type TesterConn struct {
l *logrus.Logger
}
-func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{
- Addr: &Addr{ip, uint16(port)},
+ Addr: netip.AddrPortFrom(ip, uint16(port)),
RxPackets: make(chan *Packet, 10),
TxPackets: make(chan *Packet, 10),
l: l,
@@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) {
}
if u.l.Level >= logrus.DebugLevel {
u.l.WithField("header", h).
- WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
+ WithField("udpAddr", packet.From).
WithField("dataLen", len(packet.Data)).
Debug("UDP receiving injected packet")
}
@@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
-func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
+func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
if u.closed.Load() {
return io.ErrClosedPipe
}
p := &Packet{
- Data: make([]byte, len(b), len(b)),
- FromIp: make([]byte, 16),
- FromPort: u.Addr.Port,
- ToIp: make([]byte, 16),
- ToPort: addr.Port,
+ Data: make([]byte, len(b), len(b)),
+ From: u.Addr,
+ To: addr,
}
copy(p.Data, b)
- copy(p.ToIp, addr.IP.To16())
- copy(p.FromIp, u.Addr.IP.To16())
-
u.TxPackets <- p
return nil
}
@@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- ua := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
@@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
if !ok {
return
}
- ua.Port = p.FromPort
- copy(ua.IP, p.FromIp.To16())
- r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
@@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() {
return func() {}
}
-func (u *TesterConn) LocalAddr() (*Addr, error) {
+func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
diff --git a/udp/udp_windows.go b/udp/udp_windows.go
index ebcace670..1b777c374 100644
--- a/udp/udp_windows.go
+++ b/udp/udp_windows.go
@@ -6,12 +6,13 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between