diff --git a/libnetwork/drivers/bridge/port_mapping_linux_test.go b/libnetwork/drivers/bridge/port_mapping_linux_test.go index 2bc21e7eb029d..82e1fc8a134f2 100644 --- a/libnetwork/drivers/bridge/port_mapping_linux_test.go +++ b/libnetwork/drivers/bridge/port_mapping_linux_test.go @@ -413,7 +413,7 @@ func TestAddPortMappings(t *testing.T) { ctrIP4 := newIPNet(t, "172.19.0.2/16") ctrIP4Mapped := newIPNet(t, "::ffff:172.19.0.2/112") ctrIP6 := newIPNet(t, "fdf8:b88e:bb5c:3483::2/64") - firstEphemPort := uint16(portallocator.Get().Begin) + firstEphemPort, _ := portallocator.GetPortRange() testcases := []struct { name string @@ -876,8 +876,7 @@ func TestAddPortMappings(t *testing.T) { return net.ParseIP("127.0.0.1") } - err = portallocator.Get().ReleaseAll() - assert.NilError(t, err) + portallocator.Get().ReleaseAll() pbs, err := n.addPortMappings(tc.epAddrV4, tc.epAddrV6, tc.cfg, tc.defHostIP) if tc.expErr != "" { diff --git a/libnetwork/portallocator/portallocator.go b/libnetwork/portallocator/portallocator.go index 0df927748a459..4e3b970642db4 100644 --- a/libnetwork/portallocator/portallocator.go +++ b/libnetwork/portallocator/portallocator.go @@ -21,7 +21,6 @@ var ( ErrAllPortsAllocated = errors.New("all ports are allocated") // ErrUnknownProtocol is returned when an unknown protocol was specified ErrUnknownProtocol = errors.New("unknown protocol") - defaultIP = net.ParseIP("0.0.0.0") once sync.Once instance *PortAllocator ) @@ -62,10 +61,11 @@ func (e ErrPortAlreadyAllocated) Error() string { type ( // PortAllocator manages the transport ports database PortAllocator struct { - mutex sync.Mutex - ipMap ipMapping - Begin int - End int + mutex sync.Mutex + defaultIP net.IP + ipMap ipMapping + begin int + end int } portRange struct { begin int @@ -80,6 +80,15 @@ type ( protoMap map[string]*portMap ) +// GetPortRange returns the PortAllocator's default port range. +// +// This function is for internal use in tests, and must not be used +// for other purposes. +func GetPortRange() (start, end uint16) { + p := Get() + return uint16(p.begin), uint16(p.end) +} + // Get returns the PortAllocator func Get() *PortAllocator { // Port Allocator is a singleton @@ -96,9 +105,10 @@ func newInstance() *PortAllocator { start, end = defaultPortRangeStart, defaultPortRangeEnd } return &PortAllocator{ - ipMap: ipMapping{}, - Begin: start, - End: end, + ipMap: ipMapping{}, + defaultIP: net.IPv4zero, + begin: start, + end: end, } } @@ -106,15 +116,18 @@ func newInstance() *PortAllocator { // If port is 0 it returns first free port. Otherwise it checks port availability // in proto's pool and returns that port or error if port is already busy. func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) { - return p.RequestPortInRange(ip, proto, port, port) + if ip == nil { + ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing. + } + return p.RequestPortsInRange([]net.IP{ip}, proto, port, port) } -// RequestPortInRange is equivalent to [RequestPortsInRange] with a single IP address. -// -// If ip is nil, a port is instead requested for the defaultIP. +// RequestPortInRange is equivalent to [PortAllocator.RequestPortsInRange] with +// a single IP address. If ip is nil, a port is instead requested for the +// default IP (0.0.0.0). func (p *PortAllocator) RequestPortInRange(ip net.IP, proto string, portStart, portEnd int) (int, error) { if ip == nil { - ip = defaultIP + ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing. } return p.RequestPortsInRange([]net.IP{ip}, proto, portStart, portEnd) } @@ -129,6 +142,13 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar return 0, ErrUnknownProtocol } + if portStart != 0 || portEnd != 0 { + // Validate custom port-range + if portStart == 0 || portEnd == 0 || portEnd < portStart { + return 0, fmt.Errorf("invalid port range: %d-%d", portStart, portEnd) + } + } + p.mutex.Lock() defer p.mutex.Unlock() @@ -137,9 +157,9 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar ipstr := ip.String() if _, ok := p.ipMap[ipstr]; !ok { p.ipMap[ipstr] = protoMap{ - "tcp": p.newPortMap(), - "udp": p.newPortMap(), - "sctp": p.newPortMap(), + "tcp": newPortMap(p.begin, p.end), + "udp": newPortMap(p.begin, p.end), + "sctp": newPortMap(p.begin, p.end), } } pMaps[i] = p.ipMap[ipstr][proto] @@ -163,11 +183,7 @@ func (p *PortAllocator) RequestPortsInRange(ips []net.IP, proto string, portStar // Create/fetch ranges for each portMap. pRanges := make([]*portRange, len(pMaps)) for i, pMap := range pMaps { - var err error - pRanges[i], err = pMap.getPortRange(portStart, portEnd) - if err != nil { - return 0, err - } + pRanges[i] = pMap.getPortRange(portStart, portEnd) } // Starting after the last port allocated for the first address, search @@ -199,7 +215,7 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) { defer p.mutex.Unlock() if ip == nil { - ip = defaultIP + ip = p.defaultIP // FIXME(thaJeztah): consider making this a required argument and producing an error instead, or set default when constructing. } protomap, ok := p.ipMap[ip.String()] if !ok { @@ -208,24 +224,11 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) { delete(protomap[proto].p, port) } -func (p *PortAllocator) newPortMap() *portMap { - defaultKey := getRangeKey(p.Begin, p.End) - pm := &portMap{ - p: map[int]struct{}{}, - defaultRange: defaultKey, - portRanges: map[string]*portRange{ - defaultKey: newPortRange(p.Begin, p.End), - }, - } - return pm -} - // ReleaseAll releases all ports for all ips. -func (p *PortAllocator) ReleaseAll() error { +func (p *PortAllocator) ReleaseAll() { p.mutex.Lock() p.ipMap = ipMapping{} p.mutex.Unlock() - return nil } func getRangeKey(portStart, portEnd int) string { @@ -240,26 +243,32 @@ func newPortRange(portStart, portEnd int) *portRange { } } -func (pm *portMap) getPortRange(portStart, portEnd int) (*portRange, error) { +func newPortMap(portStart, portEnd int) *portMap { + defaultKey := getRangeKey(portStart, portEnd) + return &portMap{ + p: map[int]struct{}{}, + defaultRange: defaultKey, + portRanges: map[string]*portRange{ + defaultKey: newPortRange(portStart, portEnd), + }, + } +} + +func (pm *portMap) getPortRange(portStart, portEnd int) *portRange { var key string if portStart == 0 && portEnd == 0 { key = pm.defaultRange } else { key = getRangeKey(portStart, portEnd) - if portStart == portEnd || - portStart == 0 || portEnd == 0 || - portEnd < portStart { - return nil, fmt.Errorf("invalid port range: %s", key) - } } // Return existing port range, if already known. if pr, exists := pm.portRanges[key]; exists { - return pr, nil + return pr } // Otherwise create a new port range. pr := newPortRange(portStart, portEnd) pm.portRanges[key] = pr - return pr, nil + return pr } diff --git a/libnetwork/portallocator/portallocator_test.go b/libnetwork/portallocator/portallocator_test.go index c48d52406b1ad..cc0831d722143 100644 --- a/libnetwork/portallocator/portallocator_test.go +++ b/libnetwork/portallocator/portallocator_test.go @@ -8,29 +8,23 @@ import ( is "gotest.tools/v3/assert/cmp" ) -func resetPortAllocator() { - instance = newInstance() -} - func TestRequestNewPort(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() - port, err := p.RequestPort(defaultIP, "tcp", 0) + port, err := p.RequestPort(net.IPv4zero, "tcp", 0) if err != nil { t.Fatal(err) } - if expected := p.Begin; port != expected { + if expected := p.begin; port != expected { t.Fatalf("Expected port %d got %d", expected, port) } } func TestRequestSpecificPort(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() - port, err := p.RequestPort(defaultIP, "tcp", 5000) + port, err := p.RequestPort(net.IPv4zero, "tcp", 5000) if err != nil { t.Fatal(err) } @@ -41,9 +35,9 @@ func TestRequestSpecificPort(t *testing.T) { } func TestReleasePort(t *testing.T) { - p := Get() + p := newInstance() - port, err := p.RequestPort(defaultIP, "tcp", 5000) + port, err := p.RequestPort(net.IPv4zero, "tcp", 5000) if err != nil { t.Fatal(err) } @@ -51,14 +45,13 @@ func TestReleasePort(t *testing.T) { t.Fatalf("Expected port 5000 got %d", port) } - p.ReleasePort(defaultIP, "tcp", 5000) + p.ReleasePort(net.IPv4zero, "tcp", 5000) } func TestReuseReleasedPort(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() - port, err := p.RequestPort(defaultIP, "tcp", 5000) + port, err := p.RequestPort(net.IPv4zero, "tcp", 5000) if err != nil { t.Fatal(err) } @@ -66,9 +59,9 @@ func TestReuseReleasedPort(t *testing.T) { t.Fatalf("Expected port 5000 got %d", port) } - p.ReleasePort(defaultIP, "tcp", 5000) + p.ReleasePort(net.IPv4zero, "tcp", 5000) - port, err = p.RequestPort(defaultIP, "tcp", 5000) + port, err = p.RequestPort(net.IPv4zero, "tcp", 5000) if err != nil { t.Fatal(err) } @@ -78,10 +71,9 @@ func TestReuseReleasedPort(t *testing.T) { } func TestReleaseUnreadledPort(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() - port, err := p.RequestPort(defaultIP, "tcp", 5000) + port, err := p.RequestPort(net.IPv4zero, "tcp", 5000) if err != nil { t.Fatal(err) } @@ -89,7 +81,7 @@ func TestReleaseUnreadledPort(t *testing.T) { t.Fatalf("Expected port 5000 got %d", port) } - _, err = p.RequestPort(defaultIP, "tcp", 5000) + _, err = p.RequestPort(net.IPv4zero, "tcp", 5000) switch err.(type) { case ErrPortAlreadyAllocated: @@ -99,39 +91,40 @@ func TestReleaseUnreadledPort(t *testing.T) { } func TestUnknowProtocol(t *testing.T) { - if _, err := Get().RequestPort(defaultIP, "tcpp", 0); err != ErrUnknownProtocol { + p := newInstance() + + if _, err := p.RequestPort(net.IPv4zero, "tcpp", 0); err != ErrUnknownProtocol { t.Fatalf("Expected error %s got %s", ErrUnknownProtocol, err) } } func TestAllocateAllPorts(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() - for i := 0; i <= p.End-p.Begin; i++ { - port, err := p.RequestPort(defaultIP, "tcp", 0) + for i := 0; i <= p.end-p.begin; i++ { + port, err := p.RequestPort(net.IPv4zero, "tcp", 0) if err != nil { t.Fatal(err) } - if expected := p.Begin + i; port != expected { + if expected := p.begin + i; port != expected { t.Fatalf("Expected port %d got %d", expected, port) } } - if _, err := p.RequestPort(defaultIP, "tcp", 0); err != ErrAllPortsAllocated { + if _, err := p.RequestPort(net.IPv4zero, "tcp", 0); err != ErrAllPortsAllocated { t.Fatalf("Expected error %s got %s", ErrAllPortsAllocated, err) } - _, err := p.RequestPort(defaultIP, "udp", 0) + _, err := p.RequestPort(net.IPv4zero, "udp", 0) if err != nil { t.Fatal(err) } // release a port in the middle and ensure we get another tcp port - port := p.Begin + 5 - p.ReleasePort(defaultIP, "tcp", port) - newPort, err := p.RequestPort(defaultIP, "tcp", 0) + port := p.begin + 5 + p.ReleasePort(net.IPv4zero, "tcp", port) + newPort, err := p.RequestPort(net.IPv4zero, "tcp", 0) if err != nil { t.Fatal(err) } @@ -141,8 +134,8 @@ func TestAllocateAllPorts(t *testing.T) { // now pm.last == newPort, release it so that it's the only free port of // the range, and ensure we get it back - p.ReleasePort(defaultIP, "tcp", newPort) - port, err = p.RequestPort(defaultIP, "tcp", 0) + p.ReleasePort(net.IPv4zero, "tcp", newPort) + port, err = p.RequestPort(net.IPv4zero, "tcp", 0) if err != nil { t.Fatal(err) } @@ -152,29 +145,25 @@ func TestAllocateAllPorts(t *testing.T) { } func BenchmarkAllocatePorts(b *testing.B) { - p := Get() - defer resetPortAllocator() + p := newInstance() for i := 0; i < b.N; i++ { - for i := 0; i <= p.End-p.Begin; i++ { - port, err := p.RequestPort(defaultIP, "tcp", 0) + for i := 0; i <= p.end-p.begin; i++ { + port, err := p.RequestPort(net.IPv4zero, "tcp", 0) if err != nil { b.Fatal(err) } - if expected := p.Begin + i; port != expected { + if expected := p.begin + i; port != expected { b.Fatalf("Expected port %d got %d", expected, port) } } - if err := p.ReleaseAll(); err != nil { - b.Fatal(err) - } + p.ReleaseAll() } } func TestPortAllocation(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() ip := net.ParseIP("192.168.0.1") ip2 := net.ParseIP("192.168.0.2") @@ -233,31 +222,30 @@ func TestPortAllocation(t *testing.T) { } func TestPortAllocationWithCustomRange(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() start, end := 8081, 8082 specificPort := 8000 // get an ephemeral port. - port1, err := p.RequestPortInRange(defaultIP, "tcp", 0, 0) + port1, err := p.RequestPortInRange(net.IPv4zero, "tcp", 0, 0) if err != nil { t.Fatal(err) } // request invalid ranges - if _, err := p.RequestPortInRange(defaultIP, "tcp", 0, end); err == nil { + if _, err := p.RequestPortInRange(net.IPv4zero, "tcp", 0, end); err == nil { t.Fatalf("Expected error for invalid range %d-%d", 0, end) } - if _, err := p.RequestPortInRange(defaultIP, "tcp", start, 0); err == nil { + if _, err := p.RequestPortInRange(net.IPv4zero, "tcp", start, 0); err == nil { t.Fatalf("Expected error for invalid range %d-%d", 0, end) } - if _, err := p.RequestPortInRange(defaultIP, "tcp", 8081, 8080); err == nil { + if _, err := p.RequestPortInRange(net.IPv4zero, "tcp", 8081, 8080); err == nil { t.Fatalf("Expected error for invalid range %d-%d", 0, end) } // request a single port - port, err := p.RequestPortInRange(defaultIP, "tcp", specificPort, specificPort) + port, err := p.RequestPortInRange(net.IPv4zero, "tcp", specificPort, specificPort) if err != nil { t.Fatal(err) } @@ -266,7 +254,7 @@ func TestPortAllocationWithCustomRange(t *testing.T) { } // get a port from the range - port2, err := p.RequestPortInRange(defaultIP, "tcp", start, end) + port2, err := p.RequestPortInRange(net.IPv4zero, "tcp", start, end) if err != nil { t.Fatal(err) } @@ -274,7 +262,7 @@ func TestPortAllocationWithCustomRange(t *testing.T) { t.Fatalf("Expected a port between %d and %d, got %d", start, end, port2) } // get another ephemeral port (should be > port1) - port3, err := p.RequestPortInRange(defaultIP, "tcp", 0, 0) + port3, err := p.RequestPortInRange(net.IPv4zero, "tcp", 0, 0) if err != nil { t.Fatal(err) } @@ -282,7 +270,7 @@ func TestPortAllocationWithCustomRange(t *testing.T) { t.Fatalf("Expected new port > %d in the ephemeral range, got %d", port1, port3) } // get another (and in this case the only other) port from the range - port4, err := p.RequestPortInRange(defaultIP, "tcp", start, end) + port4, err := p.RequestPortInRange(net.IPv4zero, "tcp", start, end) if err != nil { t.Fatal(err) } @@ -293,38 +281,36 @@ func TestPortAllocationWithCustomRange(t *testing.T) { t.Fatal("Allocated the same port from a custom range") } // request 3rd port from the range of 2 - if _, err := p.RequestPortInRange(defaultIP, "tcp", start, end); err != ErrAllPortsAllocated { + if _, err := p.RequestPortInRange(net.IPv4zero, "tcp", start, end); err != ErrAllPortsAllocated { t.Fatalf("Expected error %s got %s", ErrAllPortsAllocated, err) } } func TestNoDuplicateBPR(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() - if port, err := p.RequestPort(defaultIP, "tcp", p.Begin); err != nil { + if port, err := p.RequestPort(net.IPv4zero, "tcp", p.begin); err != nil { t.Fatal(err) - } else if port != p.Begin { - t.Fatalf("Expected port %d got %d", p.Begin, port) + } else if port != p.begin { + t.Fatalf("Expected port %d got %d", p.begin, port) } - if port, err := p.RequestPort(defaultIP, "tcp", 0); err != nil { + if port, err := p.RequestPort(net.IPv4zero, "tcp", 0); err != nil { t.Fatal(err) - } else if port == p.Begin { + } else if port == p.begin { t.Fatalf("Acquire(0) allocated the same port twice: %d", port) } } func TestRequestPortForMultipleIPs(t *testing.T) { - p := Get() - defer resetPortAllocator() + p := newInstance() addrs := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::")} // Default port range. port, err := p.RequestPortsInRange(addrs, "tcp", 0, 0) assert.Check(t, err) - assert.Check(t, is.Equal(port, p.Begin)) + assert.Check(t, is.Equal(port, p.begin)) // Single-port range. port, err = p.RequestPortsInRange(addrs, "tcp", 10000, 10000)