forked from envoyproxy/envoy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
websocket: websocket codec implementation (envoyproxy#22542)
Commit Message: websocket: websocket codec implementation Testing: Unit tests added Additional Description: Related issue comment => envoyproxy#13877 (comment) This is as the initial step to support rate limiting for WebSocket. Signed-off-by: Amila Senadheera <[email protected]>
- Loading branch information
1 parent
5183dbf
commit e54aaa3
Showing
6 changed files
with
1,283 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
load( | ||
"//bazel:envoy_build_system.bzl", | ||
"envoy_cc_library", | ||
"envoy_package", | ||
) | ||
|
||
licenses(["notice"]) # Apache 2 | ||
|
||
envoy_package() | ||
|
||
envoy_cc_library( | ||
name = "codec_lib", | ||
srcs = ["codec.cc"], | ||
hdrs = ["codec.h"], | ||
deps = [ | ||
"//envoy/buffer:buffer_interface", | ||
"//source/common/buffer:buffer_lib", | ||
"//source/common/common:minimal_logger_lib", | ||
"//source/common/common:scalar_to_byte_vector_lib", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
#include "source/common/websocket/codec.h" | ||
|
||
#include <algorithm> | ||
#include <array> | ||
#include <cstdint> | ||
#include <memory> | ||
#include <vector> | ||
|
||
#include "source/common/buffer/buffer_impl.h" | ||
#include "source/common/common/scalar_to_byte_vector.h" | ||
|
||
namespace Envoy { | ||
namespace WebSocket { | ||
|
||
absl::optional<std::vector<uint8_t>> Encoder::encodeFrameHeader(const Frame& frame) { | ||
if (std::find(kFrameOpcodes.begin(), kFrameOpcodes.end(), frame.opcode_) == kFrameOpcodes.end()) { | ||
ENVOY_LOG(debug, "Failed to encode websocket frame with invalid opcode: {}", frame.opcode_); | ||
return absl::nullopt; | ||
} | ||
std::vector<uint8_t> output; | ||
// Set flags and opcode | ||
pushScalarToByteVector( | ||
static_cast<uint8_t>(frame.final_fragment_ ? (0x80 | frame.opcode_) : frame.opcode_), output); | ||
|
||
// Set payload length | ||
if (frame.payload_length_ <= 125) { | ||
// Set mask bit and 7-bit length | ||
pushScalarToByteVector(frame.masking_key_.has_value() | ||
? static_cast<uint8_t>(frame.payload_length_ | 0x80) | ||
: static_cast<uint8_t>(frame.payload_length_), | ||
output); | ||
} else if (frame.payload_length_ <= 65535) { | ||
// Set mask bit and 16-bit length indicator | ||
pushScalarToByteVector(static_cast<uint8_t>(frame.masking_key_.has_value() ? 0xfe : 0x7e), | ||
output); | ||
// Set 16-bit length | ||
pushScalarToByteVector(htobe16(frame.payload_length_), output); | ||
} else { | ||
// Set mask bit and 64-bit length indicator | ||
pushScalarToByteVector(static_cast<uint8_t>(frame.masking_key_.has_value() ? 0xff : 0x7f), | ||
output); | ||
// Set 64-bit length | ||
pushScalarToByteVector(htobe64(frame.payload_length_), output); | ||
} | ||
// Set masking key | ||
if (frame.masking_key_.has_value()) { | ||
pushScalarToByteVector(htobe32(frame.masking_key_.value()), output); | ||
} | ||
return output; | ||
} | ||
|
||
void Decoder::frameDataStart() { | ||
frame_.payload_length_ = length_; | ||
if (length_ == 0) { | ||
state_ = State::FrameFinished; | ||
} else { | ||
if (max_payload_buffer_length_ > 0) { | ||
frame_.payload_ = std::make_unique<Buffer::OwnedImpl>(); | ||
} | ||
state_ = State::FramePayload; | ||
} | ||
} | ||
|
||
void Decoder::frameData(const uint8_t* mem, uint64_t length) { | ||
if (max_payload_buffer_length_ > 0) { | ||
uint64_t allowed_length = max_payload_buffer_length_ - frame_.payload_->length(); | ||
frame_.payload_->add(mem, length <= allowed_length ? length : allowed_length); | ||
} | ||
} | ||
|
||
void Decoder::frameDataEnd(std::vector<Frame>& output) { | ||
output.push_back(std::move(frame_)); | ||
resetDecoder(); | ||
} | ||
|
||
void Decoder::resetDecoder() { | ||
frame_ = {false, 0, absl::nullopt, 0, nullptr}; | ||
state_ = State::FrameHeaderFlagsAndOpcode; | ||
length_ = 0; | ||
num_remaining_extended_length_bytes_ = 0; | ||
num_remaining_masking_key_bytes_ = 0; | ||
} | ||
|
||
uint8_t Decoder::doDecodeFlagsAndOpcode(absl::Span<const uint8_t>& data) { | ||
// Validate opcode (last 4 bits) | ||
uint8_t opcode = data.front() & 0x0f; | ||
if (std::find(kFrameOpcodes.begin(), kFrameOpcodes.end(), opcode) == kFrameOpcodes.end()) { | ||
ENVOY_LOG(debug, "Failed to decode websocket frame with invalid opcode: {}", opcode); | ||
return 0; | ||
} | ||
frame_.opcode_ = opcode; | ||
frame_.final_fragment_ = data.front() & 0x80; | ||
state_ = State::FrameHeaderMaskFlagAndLength; | ||
return 1; | ||
} | ||
|
||
uint8_t Decoder::doDecodeMaskFlagAndLength(absl::Span<const uint8_t>& data) { | ||
num_remaining_masking_key_bytes_ = data.front() & 0x80 ? kMaskingKeyLength : 0; | ||
uint8_t length_indicator = data.front() & 0x7f; | ||
if (length_indicator == 0x7e) { | ||
num_remaining_extended_length_bytes_ = kPayloadLength16Bit; | ||
state_ = State::FrameHeaderExtendedLength16Bit; | ||
} else if (length_indicator == 0x7f) { | ||
num_remaining_extended_length_bytes_ = kPayloadLength64Bit; | ||
state_ = State::FrameHeaderExtendedLength64Bit; | ||
} else if (num_remaining_masking_key_bytes_ > 0) { | ||
length_ = length_indicator; | ||
state_ = State::FrameHeaderMaskingKey; | ||
} else { | ||
length_ = length_indicator; | ||
frameDataStart(); | ||
} | ||
return 1; | ||
} | ||
|
||
uint8_t Decoder::doDecodeExtendedLength(absl::Span<const uint8_t>& data) { | ||
uint64_t bytes_to_decode = data.length() <= num_remaining_extended_length_bytes_ | ||
? data.length() | ||
: num_remaining_extended_length_bytes_; | ||
uint8_t size_of_extended_length = | ||
state_ == State::FrameHeaderExtendedLength16Bit ? kPayloadLength16Bit : kPayloadLength64Bit; | ||
uint8_t shift_of_bytes = size_of_extended_length - num_remaining_extended_length_bytes_; | ||
uint8_t* destination = reinterpret_cast<uint8_t*>(&length_) + shift_of_bytes; | ||
|
||
ASSERT(shift_of_bytes >= 0); | ||
ASSERT(shift_of_bytes < size_of_extended_length); | ||
memcpy(destination, data.data(), bytes_to_decode); // NOLINT(safe-memcpy) | ||
num_remaining_extended_length_bytes_ -= bytes_to_decode; | ||
|
||
if (num_remaining_extended_length_bytes_ == 0) { | ||
length_ = state_ == State::FrameHeaderExtendedLength16Bit ? htobe16(length_) : htobe64(length_); | ||
if (num_remaining_masking_key_bytes_ > 0) { | ||
state_ = State::FrameHeaderMaskingKey; | ||
} else { | ||
frameDataStart(); | ||
} | ||
} | ||
return bytes_to_decode; | ||
} | ||
|
||
uint8_t Decoder::doDecodeMaskingKey(absl::Span<const uint8_t>& data) { | ||
if (!frame_.masking_key_.has_value()) { | ||
frame_.masking_key_ = 0; | ||
} | ||
uint64_t bytes_to_decode = data.length() <= num_remaining_masking_key_bytes_ | ||
? data.length() | ||
: num_remaining_masking_key_bytes_; | ||
uint8_t shift_of_bytes = kMaskingKeyLength - num_remaining_masking_key_bytes_; | ||
uint8_t* destination = | ||
reinterpret_cast<uint8_t*>(&(frame_.masking_key_.value())) + shift_of_bytes; | ||
ASSERT(shift_of_bytes >= 0); | ||
ASSERT(shift_of_bytes < kMaskingKeyLength); | ||
memcpy(destination, data.data(), bytes_to_decode); // NOLINT(safe-memcpy) | ||
num_remaining_masking_key_bytes_ -= bytes_to_decode; | ||
|
||
if (num_remaining_masking_key_bytes_ == 0) { | ||
frame_.masking_key_ = htobe32(frame_.masking_key_.value()); | ||
frameDataStart(); | ||
} | ||
return bytes_to_decode; | ||
} | ||
|
||
uint64_t Decoder::doDecodePayload(absl::Span<const uint8_t>& data) { | ||
uint64_t remain_in_buffer = data.length(); | ||
uint64_t bytes_decoded = 0; | ||
if (remain_in_buffer <= length_) { | ||
frameData(data.data(), remain_in_buffer); | ||
bytes_decoded += remain_in_buffer; | ||
length_ -= remain_in_buffer; | ||
} else { | ||
frameData(data.data(), length_); | ||
bytes_decoded += length_; | ||
length_ = 0; | ||
} | ||
if (length_ == 0) { | ||
state_ = State::FrameFinished; | ||
} | ||
return bytes_decoded; | ||
} | ||
|
||
absl::optional<std::vector<Frame>> Decoder::decode(const Buffer::Instance& input) { | ||
absl::optional<std::vector<Frame>> output = std::vector<Frame>(); | ||
for (const Buffer::RawSlice& slice : input.getRawSlices()) { | ||
absl::Span<const uint8_t> data(reinterpret_cast<uint8_t*>(slice.mem_), slice.len_); | ||
while (!data.empty() || state_ == State::FrameFinished) { | ||
uint64_t bytes_decoded = 0; | ||
switch (state_) { | ||
case State::FrameHeaderFlagsAndOpcode: | ||
bytes_decoded = doDecodeFlagsAndOpcode(data); | ||
if (bytes_decoded == 0) { | ||
return absl::nullopt; | ||
} | ||
break; | ||
case State::FrameHeaderMaskFlagAndLength: | ||
bytes_decoded = doDecodeMaskFlagAndLength(data); | ||
break; | ||
case State::FrameHeaderExtendedLength16Bit: | ||
case State::FrameHeaderExtendedLength64Bit: | ||
bytes_decoded = doDecodeExtendedLength(data); | ||
break; | ||
case State::FrameHeaderMaskingKey: | ||
bytes_decoded = doDecodeMaskingKey(data); | ||
break; | ||
case State::FramePayload: | ||
bytes_decoded = doDecodePayload(data); | ||
break; | ||
case State::FrameFinished: | ||
frameDataEnd(output.value()); | ||
break; | ||
} | ||
data.remove_prefix(bytes_decoded); | ||
} | ||
} | ||
return !output->empty() ? std::move(output) : absl::nullopt; | ||
} | ||
|
||
} // namespace WebSocket | ||
} // namespace Envoy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#pragma once | ||
|
||
#include <array> | ||
#include <cstdint> | ||
#include <vector> | ||
|
||
#include "envoy/buffer/buffer.h" | ||
|
||
#include "source/common/common/logger.h" | ||
|
||
namespace Envoy { | ||
namespace WebSocket { | ||
|
||
// Opcodes (https://datatracker.ietf.org/doc/html/rfc6455#section-11.8) | ||
constexpr uint8_t kFrameOpcodeContinuation = 0; | ||
constexpr uint8_t kFrameOpcodeText = 1; | ||
constexpr uint8_t kFrameOpcodeBinary = 2; | ||
constexpr uint8_t kFrameOpcodeClose = 8; | ||
constexpr uint8_t kFrameOpcodePing = 9; | ||
constexpr uint8_t kFrameOpcodePong = 10; | ||
constexpr std::array<uint8_t, 6> kFrameOpcodes = {kFrameOpcodeContinuation, kFrameOpcodeText, | ||
kFrameOpcodeBinary, kFrameOpcodeClose, | ||
kFrameOpcodePing, kFrameOpcodePong}; | ||
|
||
// Length of the masking key which is 4 bytes fixed size | ||
constexpr uint8_t kMaskingKeyLength = 4; | ||
// 16 bit payload length | ||
constexpr uint8_t kPayloadLength16Bit = 2; | ||
// 64 bit payload length | ||
constexpr uint8_t kPayloadLength64Bit = 8; | ||
// Maximum payload buffer length | ||
constexpr uint64_t kMaxPayloadBufferLength = 0x7fffffffffffffff; | ||
|
||
// Wire format (https://datatracker.ietf.org/doc/html/rfc6455#section-5.2) | ||
// of WebSocket frame: | ||
// | ||
// 0 1 2 3 | ||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 | ||
// +-+-+-+-+-------+-+-------------+-------------------------------+ | ||
// |F|R|R|R| opcode|M| Payload len | Extended payload length | | ||
// |I|S|S|S| (4) |A| (7) | (16/64) | | ||
// |N|V|V|V| |S| | (if payload len==126/127) | | ||
// | |1|2|3| |K| | | | ||
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + | ||
// | Extended payload length continued, if payload len == 127 | | ||
// + - - - - - - - - - - - - - - - +-------------------------------+ | ||
// | | Masking-key, if MASK set to 1 | | ||
// +-------------------------------+-------------------------------+ | ||
// | Masking-key (continued) | Payload Data | | ||
// +-------------------------------- - - - - - - - - - - - - - - - + | ||
// : .... Payload Data continued .... Payload Data continued ..... : | ||
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + | ||
// | .... Payload Data continued .... Payload Data continued ..... | | ||
// +---------------------------------------------------------------+ | ||
|
||
// In-memory representation of the contents of a WebSocket frame. | ||
struct Frame { | ||
// Indicates that this is the final fragment in a message. | ||
bool final_fragment_; | ||
// Frame opcode. | ||
uint8_t opcode_; | ||
// The 4 byte fixed size masking key used to mask the payload. Masking/unmasking should be | ||
// performed as described in https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 | ||
absl::optional<uint32_t> masking_key_; | ||
// Length of the payload as the number of bytes. | ||
uint64_t payload_length_; | ||
// WebSocket payload data (extension data and application data). | ||
Buffer::InstancePtr payload_; | ||
}; | ||
|
||
// Encoder encodes in memory WebSocket frames into frames in the wire format | ||
class Encoder : public Logger::Loggable<Logger::Id::websocket> { | ||
public: | ||
Encoder() = default; | ||
|
||
// Creates a new Websocket data frame header with the given frame data. | ||
// @param frame supplies the frame to be encoded. | ||
// @return std::vector<uint8_t> buffer with encoded header data. | ||
absl::optional<std::vector<uint8_t>> encodeFrameHeader(const Frame& frame); | ||
}; | ||
|
||
// Decoder decodes bytes in input buffer into in-memory WebSocket frames. | ||
class Decoder : public Logger::Loggable<Logger::Id::websocket> { | ||
public: | ||
Decoder(uint64_t max_payload_length = 0) | ||
: max_payload_buffer_length_{std::min(max_payload_length, kMaxPayloadBufferLength)} {}; | ||
// Decodes the given buffer into WebSocket frames. If the input is not sufficient to make a | ||
// complete WebSocket frame, then the decoder saves the state of halfway decoded WebSocket | ||
// frame until the next decode calls feed rest of the frame data. | ||
// @param input supplies the binary octets wrapped in a WebSocket frame. | ||
// @return the decoded frames. | ||
absl::optional<std::vector<Frame>> decode(const Buffer::Instance& input); | ||
|
||
private: | ||
void resetDecoder(); | ||
void frameDataStart(); | ||
void frameData(const uint8_t* mem, uint64_t length); | ||
void frameDataEnd(std::vector<Frame>& output); | ||
|
||
uint8_t doDecodeFlagsAndOpcode(absl::Span<const uint8_t>& data); | ||
uint8_t doDecodeMaskFlagAndLength(absl::Span<const uint8_t>& data); | ||
uint8_t doDecodeExtendedLength(absl::Span<const uint8_t>& data); | ||
uint8_t doDecodeMaskingKey(absl::Span<const uint8_t>& data); | ||
uint64_t doDecodePayload(absl::Span<const uint8_t>& data); | ||
|
||
// Current state of the frame that is being processed. | ||
enum class State { | ||
// Decoding the first byte. Waiting for decoding the final frame flag (1 bit) | ||
// and reserved flags (3 bits) and opcode (4 bits) of the WebSocket data frame. | ||
FrameHeaderFlagsAndOpcode, | ||
// Decoding the second byte. Waiting for decoding the mask flag (1 bit) and | ||
// length/length flag (7 bit) of the WebSocket data frame. | ||
FrameHeaderMaskFlagAndLength, | ||
// Waiting for decoding the extended 16 bit length. | ||
FrameHeaderExtendedLength16Bit, | ||
// Waiting for decoding the extended 64 bit length. | ||
FrameHeaderExtendedLength64Bit, | ||
// Waiting for decoding the masking key (4 bytes) only if the mask bit is set. | ||
FrameHeaderMaskingKey, | ||
// Waiting for decoding the payload (both extension data and application data). | ||
FramePayload, | ||
// Frame has finished decoding. | ||
FrameFinished | ||
}; | ||
uint64_t max_payload_buffer_length_; | ||
// Current frame that is being decoded. | ||
Frame frame_; | ||
State state_ = State::FrameHeaderFlagsAndOpcode; | ||
uint64_t length_ = 0; | ||
uint8_t num_remaining_extended_length_bytes_ = 0; | ||
uint8_t num_remaining_masking_key_bytes_ = 0; | ||
}; | ||
|
||
} // namespace WebSocket | ||
} // namespace Envoy |
Oops, something went wrong.