diff --git a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java index 9e585a22..2bab3d17 100644 --- a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java +++ b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java @@ -78,36 +78,42 @@ protected byte decodeProtocolVersion(ByteBuf in) { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - in.markReaderIndex(); - ProtocolCode protocolCode; - Protocol protocol; try { - protocolCode = decodeProtocolCode(in); - if (protocolCode == null) { - // read to end - return; - } + in.markReaderIndex(); + ProtocolCode protocolCode; + Protocol protocol; + try { + protocolCode = decodeProtocolCode(in); + if (protocolCode == null) { + // read to end + return; + } - byte protocolVersion = decodeProtocolVersion(in); - if (ctx.channel().attr(Connection.PROTOCOL).get() == null) { - ctx.channel().attr(Connection.PROTOCOL).set(protocolCode); - if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) { - ctx.channel().attr(Connection.VERSION).set(protocolVersion); + byte protocolVersion = decodeProtocolVersion(in); + if (ctx.channel().attr(Connection.PROTOCOL).get() == null) { + ctx.channel().attr(Connection.PROTOCOL).set(protocolCode); + if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) { + ctx.channel().attr(Connection.VERSION).set(protocolVersion); + } } + + protocol = ProtocolManager.getProtocol(protocolCode); + } finally { + // reset the readerIndex before throwing an exception or decoding content + // to ensure that the packet is complete + in.resetReaderIndex(); } - protocol = ProtocolManager.getProtocol(protocolCode); - } finally { - // reset the readerIndex before throwing an exception or decoding content - // to ensure that the packet is complete - in.resetReaderIndex(); - } + if (protocol == null) { + throw new CodecException("Unknown protocol code: [" + protocolCode + + "] while decode in ProtocolDecoder."); + } - if (protocol == null) { - throw new CodecException("Unknown protocol code: [" + protocolCode - + "] while decode in ProtocolDecoder."); + protocol.getDecoder().decode(ctx, in, out); + } catch (Exception e) { + // 清空可读取区域,让 AbstractBatchDecoder#L257行release它 + in.skipBytes(in.readableBytes()); + throw e; } - - protocol.getDecoder().decode(ctx, in, out); } } diff --git a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java index b7b16da9..44617622 100644 --- a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java +++ b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java @@ -30,6 +30,7 @@ import io.netty.channel.ChannelProgressivePromise; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.Attribute; import io.netty.util.AttributeKey; import io.netty.util.concurrent.EventExecutor; @@ -52,6 +53,7 @@ public void testDecodeIllegalPacket() throws Exception { ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); int readerIndex = byteBuf.readerIndex(); + int readableBytes = byteBuf.readableBytes(); Assert.assertEquals(0, readerIndex); Exception exception = null; @@ -65,7 +67,33 @@ public void testDecodeIllegalPacket() throws Exception { Assert.assertNotNull(exception); readerIndex = byteBuf.readerIndex(); + Assert.assertEquals(readableBytes, readerIndex); + } + + @Test + public void testDecodeIllegalPacket2() { + EmbeddedChannel channel = new EmbeddedChannel(); + ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); + channel.pipeline().addLast(decoder); + + ByteBuf byteBuf = ByteBufAllocator.DEFAULT.buffer(8); + byteBuf.writeByte((byte) 13); + + int readerIndex = byteBuf.readerIndex(); + int readableBytes = byteBuf.readableBytes(); Assert.assertEquals(0, readerIndex); + Exception exception = null; + try { + channel.writeInbound(byteBuf); + } catch (Exception e) { + // ignore + exception = e; + } + Assert.assertNotNull(exception); + readerIndex = byteBuf.readerIndex(); + Assert.assertEquals(readableBytes, readerIndex); + + Assert.assertTrue(byteBuf.refCnt() == 0); } class MockedChannel implements Channel {