Skip to content

Commit b401fda

Browse files
cmccabeomkreddy
authored andcommitted
MINOR: Add more validation during KRPC deserialization
When deserializing KRPC (which is used for RPCs sent to Kafka, Kafka Metadata records, and some other things), check that we have at least N bytes remaining before allocating an array of size N. Remove DataInputStreamReadable since it was hard to make this class aware of how many bytes were remaining. Instead, when reading an individual record in the Raft layer, simply create a ByteBufferAccessor with a ByteBuffer containing just the bytes we're interested in. Add SimpleArraysMessageTest and ByteBufferAccessorTest. Also add some additional tests in RequestResponseTest. Reviewers: Tom Bentley <[email protected]>, Mickael Maison <[email protected]>, Colin McCabe <[email protected]> Co-authored-by: Colin McCabe <[email protected]> Co-authored-by: Manikumar Reddy <[email protected]> Co-authored-by: Mickael Maison <[email protected]>
1 parent 8e522c5 commit b401fda

File tree

16 files changed

+433
-188
lines changed

16 files changed

+433
-188
lines changed

checkstyle/suppressions.xml

+4
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@
170170
<suppress checks="JavaNCSS"
171171
files="DistributedHerderTest.java"/>
172172

173+
<!-- Raft -->
174+
<suppress checks="NPathComplexity"
175+
files="RecordsIterator.java"/>
176+
173177
<!-- Streams -->
174178
<suppress checks="ClassFanOutComplexity"
175179
files="(KafkaStreams|KStreamImpl|KTableImpl|InternalTopologyBuilder|StreamsPartitionAssignor|StreamThread|IQv2StoreIntegrationTest|KStreamImplTest).java"/>

clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,15 @@ public double readDouble() {
5454
}
5555

5656
@Override
57-
public void readArray(byte[] arr) {
57+
public byte[] readArray(int size) {
58+
int remaining = buf.remaining();
59+
if (size > remaining) {
60+
throw new RuntimeException("Error reading byte array of " + size + " byte(s): only " + remaining +
61+
" byte(s) available");
62+
}
63+
byte[] arr = new byte[size];
5864
buf.get(arr);
65+
return arr;
5966
}
6067

6168
@Override

clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java

-139
This file was deleted.

clients/src/main/java/org/apache/kafka/common/protocol/Readable.java

+3-5
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,23 @@ public interface Readable {
3232
int readInt();
3333
long readLong();
3434
double readDouble();
35-
void readArray(byte[] arr);
35+
byte[] readArray(int length);
3636
int readUnsignedVarint();
3737
ByteBuffer readByteBuffer(int length);
3838
int readVarint();
3939
long readVarlong();
4040
int remaining();
4141

4242
default String readString(int length) {
43-
byte[] arr = new byte[length];
44-
readArray(arr);
43+
byte[] arr = readArray(length);
4544
return new String(arr, StandardCharsets.UTF_8);
4645
}
4746

4847
default List<RawTaggedField> readUnknownTaggedField(List<RawTaggedField> unknowns, int tag, int size) {
4948
if (unknowns == null) {
5049
unknowns = new ArrayList<>();
5150
}
52-
byte[] data = new byte[size];
53-
readArray(data);
51+
byte[] data = readArray(size);
5452
unknowns.add(new RawTaggedField(tag, data));
5553
return unknowns;
5654
}

clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java

+2
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ private static DefaultRecord readFrom(ByteBuffer buffer,
342342
int numHeaders = ByteUtils.readVarint(buffer);
343343
if (numHeaders < 0)
344344
throw new InvalidRecordException("Found invalid number of record headers " + numHeaders);
345+
if (numHeaders > buffer.remaining())
346+
throw new InvalidRecordException("Found invalid number of record headers. " + numHeaders + " is larger than the remaining size of the buffer");
345347

346348
final Header[] headers;
347349
if (numHeaders == 0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.kafka.common.message;
18+
19+
import org.apache.kafka.common.protocol.ByteBufferAccessor;
20+
import org.junit.jupiter.api.Test;
21+
22+
import java.nio.ByteBuffer;
23+
24+
import static org.junit.jupiter.api.Assertions.assertEquals;
25+
import static org.junit.jupiter.api.Assertions.assertThrows;
26+
27+
public class SimpleArraysMessageTest {
28+
@Test
29+
public void testArrayBoundsChecking() {
30+
// SimpleArraysMessageData takes 2 arrays
31+
final ByteBuffer buf = ByteBuffer.wrap(new byte[] {
32+
(byte) 0x7f, // Set size of first array to 126 which is larger than the size of this buffer
33+
(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
34+
});
35+
final SimpleArraysMessageData out = new SimpleArraysMessageData();
36+
ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
37+
assertEquals("Tried to allocate a collection of size 126, but there are only 7 bytes remaining.",
38+
assertThrows(RuntimeException.class, () -> out.read(accessor, (short) 2)).getMessage());
39+
}
40+
41+
@Test
42+
public void testArrayBoundsCheckingOtherArray() {
43+
// SimpleArraysMessageData takes 2 arrays
44+
final ByteBuffer buf = ByteBuffer.wrap(new byte[] {
45+
(byte) 0x01, // Set size of first array to 0
46+
(byte) 0x7e, // Set size of second array to 125 which is larger than the size of this buffer
47+
(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
48+
});
49+
final SimpleArraysMessageData out = new SimpleArraysMessageData();
50+
ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
51+
assertEquals("Tried to allocate a collection of size 125, but there are only 6 bytes remaining.",
52+
assertThrows(RuntimeException.class, () -> out.read(accessor, (short) 2)).getMessage());
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.kafka.common.protocol;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import java.nio.ByteBuffer;
22+
import java.nio.charset.StandardCharsets;
23+
24+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
25+
import static org.junit.jupiter.api.Assertions.assertEquals;
26+
import static org.junit.jupiter.api.Assertions.assertThrows;
27+
28+
public class ByteBufferAccessorTest {
29+
@Test
30+
public void testReadArray() {
31+
ByteBuffer buf = ByteBuffer.allocate(1024);
32+
ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
33+
final byte[] testArray = new byte[] {0x4b, 0x61, 0x46};
34+
accessor.writeByteArray(testArray);
35+
accessor.writeInt(12345);
36+
accessor.flip();
37+
final byte[] testArray2 = accessor.readArray(3);
38+
assertArrayEquals(testArray, testArray2);
39+
assertEquals(12345, accessor.readInt());
40+
assertEquals("Error reading byte array of 3 byte(s): only 0 byte(s) available",
41+
assertThrows(RuntimeException.class,
42+
() -> accessor.readArray(3)).getMessage());
43+
}
44+
45+
@Test
46+
public void testReadString() {
47+
ByteBuffer buf = ByteBuffer.allocate(1024);
48+
ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
49+
String testString = "ABC";
50+
final byte[] testArray = testString.getBytes(StandardCharsets.UTF_8);
51+
accessor.writeByteArray(testArray);
52+
accessor.flip();
53+
assertEquals("ABC", accessor.readString(3));
54+
assertEquals("Error reading byte array of 2 byte(s): only 0 byte(s) available",
55+
assertThrows(RuntimeException.class,
56+
() -> accessor.readString(2)).getMessage());
57+
}
58+
}

clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java

+14
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,20 @@ public void testInvalidNumHeaders() {
247247
buf.flip();
248248
assertThrows(InvalidRecordException.class,
249249
() -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null));
250+
251+
ByteBuffer buf2 = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes));
252+
ByteUtils.writeVarint(sizeOfBodyInBytes, buf2);
253+
buf2.put(attributes);
254+
ByteUtils.writeVarlong(timestampDelta, buf2);
255+
ByteUtils.writeVarint(offsetDelta, buf2);
256+
ByteUtils.writeVarint(-1, buf2); // null key
257+
ByteUtils.writeVarint(-1, buf2); // null value
258+
ByteUtils.writeVarint(sizeOfBodyInBytes, buf2); // more headers than remaining buffer size, not allowed
259+
buf2.position(buf2.limit());
260+
261+
buf2.flip();
262+
assertThrows(InvalidRecordException.class,
263+
() -> DefaultRecord.readFrom(buf2, 0L, 0L, RecordBatch.NO_SEQUENCE, null));
250264
}
251265

252266
@Test

0 commit comments

Comments
 (0)