-
Notifications
You must be signed in to change notification settings - Fork 200
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]>
- Loading branch information
1 parent
fe8a2dd
commit 3327c9c
Showing
3 changed files
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
/* | ||
Copyright 2024 The Kubernetes Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package net | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
"syscall" | ||
) | ||
|
||
// multiListener implements net.Listener and uses multiplexing to listen to and accept | ||
// TCP connections from multiple addresses. | ||
type multiListener struct { | ||
latestAcceptedFDIndex int | ||
fds []int | ||
addrs []net.Addr | ||
stopCh chan struct{} | ||
} | ||
|
||
// compile time check to ensure *multiListener implements net.Listener. | ||
var _ net.Listener = &multiListener{} | ||
|
||
// NewMultiListener returns *multiListener as net.Listener allowing consumers to | ||
// listen for TCP connections on multiple addresses. | ||
func NewMultiListener(addresses []string) (net.Listener, error) { | ||
ml := &multiListener{ | ||
stopCh: make(chan struct{}), | ||
} | ||
for _, address := range addresses { | ||
fd, addr, err := createBindAndListen(address) | ||
if err != nil { | ||
return nil, err | ||
} | ||
ml.fds = append(ml.fds, fd) | ||
ml.addrs = append(ml.addrs, addr) | ||
} | ||
return ml, nil | ||
} | ||
|
||
// Accept is part of net.Listener interface. | ||
func (ml *multiListener) Accept() (net.Conn, error) { | ||
return ml.accept() | ||
} | ||
|
||
// Close is part of net.Listener interface. | ||
func (ml *multiListener) Close() error { | ||
close(ml.stopCh) | ||
for _, fd := range ml.fds { | ||
_ = syscall.Close(fd) | ||
} | ||
return nil | ||
} | ||
|
||
// Addr is part of net.Listener interface. | ||
func (ml *multiListener) Addr() net.Addr { | ||
return ml.addrs[ml.latestAcceptedFDIndex] | ||
} | ||
|
||
// createBindAndListen creates a TCP socket, binds it to the specified address, and starts listening on it. | ||
func createBindAndListen(address string) (int, net.Addr, error) { | ||
host, _, err := net.SplitHostPort(address) | ||
if err != nil { | ||
return -1, nil, err | ||
} | ||
|
||
ipFamily := IPFamilyOf(ParseIPSloppy(host)) | ||
var network string | ||
var domain int | ||
switch ipFamily { | ||
case IPv4: | ||
network = "tcp4" | ||
domain = syscall.AF_INET | ||
case IPv6: | ||
network = "tcp6" | ||
domain = syscall.AF_INET6 | ||
default: | ||
return -1, nil, fmt.Errorf("failed to idenfity ip family of host '%s'", host) | ||
|
||
} | ||
|
||
// resolve tcp addr | ||
addr, err := net.ResolveTCPAddr(network, address) | ||
if err != nil { | ||
return -1, nil, err | ||
} | ||
|
||
// create socket | ||
fd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0) | ||
if err != nil { | ||
return -1, nil, err | ||
} | ||
|
||
// define socket address for bind | ||
var sockAddr syscall.Sockaddr | ||
if ipFamily == IPv4 { | ||
var ipBytes [4]byte | ||
copy(ipBytes[:], addr.IP.To4()) | ||
sockAddr = &syscall.SockaddrInet4{ | ||
Addr: ipBytes, | ||
Port: addr.Port, | ||
} | ||
} else { | ||
var ipBytes [16]byte | ||
copy(ipBytes[:], addr.IP.To16()) | ||
sockAddr = &syscall.SockaddrInet6{ | ||
Addr: ipBytes, | ||
Port: addr.Port, | ||
} | ||
} | ||
|
||
// bind socket to specified addr | ||
if err = syscall.Bind(fd, sockAddr); err != nil { | ||
_ = syscall.Close(fd) | ||
return -1, nil, err | ||
} | ||
|
||
// start listening on socket | ||
if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil { | ||
_ = syscall.Close(fd) | ||
return -1, nil, err | ||
} | ||
|
||
return fd, addr, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
//go:build darwin | ||
// +build darwin | ||
|
||
/* | ||
Copyright 2024 The Kubernetes Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package net | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
"os" | ||
"syscall" | ||
) | ||
|
||
// Accept is part of net.Listener interface. | ||
func (ml *multiListener) accept() (net.Conn, error) { | ||
for { | ||
readFds := &syscall.FdSet{} | ||
maxfd := 0 | ||
|
||
for _, fd := range ml.fds { | ||
if fd > maxfd { | ||
maxfd = fd | ||
} | ||
addFDToFDSet(fd, readFds) | ||
} | ||
|
||
// wait for any of the sockets to be ready for accepting new connection | ||
timeout := syscall.Timeval{Sec: 1, Usec: 0} | ||
err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
for i, fd := range ml.fds { | ||
if isFDInFDSet(fd, readFds) { | ||
conn, err := acceptConnection(fd) | ||
if err != nil { | ||
return nil, err | ||
} | ||
ml.latestAcceptedFDIndex = i | ||
return conn, nil | ||
} | ||
} | ||
|
||
select { | ||
case <-ml.stopCh: | ||
return nil, fmt.Errorf("multiListener closed") | ||
default: | ||
continue | ||
} | ||
} | ||
} | ||
|
||
// addFDToFDSet adds fd to the given fd set | ||
func addFDToFDSet(fd int, p *syscall.FdSet) { | ||
mask := 1 << (uint(fd) % syscall.FD_SETSIZE) | ||
p.Bits[fd/syscall.FD_SETSIZE] |= int32(mask) | ||
} | ||
|
||
// isFDInFDSet returns true if fd is in fd set, false otherwise | ||
func isFDInFDSet(fd int, p *syscall.FdSet) bool { | ||
mask := 1 << (uint(fd) % syscall.FD_SETSIZE) | ||
return p.Bits[fd/syscall.FD_SETSIZE]&int32(mask) != 0 | ||
} | ||
|
||
// acceptConnection accepts connection and returns remote connection object | ||
func acceptConnection(fd int) (net.Conn, error) { | ||
connFD, _, err := syscall.Accept(fd) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD))) | ||
if err != nil { | ||
_ = syscall.Close(connFD) | ||
return nil, err | ||
} | ||
return conn, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
//go:build linux | ||
// +build linux | ||
|
||
/* | ||
Copyright 2024 The Kubernetes Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package net | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
"os" | ||
"syscall" | ||
) | ||
|
||
func (ml *multiListener) accept() (net.Conn, error) { | ||
for { | ||
readFds := &syscall.FdSet{} | ||
maxfd := 0 | ||
|
||
for _, fd := range ml.fds { | ||
if fd > maxfd { | ||
maxfd = fd | ||
} | ||
addFDToFDSet(fd, readFds) | ||
} | ||
|
||
// wait for any of the sockets to be ready for accepting new connection | ||
timeout := syscall.Timeval{Sec: 1, Usec: 0} | ||
n, err := syscall.Select(maxfd+1, readFds, nil, nil, &timeout) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if n == 0 { | ||
select { | ||
case <-ml.stopCh: | ||
return nil, fmt.Errorf("multiListener closed") | ||
default: | ||
continue | ||
} | ||
} | ||
for i, fd := range ml.fds { | ||
if isFDInFDSet(fd, readFds) { | ||
conn, err := acceptConnection(fd) | ||
if err != nil { | ||
return nil, err | ||
} | ||
ml.latestAcceptedFDIndex = i | ||
return conn, nil | ||
} | ||
} | ||
} | ||
} | ||
|
||
// addFDToFDSet adds fd to the given fd set | ||
func addFDToFDSet(fd int, p *syscall.FdSet) { | ||
mask := 1 << (uint(fd) % syscall.FD_SETSIZE) | ||
p.Bits[fd/syscall.FD_SETSIZE] |= int64(mask) | ||
} | ||
|
||
// isFDInFDSet returns true if fd is in fd set, false otherwise | ||
func isFDInFDSet(fd int, p *syscall.FdSet) bool { | ||
mask := 1 << (uint(fd) % syscall.FD_SETSIZE) | ||
return p.Bits[fd/syscall.FD_SETSIZE]&int64(mask) != 0 | ||
} | ||
|
||
// acceptConnection accepts connection and returns remote connection object | ||
func acceptConnection(fd int) (net.Conn, error) { | ||
connFD, _, err := syscall.Accept(fd) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
conn, err := net.FileConn(os.NewFile(uintptr(connFD), fmt.Sprintf("fd %d", connFD))) | ||
if err != nil { | ||
_ = syscall.Close(connFD) | ||
return nil, err | ||
} | ||
return conn, nil | ||
} |