Skip to content

Commit ee133db

Browse files
konstantin-s-bogomgvisor-bot
authored andcommitted
Check UDP packet size before allocation.
Reported-by: [email protected] PiperOrigin-RevId: 433892548
1 parent 4503ba3 commit ee133db

File tree

3 files changed

+145
-22
lines changed

3 files changed

+145
-22
lines changed

pkg/tcpip/transport/udp/endpoint.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -396,27 +396,31 @@ func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (
396396
return udpPacketInfo{}, err
397397
}
398398

399+
if p.Len() > header.UDPMaximumPacketSize {
400+
// Native linux behaviour differs for IPv4 and IPv6 packets; IPv4 packet
401+
// errors aren't report to the error queue at all.
402+
if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber {
403+
so := e.SocketOptions()
404+
if so.GetRecvError() {
405+
so.QueueLocalErr(
406+
&tcpip.ErrMessageTooLong{},
407+
e.net.NetProto(),
408+
uint32(p.Len()),
409+
dst,
410+
nil,
411+
)
412+
}
413+
}
414+
ctx.Release()
415+
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
416+
}
417+
399418
// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
400419
v := make([]byte, p.Len())
401420
if _, err := io.ReadFull(p, v); err != nil {
402421
ctx.Release()
403422
return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
404423
}
405-
if len(v) > header.UDPMaximumPacketSize {
406-
// Payload can't possibly fit in a packet.
407-
so := e.SocketOptions()
408-
if so.GetRecvError() {
409-
so.QueueLocalErr(
410-
&tcpip.ErrMessageTooLong{},
411-
e.net.NetProto(),
412-
header.UDPMaximumPacketSize,
413-
dst,
414-
v,
415-
)
416-
}
417-
ctx.Release()
418-
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
419-
}
420424

421425
return udpPacketInfo{
422426
ctx: ctx,

pkg/tcpip/transport/udp/udp_test.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,15 +423,15 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
423423
// and verifies it fails with the provided error code.
424424
// TODO(https://gvisor.dev/issue/5623): Extract the test write methods in the
425425
// testing context.
426-
func testFailingWrite(c *context.Context, flow context.TestFlow, wantErr tcpip.Error) {
426+
func testFailingWrite(c *context.Context, flow context.TestFlow, payloadSize int, wantErr tcpip.Error) {
427427
c.T.Helper()
428428
// Take a snapshot of the stats to validate them at the end of the test.
429429
epstats := c.EP.Stats().(*tcpip.TransportEndpointStats).Clone()
430430
h := flow.MakeHeader4Tuple(context.Outgoing)
431431
writeDstAddr := flow.MapAddrIfApplicable(h.Dst.Addr)
432432

433433
var r bytes.Reader
434-
r.Reset(newRandomPayload(arbitraryPayloadSize))
434+
r.Reset(newRandomPayload(payloadSize))
435435
_, gotErr := c.EP.Write(&r, tcpip.WriteOptions{
436436
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.Dst.Port},
437437
})
@@ -590,7 +590,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
590590
testWrite(c, context.UnicastV6)
591591

592592
// Write to V4 mapped address.
593-
testFailingWrite(c, context.UnicastV4in6, &tcpip.ErrNetworkUnreachable{})
593+
testFailingWrite(c, context.UnicastV4in6, arbitraryPayloadSize, &tcpip.ErrNetworkUnreachable{})
594594
const want = 1
595595
if got := c.EP.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
596596
c.T.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
@@ -611,7 +611,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
611611
testWrite(c, context.UnicastV4in6)
612612

613613
// Write to v6 address.
614-
testFailingWrite(c, context.UnicastV6, &tcpip.ErrInvalidEndpointState{})
614+
testFailingWrite(c, context.UnicastV6, arbitraryPayloadSize, &tcpip.ErrInvalidEndpointState{})
615615
}
616616

617617
func TestV4WriteOnV6Only(t *testing.T) {
@@ -621,7 +621,7 @@ func TestV4WriteOnV6Only(t *testing.T) {
621621
c.CreateEndpointForFlow(context.UnicastV6Only, udp.ProtocolNumber)
622622

623623
// Write to V4 mapped address.
624-
testFailingWrite(c, context.UnicastV4in6, &tcpip.ErrNoRoute{})
624+
testFailingWrite(c, context.UnicastV4in6, arbitraryPayloadSize, &tcpip.ErrNoRoute{})
625625
}
626626

627627
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
@@ -636,7 +636,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
636636
}
637637

638638
// Write to v6 address.
639-
testFailingWrite(c, context.UnicastV6, &tcpip.ErrInvalidEndpointState{})
639+
testFailingWrite(c, context.UnicastV6, arbitraryPayloadSize, &tcpip.ErrInvalidEndpointState{})
640640
}
641641

642642
func TestV6WriteOnConnected(t *testing.T) {
@@ -1772,7 +1772,7 @@ func TestShutdownWrite(t *testing.T) {
17721772
t.Fatalf("Shutdown failed: %s", err)
17731773
}
17741774

1775-
testFailingWrite(c, context.UnicastV6, &tcpip.ErrClosedForSend{})
1775+
testFailingWrite(c, context.UnicastV6, arbitraryPayloadSize, &tcpip.ErrClosedForSend{})
17761776
}
17771777

17781778
func TestOutgoingSubnetBroadcast(t *testing.T) {
@@ -2067,6 +2067,22 @@ func TestChecksumWithZeroValueOnesComplementSum(t *testing.T) {
20672067
}
20682068
}
20692069

2070+
// TestWritePayloadSizeTooBig verifies that writing anything bigger than
2071+
// header.UDPMaximumPacketSize fails.
2072+
func TestWritePayloadSizeTooBig(t *testing.T) {
2073+
c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4})
2074+
defer c.Cleanup()
2075+
2076+
c.CreateEndpoint(ipv6.ProtocolNumber, udp.ProtocolNumber)
2077+
2078+
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != nil {
2079+
c.T.Fatalf("Connect failed: %s", err)
2080+
}
2081+
2082+
testWrite(c, context.UnicastV6)
2083+
testFailingWrite(c, context.UnicastV6, header.UDPMaximumPacketSize+1, &tcpip.ErrMessageTooLong{})
2084+
}
2085+
20702086
func TestMain(m *testing.M) {
20712087
refs.SetLeakMode(refs.LeaksPanic)
20722088
code := m.Run()

test/syscalls/linux/socket_ip_udp_generic.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <errno.h>
1818
#ifdef __linux__
19+
#include <linux/errqueue.h>
1920
#include <linux/in6.h>
2021
#endif // __linux__
2122
#include <netinet/in.h>
@@ -541,5 +542,107 @@ TEST_P(UDPSocketPairTest, GetSocketAcceptConn) {
541542
EXPECT_EQ(got, 0);
542543
}
543544

545+
#ifdef __linux__
546+
TEST_P(UDPSocketPairTest, PayloadTooBig) {
547+
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
548+
549+
// Set IP_RECVERR socket option to enable error queueing.
550+
int v = kSockOptOn;
551+
socklen_t optlen = sizeof(v);
552+
int opt_level = SOL_IP;
553+
int opt_type = IP_RECVERR;
554+
if (sockets->first_addr()->sa_family == AF_INET6) {
555+
opt_level = SOL_IPV6;
556+
opt_type = IPV6_RECVERR;
557+
}
558+
ASSERT_THAT(setsockopt(sockets->first_fd(), opt_level, opt_type, &v, optlen),
559+
SyscallSucceeds());
560+
561+
// Buffers bigger than 0xffff should receive an error.
562+
const int kBufLen = 0x10000;
563+
char buf[kBufLen];
564+
RandomizeBuffer(buf, sizeof(buf));
565+
566+
EXPECT_THAT(send(sockets->first_fd(), buf, sizeof(buf), 0),
567+
SyscallFailsWithErrno(EMSGSIZE));
568+
569+
// Dequeue error using recvmsg(MSG_ERRQUEUE). Give a buffer big-enough for
570+
// the original message just in case.
571+
char got[kBufLen];
572+
struct iovec iov;
573+
iov.iov_base = reinterpret_cast<void*>(got);
574+
iov.iov_len = kBufLen;
575+
576+
const int addrlen_ = sockets->second_addr_size();
577+
size_t control_buf_len = CMSG_SPACE(sizeof(sock_extended_err) + addrlen_);
578+
std::vector<char> control_buf(control_buf_len);
579+
struct sockaddr_storage remote;
580+
memset(&remote, 0, sizeof(remote));
581+
struct msghdr msg = {};
582+
msg.msg_iov = &iov;
583+
msg.msg_iovlen = 1;
584+
msg.msg_flags = 0;
585+
msg.msg_control = control_buf.data();
586+
msg.msg_controllen = control_buf_len;
587+
msg.msg_name = reinterpret_cast<void*>(&remote);
588+
msg.msg_namelen = addrlen_;
589+
590+
struct sockaddr_storage addr;
591+
optlen = sizeof(addr);
592+
EXPECT_THAT(getpeername(sockets->first_fd(), AsSockAddr(&addr), &optlen),
593+
SyscallSucceeds());
594+
bool ipv6 = false;
595+
if (addr.ss_family == AF_INET6) {
596+
auto ipv6addr = reinterpret_cast<struct sockaddr_in6*>(&addr);
597+
598+
// Exclude IPv4-mapped addresses.
599+
uint8_t v4MappedPrefix[12] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
600+
0x00, 0x00, 0x00, 0x00, 0xff, 0xff};
601+
ipv6 = memcmp(&ipv6addr->sin6_addr.s6_addr[0], v4MappedPrefix,
602+
sizeof(v4MappedPrefix)) != 0;
603+
}
604+
// Native behaviour for IPv4 packets is to not report to ERRQUEUE.
605+
if (!ipv6) {
606+
EXPECT_THAT(recvmsg(sockets->first_fd(), &msg, MSG_ERRQUEUE),
607+
SyscallFailsWithErrno(EAGAIN));
608+
return;
609+
}
610+
611+
ASSERT_THAT(recvmsg(sockets->first_fd(), &msg, MSG_ERRQUEUE),
612+
SyscallSucceedsWithValue(0));
613+
614+
EXPECT_NE(msg.msg_flags & MSG_ERRQUEUE, 0);
615+
EXPECT_EQ(memcmp(&remote, sockets->second_addr(), addrlen_), 0);
616+
617+
// Check the contents of the control message.
618+
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
619+
ASSERT_NE(cmsg, nullptr);
620+
EXPECT_EQ(CMSG_NXTHDR(&msg, cmsg), nullptr);
621+
EXPECT_EQ(cmsg->cmsg_level, opt_level);
622+
EXPECT_EQ(cmsg->cmsg_type, opt_type);
623+
EXPECT_EQ(cmsg->cmsg_len,
624+
sizeof(sock_extended_err) + addrlen_ + sizeof(cmsghdr));
625+
626+
// Check the contents of socket error.
627+
struct sock_extended_err* sock_err =
628+
reinterpret_cast<sock_extended_err*>(CMSG_DATA(cmsg));
629+
EXPECT_EQ(sock_err->ee_errno, EMSGSIZE);
630+
EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_LOCAL);
631+
EXPECT_EQ(sock_err->ee_type, ICMP_ECHOREPLY);
632+
EXPECT_EQ(sock_err->ee_code, ICMP_NET_UNREACH);
633+
EXPECT_EQ(sock_err->ee_info, kBufLen);
634+
EXPECT_EQ(sock_err->ee_data, 0);
635+
636+
// Verify that no socket error was put on the queue.
637+
int err;
638+
optlen = sizeof(err);
639+
ASSERT_THAT(
640+
getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ERROR, &err, &optlen),
641+
SyscallSucceeds());
642+
ASSERT_EQ(err, 0);
643+
ASSERT_EQ(optlen, sizeof(err));
644+
}
645+
#endif // __linux__
646+
544647
} // namespace testing
545648
} // namespace gvisor

0 commit comments

Comments
 (0)