Skip to content

Commit

Permalink
Check UDP packet size before allocation.
Browse files Browse the repository at this point in the history
Reported-by: [email protected]
PiperOrigin-RevId: 433892548
  • Loading branch information
konstantin-s-bogom authored and gvisor-bot committed Mar 11, 2022
1 parent 4503ba3 commit ee133db
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 22 deletions.
34 changes: 19 additions & 15 deletions pkg/tcpip/transport/udp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,27 +396,31 @@ func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (
return udpPacketInfo{}, err
}

if p.Len() > header.UDPMaximumPacketSize {
// Native linux behaviour differs for IPv4 and IPv6 packets; IPv4 packet
// errors aren't report to the error queue at all.
if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber {
so := e.SocketOptions()
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
e.net.NetProto(),
uint32(p.Len()),
dst,
nil,
)
}
}
ctx.Release()
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}

// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
ctx.Release()
return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
so := e.SocketOptions()
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
e.net.NetProto(),
header.UDPMaximumPacketSize,
dst,
v,
)
}
ctx.Release()
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}

return udpPacketInfo{
ctx: ctx,
Expand Down
30 changes: 23 additions & 7 deletions pkg/tcpip/transport/udp/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,15 +423,15 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
// and verifies it fails with the provided error code.
// TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the
// testing context.
func testFailingWrite(c *context.Context, flow context.TestFlow, wantErr tcpip.Error) {
func testFailingWrite(c *context.Context, flow context.TestFlow, payloadSize int, wantErr tcpip.Error) {
c.T.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.EP.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.MakeHeader4Tuple(context.Outgoing)
writeDstAddr := flow.MapAddrIfApplicable(h.Dst.Addr)

var r bytes.Reader
r.Reset(newRandomPayload(arbitraryPayloadSize))
r.Reset(newRandomPayload(payloadSize))
_, gotErr := c.EP.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port},
})
Expand Down Expand Up @@ -590,7 +590,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
testWrite(c, context.UnicastV6)

// Write to V4 mapped address.
testFailingWrite(c, context.UnicastV4in6, &tcpip.ErrNetworkUnreachable{})
testFailingWrite(c, context.UnicastV4in6, arbitraryPayloadSize, &tcpip.ErrNetworkUnreachable{})
const want = 1
if got := c.EP.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
c.T.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
Expand All @@ -611,7 +611,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
testWrite(c, context.UnicastV4in6)

// Write to v6 address.
testFailingWrite(c, context.UnicastV6, &tcpip.ErrInvalidEndpointState{})
testFailingWrite(c, context.UnicastV6, arbitraryPayloadSize, &tcpip.ErrInvalidEndpointState{})
}

func TestV4WriteOnV6Only(t *testing.T) {
Expand All @@ -621,7 +621,7 @@ func TestV4WriteOnV6Only(t *testing.T) {
c.CreateEndpointForFlow(context.UnicastV6Only, udp.ProtocolNumber)

// Write to V4 mapped address.
testFailingWrite(c, context.UnicastV4in6, &tcpip.ErrNoRoute{})
testFailingWrite(c, context.UnicastV4in6, arbitraryPayloadSize, &tcpip.ErrNoRoute{})
}

func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
Expand All @@ -636,7 +636,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}

// Write to v6 address.
testFailingWrite(c, context.UnicastV6, &tcpip.ErrInvalidEndpointState{})
testFailingWrite(c, context.UnicastV6, arbitraryPayloadSize, &tcpip.ErrInvalidEndpointState{})
}

func TestV6WriteOnConnected(t *testing.T) {
Expand Down Expand Up @@ -1772,7 +1772,7 @@ func TestShutdownWrite(t *testing.T) {
t.Fatalf("Shutdown failed: %s", err)
}

testFailingWrite(c, context.UnicastV6, &tcpip.ErrClosedForSend{})
testFailingWrite(c, context.UnicastV6, arbitraryPayloadSize, &tcpip.ErrClosedForSend{})
}

func TestOutgoingSubnetBroadcast(t *testing.T) {
Expand Down Expand Up @@ -2067,6 +2067,22 @@ func TestChecksumWithZeroValueOnesComplementSum(t *testing.T) {
}
}

// TestWritePayloadSizeTooBig verifies that writing anything bigger than
// header.UDPMaximumPacketSize fails.
func TestWritePayloadSizeTooBig(t *testing.T) {
c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
defer c.Cleanup()

c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)

if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
c.T.Fatalf("Connect failed: %s", err)
}

testWrite(c, context.UnicastV6)
testFailingWrite(c, context.UnicastV6, header.UDPMaximumPacketSize+1, &tcpip.ErrMessageTooLong{})
}

func TestMain(m *testing.M) {
refs.SetLeakMode(refs.LeaksPanic)
code := m.Run()
Expand Down
103 changes: 103 additions & 0 deletions test/syscalls/linux/socket_ip_udp_generic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <errno.h>
#ifdef __linux__
#include <linux/errqueue.h>
#include <linux/in6.h>
#endif // __linux__
#include <netinet/in.h>
Expand Down Expand Up @@ -541,5 +542,107 @@ TEST_P(UDPSocketPairTest, GetSocketAcceptConn) {
EXPECT_EQ(got, 0);
}

#ifdef __linux__
TEST_P(UDPSocketPairTest, PayloadTooBig) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());

// Set IP_RECVERR socket option to enable error queueing.
int v = kSockOptOn;
socklen_t optlen = sizeof(v);
int opt_level = SOL_IP;
int opt_type = IP_RECVERR;
if (sockets->first_addr()->sa_family == AF_INET6) {
opt_level = SOL_IPV6;
opt_type = IPV6_RECVERR;
}
ASSERT_THAT(setsockopt(sockets->first_fd(), opt_level, opt_type, &v, optlen),
SyscallSucceeds());

// Buffers bigger than 0xffff should receive an error.
const int kBufLen = 0x10000;
char buf[kBufLen];
RandomizeBuffer(buf, sizeof(buf));

EXPECT_THAT(send(sockets->first_fd(), buf, sizeof(buf), 0),
SyscallFailsWithErrno(EMSGSIZE));

// Dequeue error using recvmsg(MSG_ERRQUEUE). Give a buffer big-enough for
// the original message just in case.
char got[kBufLen];
struct iovec iov;
iov.iov_base = reinterpret_cast<void*>(got);
iov.iov_len = kBufLen;

const int addrlen_ = sockets->second_addr_size();
size_t control_buf_len = CMSG_SPACE(sizeof(sock_extended_err) + addrlen_);
std::vector<char> control_buf(control_buf_len);
struct sockaddr_storage remote;
memset(&remote, 0, sizeof(remote));
struct msghdr msg = {};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_flags = 0;
msg.msg_control = control_buf.data();
msg.msg_controllen = control_buf_len;
msg.msg_name = reinterpret_cast<void*>(&remote);
msg.msg_namelen = addrlen_;

struct sockaddr_storage addr;
optlen = sizeof(addr);
EXPECT_THAT(getpeername(sockets->first_fd(), AsSockAddr(&addr), &optlen),
SyscallSucceeds());
bool ipv6 = false;
if (addr.ss_family == AF_INET6) {
auto ipv6addr = reinterpret_cast<struct sockaddr_in6*>(&addr);

// Exclude IPv4-mapped addresses.
uint8_t v4MappedPrefix[12] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xff, 0xff};
ipv6 = memcmp(&ipv6addr->sin6_addr.s6_addr[0], v4MappedPrefix,
sizeof(v4MappedPrefix)) != 0;
}
// Native behaviour for IPv4 packets is to not report to ERRQUEUE.
if (!ipv6) {
EXPECT_THAT(recvmsg(sockets->first_fd(), &msg, MSG_ERRQUEUE),
SyscallFailsWithErrno(EAGAIN));
return;
}

ASSERT_THAT(recvmsg(sockets->first_fd(), &msg, MSG_ERRQUEUE),
SyscallSucceedsWithValue(0));

EXPECT_NE(msg.msg_flags & MSG_ERRQUEUE, 0);
EXPECT_EQ(memcmp(&remote, sockets->second_addr(), addrlen_), 0);

// Check the contents of the control message.
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
ASSERT_NE(cmsg, nullptr);
EXPECT_EQ(CMSG_NXTHDR(&msg, cmsg), nullptr);
EXPECT_EQ(cmsg->cmsg_level, opt_level);
EXPECT_EQ(cmsg->cmsg_type, opt_type);
EXPECT_EQ(cmsg->cmsg_len,
sizeof(sock_extended_err) + addrlen_ + sizeof(cmsghdr));

// Check the contents of socket error.
struct sock_extended_err* sock_err =
reinterpret_cast<sock_extended_err*>(CMSG_DATA(cmsg));
EXPECT_EQ(sock_err->ee_errno, EMSGSIZE);
EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_LOCAL);
EXPECT_EQ(sock_err->ee_type, ICMP_ECHOREPLY);
EXPECT_EQ(sock_err->ee_code, ICMP_NET_UNREACH);
EXPECT_EQ(sock_err->ee_info, kBufLen);
EXPECT_EQ(sock_err->ee_data, 0);

// Verify that no socket error was put on the queue.
int err;
optlen = sizeof(err);
ASSERT_THAT(
getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
ASSERT_EQ(err, 0);
ASSERT_EQ(optlen, sizeof(err));
}
#endif // __linux__

} // namespace testing
} // namespace gvisor

0 comments on commit ee133db

Please sign in to comment.