diff --git a/net/multi_listen.go b/net/multi_listen.go new file mode 100644 index 00000000..9e4d10ef --- /dev/null +++ b/net/multi_listen.go @@ -0,0 +1,138 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "syscall" +) + +// multiListener implements net.Listener and uses multiplexing to listen to and accept +// TCP connections from multiple addresses. +type multiListener struct { + latestAcceptedFDIndex int + fds []int + addrs []net.Addr + stopCh chan struct{} +} + +// compile time check to ensure *multiListener implements net.Listener. +var _ net.Listener = &multiListener{} + +// NewMultiListener returns *multiListener as net.Listener allowing consumers to +// listen for TCP connections on multiple addresses. +func NewMultiListener(addresses []string) (net.Listener, error) { + ml := &multiListener{ + stopCh: make(chan struct{}), + } + for _, address := range addresses { + fd, addr, err := createBindAndListen(address) + if err != nil { + return nil, err + } + ml.fds = append(ml.fds, fd) + ml.addrs = append(ml.addrs, addr) + } + return ml, nil +} + +// Accept is part of net.Listener interface. +func (ml *multiListener) Accept() (net.Conn, error) { + return ml.accept() +} + +// Close is part of net.Listener interface. +func (ml *multiListener) Close() error { + close(ml.stopCh) + for _, fd := range ml.fds { + _ = syscall.Close(fd) + } + return nil +} + +// Addr is part of net.Listener interface. +func (ml *multiListener) Addr() net.Addr { + return ml.addrs[ml.latestAcceptedFDIndex] +} + +// createBindAndListen creates a TCP socket, binds it to the specified address, and starts listening on it. +func createBindAndListen(address string) (int, net.Addr, error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + return -1, nil, err + } + + ipFamily := IPFamilyOf(ParseIPSloppy(host)) + var network string + var domain int + switch ipFamily { + case IPv4: + network = "tcp4" + domain = syscall.AF_INET + case IPv6: + network = "tcp6" + domain = syscall.AF_INET6 + default: + return -1, nil, fmt.Errorf("failed to idenfity ip family of host '%s'", host) + + } + + // resolve tcp addr + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return -1, nil, err + } + + // create socket + fd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0) + if err != nil { + return -1, nil, err + } + + // define socket address for bind + var sockAddr syscall.Sockaddr + if ipFamily == IPv4 { + var ipBytes [4]byte + copy(ipBytes[:], addr.IP.To4()) + sockAddr = &syscall.SockaddrInet4{ + Addr: ipBytes, + Port: addr.Port, + } + } else { + var ipBytes [16]byte + copy(ipBytes[:], addr.IP.To16()) + sockAddr = &syscall.SockaddrInet6{ + Addr: ipBytes, + Port: addr.Port, + } + } + + // bind socket to specified addr + if err = syscall.Bind(fd, sockAddr); err != nil { + _ = syscall.Close(fd) + return -1, nil, err + } + + // start listening on socket + if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil { + _ = syscall.Close(fd) + return -1, nil, err + } + + return fd, addr, nil +} diff --git a/net/multi_listen_darwin.go b/net/multi_listen_darwin.go new file mode 100644 index 00000000..08197f22 --- /dev/null +++ b/net/multi_listen_darwin.go @@ -0,0 +1,94 @@ +//go:build darwin +// +build darwin + +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "os" + "syscall" +) + +// Accept is part of net.Listener interface. +func (ml *multiListener) accept() (net.Conn, error) { + for { + readFds := &syscall.FdSet{} + maxfd := 0 + + for _, fd := range ml.fds { + if fd > maxfd { + maxfd = fd + } + addFDToFDSet(fd, readFds) + } + + // wait for any of the sockets to be ready for accepting new connection + timeout := syscall.Timeval{Sec: 1, Usec: 0} + err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout) + if err != nil { + return nil, err + } + + for i, fd := range ml.fds { + if isFDInFDSet(fd, readFds) { + conn, err := acceptConnection(fd) + if err != nil { + return nil, err + } + ml.latestAcceptedFDIndex = i + return conn, nil + } + } + + select { + case <-ml.stopCh: + return nil, fmt.Errorf("multiListener closed") + default: + continue + } + } +} + +// addFDToFDSet adds fd to the given fd set +func addFDToFDSet(fd int, p *syscall.FdSet) { + mask := 1 << (uint(fd) % syscall.FD_SETSIZE) + p.Bits[fd/syscall.FD_SETSIZE] |= int32(mask) +} + +// isFDInFDSet returns true if fd is in fd set, false otherwise +func isFDInFDSet(fd int, p *syscall.FdSet) bool { + mask := 1 << (uint(fd) % syscall.FD_SETSIZE) + return p.Bits[fd/syscall.FD_SETSIZE]&int32(mask) != 0 +} + +// acceptConnection accepts connection and returns remote connection object +func acceptConnection(fd int) (net.Conn, error) { + connFD, _, err := syscall.Accept(fd) + if err != nil { + return nil, err + } + + conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD))) + if err != nil { + _ = syscall.Close(connFD) + return nil, err + } + return conn, nil +} diff --git a/net/multi_listen_linux.go b/net/multi_listen_linux.go new file mode 100644 index 00000000..102cec4d --- /dev/null +++ b/net/multi_listen_linux.go @@ -0,0 +1,93 @@ +//go:build linux +// +build linux + +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "os" + "syscall" +) + +func (ml *multiListener) accept() (net.Conn, error) { + for { + readFds := &syscall.FdSet{} + maxfd := 0 + + for _, fd := range ml.fds { + if fd > maxfd { + maxfd = fd + } + addFDToFDSet(fd, readFds) + } + + // wait for any of the sockets to be ready for accepting new connection + timeout := syscall.Timeval{Sec: 1, Usec: 0} + n, err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout) + if err != nil { + return nil, err + } + if n == 0 { + select { + case <-ml.stopCh: + return nil, fmt.Errorf("multiListener closed") + default: + continue + } + } + for i, fd := range ml.fds { + if isFDInFDSet(fd, readFds) { + conn, err := acceptConnection(fd) + if err != nil { + return nil, err + } + ml.latestAcceptedFDIndex = i + return conn, nil + } + } + } +} + +// addFDToFDSet adds fd to the given fd set +func addFDToFDSet(fd int, p *syscall.FdSet) { + mask := 1 << (uint(fd) % syscall.FD_SETSIZE) + p.Bits[fd/syscall.FD_SETSIZE] |= int64(mask) +} + +// isFDInFDSet returns true if fd is in fd set, false otherwise +func isFDInFDSet(fd int, p *syscall.FdSet) bool { + mask := 1 << (uint(fd) % syscall.FD_SETSIZE) + return p.Bits[fd/syscall.FD_SETSIZE]&int64(mask) != 0 +} + +// acceptConnection accepts connection and returns remote connection object +func acceptConnection(fd int) (net.Conn, error) { + connFD, _, err := syscall.Accept(fd) + if err != nil { + return nil, err + } + + conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD))) + if err != nil { + _ = syscall.Close(connFD) + return nil, err + } + return conn, nil +}