diff --git a/src/main/java/org/logstash/beats/AckEncoder.java b/src/main/java/org/logstash/beats/AckEncoder.java index d78e1dda..d4e57f60 100644 --- a/src/main/java/org/logstash/beats/AckEncoder.java +++ b/src/main/java/org/logstash/beats/AckEncoder.java @@ -10,6 +10,11 @@ * */ public class AckEncoder extends MessageToByteEncoder { + + public AckEncoder() { + super(false); + } + @Override protected void encode(ChannelHandlerContext ctx, Ack ack, ByteBuf out) throws Exception { out.writeByte(ack.getProtocol()); diff --git a/src/main/java/org/logstash/beats/BeatsHandler.java b/src/main/java/org/logstash/beats/BeatsHandler.java index 15dfb7e9..d69aa1a7 100644 --- a/src/main/java/org/logstash/beats/BeatsHandler.java +++ b/src/main/java/org/logstash/beats/BeatsHandler.java @@ -1,5 +1,7 @@ package org.logstash.beats; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import org.apache.logging.log4j.LogManager; @@ -93,7 +95,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } } else { final Throwable realCause = extractCause(cause, 0); - if (logger.isDebugEnabled()){ + if (logger.isDebugEnabled()) { logger.info(format("Handling exception: " + cause + " (caused by: " + realCause + ")"), cause); } else { logger.info(format("Handling exception: " + cause + " (caused by: " + realCause + ")")); diff --git a/src/main/java/org/logstash/beats/BeatsParser.java b/src/main/java/org/logstash/beats/BeatsParser.java index 812150b1..25e6fa27 100644 --- a/src/main/java/org/logstash/beats/BeatsParser.java +++ b/src/main/java/org/logstash/beats/BeatsParser.java @@ -21,6 +21,8 @@ public class BeatsParser extends ByteToMessageDecoder { private final static Logger logger = LogManager.getLogger(BeatsParser.class); + private static final int CHUNK_SIZE = 64 * 1024; // chuck size of compressed data to be read. + private Batch batch; private enum States { @@ -30,7 +32,9 @@ private enum States { READ_JSON_HEADER(8), READ_COMPRESSED_FRAME_HEADER(4), READ_COMPRESSED_FRAME(-1), // -1 means the length to read is variable and defined in the frame itself. + READ_COMPRESSED_FRAME_JAVA_HEAP(-1), // -1 means the length to read is variable and defined in the frame itself. READ_JSON(-1), + READ_JSON_JAVA_HEAP(-1), READ_DATA_FIELDS(-1); private int length; @@ -41,10 +45,53 @@ private enum States { } + static class ChunkedAccumulator { + private ByteBuf payloadAccumulator; + private int readBytes; // count of bytes actually read + private int payloadSize; // total size of compressed payload + + /** + * Return the chunk size to read + * */ + public int startRead(int payloadSize, ChannelHandlerContext ctx) { + this.payloadSize = payloadSize; + this.readBytes = 0; + payloadAccumulator = ctx.alloc().heapBuffer(payloadSize); + // read compressed payload at most in chuck of 64Kb and aggregate in Java heap + return Math.min(this.payloadSize, CHUNK_SIZE); + } + + public void readChunk(ByteBuf in) { + int missedBytes = payloadSize - readBytes; + int readBytes = Math.min(in.readableBytes(), missedBytes); + in.readBytes(payloadAccumulator, readBytes); + this.readBytes += readBytes; + } + + public boolean isReadComplete() { + return readBytes == payloadSize; + } + + public void stopAccumulating() { + payloadSize = -1; + readBytes = -1; + payloadAccumulator.release(); + } + + public ByteBuf getPayload() { + return payloadAccumulator; + } + + public int getPayloadSize() { + return payloadSize; + } + } + private States currentState = States.READ_HEADER; private int requiredBytes = 0; private int sequence = 0; private boolean decodingCompressedBuffer = false; + private final ChunkedAccumulator accumulator = new ChunkedAccumulator(); @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws InvalidFrameProtocolException, IOException { @@ -166,23 +213,25 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t sequence = (int) in.readUnsignedInt(); int jsonPayloadSize = (int) in.readUnsignedInt(); - if(jsonPayloadSize <= 0) { + if (jsonPayloadSize <= 0) { throw new InvalidFrameProtocolException("Invalid json length, received: " + jsonPayloadSize); } - - transition(States.READ_JSON, jsonPayloadSize); + logger.trace("READ_JSON_HEADER: jsonPayloadSize: {}", jsonPayloadSize); + final int bytesToRead = accumulator.startRead(jsonPayloadSize, ctx); + transition(States.READ_JSON_JAVA_HEAP, bytesToRead); break; } case READ_COMPRESSED_FRAME_HEADER: { logger.trace("Running: READ_COMPRESSED_FRAME_HEADER"); - transition(States.READ_COMPRESSED_FRAME, in.readInt()); + final int bytesToRead = accumulator.startRead(in.readInt(), ctx); + transition(States.READ_COMPRESSED_FRAME_JAVA_HEAP, bytesToRead); break; } case READ_COMPRESSED_FRAME: { logger.trace("Running: READ_COMPRESSED_FRAME"); - inflateCompressedFrame(ctx, in, (buffer) -> { + inflateCompressedFrame(ctx, in, requiredBytes, (buffer) -> { transition(States.READ_HEADER); decodingCompressedBuffer = true; @@ -197,11 +246,38 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t }); break; } + case READ_COMPRESSED_FRAME_JAVA_HEAP: { + logger.trace("Running: READ_COMPRESSED_FRAME_JAVA_HEAP"); + accumulator.readChunk(in); + + if (accumulator.isReadComplete()) { + logger.debug("Finished to accumulate"); + // inflate compressedAccumulator in heap + inflateCompressedFrame(ctx, accumulator.getPayload(), accumulator.getPayloadSize(), (buffer) -> { + transition(States.READ_HEADER); + accumulator.stopAccumulating(); + + decodingCompressedBuffer = true; + try { + while (buffer.readableBytes() > 0) { + decode(ctx, buffer, out); + } + } finally { + decodingCompressedBuffer = false; + transition(States.READ_HEADER); + } + }); + } else { + logger.debug("Read next chunk"); + transition(States.READ_COMPRESSED_FRAME_JAVA_HEAP, CHUNK_SIZE); + } + break; + } case READ_JSON: { logger.trace("Running: READ_JSON"); ((V2Batch)batch).addMessage(sequence, in, requiredBytes); - if(batch.isComplete()) { - if(logger.isTraceEnabled()) { + if (batch.isComplete()) { + if (logger.isTraceEnabled()) { logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence); } out.add(batch); @@ -211,28 +287,54 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t transition(States.READ_HEADER); break; } + case READ_JSON_JAVA_HEAP: { + logger.trace("Running: READ_JSON_JAVA_HEAP"); + accumulator.readChunk(in); + + if (accumulator.isReadComplete()) { + logger.trace("Finished to accumulate: READ_JSON_JAVA_HEAP"); + + ByteBuf payload = accumulator.getPayload(); + ((V2Batch) batch).addMessage(sequence, payload, accumulator.getPayloadSize()); + accumulator.stopAccumulating(); + if (batch.isComplete()) { + if (logger.isTraceEnabled()) { + logger.trace("Sending batch size: {}, windowSize: {} , seq: {}", + this.batch.size(), batch.getBatchSize(), sequence); + } + out.add(batch); + batchComplete(); + } + + transition(States.READ_HEADER); + } else { + logger.trace("Read next chunk"); + transition(States.READ_JSON_JAVA_HEAP, CHUNK_SIZE); + } + break; + } } } - private void inflateCompressedFrame(final ChannelHandlerContext ctx, final ByteBuf in, final CheckedConsumer fn) + private void inflateCompressedFrame(final ChannelHandlerContext ctx, final ByteBuf in, int deflatedSize, final CheckedConsumer fn) throws IOException { // Use the compressed size as the safe start for the buffer. - ByteBuf buffer = ctx.alloc().buffer(requiredBytes); + ByteBuf buffer = ctx.alloc().heapBuffer(deflatedSize); try { - decompressImpl(in, buffer); + decompressImpl(in, buffer, deflatedSize); fn.accept(buffer); } finally { buffer.release(); } } - private void decompressImpl(final ByteBuf in, final ByteBuf out) throws IOException { + private void decompressImpl(final ByteBuf in, final ByteBuf out, int deflatedSize) throws IOException { Inflater inflater = new Inflater(); try ( ByteBufOutputStream buffOutput = new ByteBufOutputStream(out); InflaterOutputStream inflaterStream = new InflaterOutputStream(buffOutput, inflater) ) { - in.readBytes(inflaterStream, requiredBytes); + in.readBytes(inflaterStream, deflatedSize); } finally { inflater.end(); } diff --git a/src/main/java/org/logstash/beats/V2Batch.java b/src/main/java/org/logstash/beats/V2Batch.java index 84c529fa..8a278528 100644 --- a/src/main/java/org/logstash/beats/V2Batch.java +++ b/src/main/java/org/logstash/beats/V2Batch.java @@ -9,15 +9,15 @@ * Implementation of {@link Batch} for the v2 protocol backed by ByteBuf. *must* be released after use. */ public class V2Batch implements Batch { - private ByteBuf internalBuffer = PooledByteBufAllocator.DEFAULT.buffer(); + private final ByteBuf internalBuffer = PooledByteBufAllocator.DEFAULT.heapBuffer(); private int written = 0; private int read = 0; private static final int SIZE_OF_INT = 4; private int batchSize; private int highestSequence = -1; - public void setProtocol(byte protocol){ - if (protocol != Protocol.VERSION_2){ + public void setProtocol(byte protocol) { + if (protocol != Protocol.VERSION_2) { throw new IllegalArgumentException("Only version 2 protocol is supported"); } } @@ -27,7 +27,7 @@ public byte getProtocol() { return Protocol.VERSION_2; } - public Iterator iterator(){ + public Iterator iterator() { internalBuffer.resetReaderIndex(); return new Iterator() { @Override @@ -80,19 +80,19 @@ public int getHighestSequence(){ /** * Adds a message to the batch, which will be constructed into an actual {@link Message} lazily. - * @param sequenceNumber sequence number of the message within the batch + * @param sequenceNumber sequence number of the message within the batch * @param buffer A ByteBuf pointing to serialized JSon * @param size size of the serialized Json */ void addMessage(int sequenceNumber, ByteBuf buffer, int size) { written++; - if (internalBuffer.writableBytes() < size + (2 * SIZE_OF_INT)){ + if (internalBuffer.writableBytes() < size + (2 * SIZE_OF_INT)) { internalBuffer.capacity(internalBuffer.capacity() + size + (2 * SIZE_OF_INT)); } internalBuffer.writeInt(sequenceNumber); internalBuffer.writeInt(size); buffer.readBytes(internalBuffer, size); - if (sequenceNumber > highestSequence){ + if (sequenceNumber > highestSequence) { highestSequence = sequenceNumber; } } diff --git a/src/test/java/org/logstash/beats/BeatsParserTest.java b/src/test/java/org/logstash/beats/BeatsParserTest.java index 4fde2994..f84b4f0f 100644 --- a/src/test/java/org/logstash/beats/BeatsParserTest.java +++ b/src/test/java/org/logstash/beats/BeatsParserTest.java @@ -39,7 +39,7 @@ public void setup() throws Exception{ this.v1Batch = new V1Batch(); for(int i = 1; i <= numberOfMessage; i++) { - Map map = new HashMap(); + Map map = new HashMap<>(); map.put("line", "Another world"); map.put("from", "Little big Adventure"); @@ -50,7 +50,7 @@ public void setup() throws Exception{ this.byteBufBatch = new V2Batch(); for(int i = 1; i <= numberOfMessage; i++) { - Map map = new HashMap(); + Map map = new HashMap<>(); map.put("line", "Another world"); map.put("from", "Little big Adventure"); ByteBuf bytebuf = Unpooled.wrappedBuffer(MAPPER.writeValueAsBytes(map));