Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ProtocolCodeBasedDecoder memory leak #355

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,42 @@ protected byte decodeProtocolVersion(ByteBuf in) {

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
in.markReaderIndex();
ProtocolCode protocolCode;
Protocol protocol;
try {
protocolCode = decodeProtocolCode(in);
if (protocolCode == null) {
// read to end
return;
}
in.markReaderIndex();
ProtocolCode protocolCode;
Protocol protocol;
try {
protocolCode = decodeProtocolCode(in);
if (protocolCode == null) {
// read to end
return;
}

byte protocolVersion = decodeProtocolVersion(in);
if (ctx.channel().attr(Connection.PROTOCOL).get() == null) {
ctx.channel().attr(Connection.PROTOCOL).set(protocolCode);
if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) {
ctx.channel().attr(Connection.VERSION).set(protocolVersion);
byte protocolVersion = decodeProtocolVersion(in);
if (ctx.channel().attr(Connection.PROTOCOL).get() == null) {
ctx.channel().attr(Connection.PROTOCOL).set(protocolCode);
if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) {
ctx.channel().attr(Connection.VERSION).set(protocolVersion);
}
}

protocol = ProtocolManager.getProtocol(protocolCode);
} finally {
// reset the readerIndex before throwing an exception or decoding content
// to ensure that the packet is complete
in.resetReaderIndex();
}

protocol = ProtocolManager.getProtocol(protocolCode);
} finally {
// reset the readerIndex before throwing an exception or decoding content
// to ensure that the packet is complete
in.resetReaderIndex();
}
if (protocol == null) {
throw new CodecException("Unknown protocol code: [" + protocolCode
+ "] while decode in ProtocolDecoder.");
}

if (protocol == null) {
throw new CodecException("Unknown protocol code: [" + protocolCode
+ "] while decode in ProtocolDecoder.");
protocol.getDecoder().decode(ctx, in, out);
} catch (Exception e) {
// 清空可读取区域,让 AbstractBatchDecoder#L257行release它
in.skipBytes(in.readableBytes());
throw e;
}

protocol.getDecoder().decode(ctx, in, out);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.netty.channel.ChannelProgressivePromise;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.EventExecutor;
Expand All @@ -52,6 +53,7 @@ public void testDecodeIllegalPacket() throws Exception {
ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1);

int readerIndex = byteBuf.readerIndex();
int readableBytes = byteBuf.readableBytes();
Assert.assertEquals(0, readerIndex);

Exception exception = null;
Expand All @@ -65,7 +67,33 @@ public void testDecodeIllegalPacket() throws Exception {
Assert.assertNotNull(exception);

readerIndex = byteBuf.readerIndex();
Assert.assertEquals(readableBytes, readerIndex);
}

@Test
public void testDecodeIllegalPacket2() {
EmbeddedChannel channel = new EmbeddedChannel();
ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1);
channel.pipeline().addLast(decoder);

ByteBuf byteBuf = ByteBufAllocator.DEFAULT.buffer(8);
byteBuf.writeByte((byte) 13);

int readerIndex = byteBuf.readerIndex();
int readableBytes = byteBuf.readableBytes();
Assert.assertEquals(0, readerIndex);
Exception exception = null;
try {
channel.writeInbound(byteBuf);
} catch (Exception e) {
// ignore
exception = e;
}
Assert.assertNotNull(exception);
readerIndex = byteBuf.readerIndex();
Assert.assertEquals(readableBytes, readerIndex);

Assert.assertTrue(byteBuf.refCnt() == 0);
}

class MockedChannel implements Channel {
Expand Down
Loading