Skip to content

Commit

Permalink
websocket: websocket codec implementation (envoyproxy#22542)
Browse files Browse the repository at this point in the history
Commit Message:
websocket: websocket codec implementation

Testing: Unit tests added

Additional Description:
Related issue comment => envoyproxy#13877 (comment)
This is as the initial step to support rate limiting for WebSocket.

Signed-off-by: Amila Senadheera <[email protected]>
  • Loading branch information
Amila-Rukshan authored Aug 29, 2022
1 parent 5183dbf commit e54aaa3
Show file tree
Hide file tree
Showing 6 changed files with 1,283 additions and 1 deletion.
3 changes: 2 additions & 1 deletion source/common/common/logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ namespace Logger {
FUNCTION(tracing) \
FUNCTION(upstream) \
FUNCTION(udp) \
FUNCTION(wasm)
FUNCTION(wasm) \
FUNCTION(websocket)

// clang-format off
enum class Id {
Expand Down
21 changes: 21 additions & 0 deletions source/common/websocket/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load(
"//bazel:envoy_build_system.bzl",
"envoy_cc_library",
"envoy_package",
)

licenses(["notice"]) # Apache 2

envoy_package()

envoy_cc_library(
name = "codec_lib",
srcs = ["codec.cc"],
hdrs = ["codec.h"],
deps = [
"//envoy/buffer:buffer_interface",
"//source/common/buffer:buffer_lib",
"//source/common/common:minimal_logger_lib",
"//source/common/common:scalar_to_byte_vector_lib",
],
)
218 changes: 218 additions & 0 deletions source/common/websocket/codec.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#include "source/common/websocket/codec.h"

#include <algorithm>
#include <array>
#include <cstdint>
#include <memory>
#include <vector>

#include "source/common/buffer/buffer_impl.h"
#include "source/common/common/scalar_to_byte_vector.h"

namespace Envoy {
namespace WebSocket {

absl::optional<std::vector<uint8_t>> Encoder::encodeFrameHeader(const Frame& frame) {
if (std::find(kFrameOpcodes.begin(), kFrameOpcodes.end(), frame.opcode_) == kFrameOpcodes.end()) {
ENVOY_LOG(debug, "Failed to encode websocket frame with invalid opcode: {}", frame.opcode_);
return absl::nullopt;
}
std::vector<uint8_t> output;
// Set flags and opcode
pushScalarToByteVector(
static_cast<uint8_t>(frame.final_fragment_ ? (0x80 | frame.opcode_) : frame.opcode_), output);

// Set payload length
if (frame.payload_length_ <= 125) {
// Set mask bit and 7-bit length
pushScalarToByteVector(frame.masking_key_.has_value()
? static_cast<uint8_t>(frame.payload_length_ | 0x80)
: static_cast<uint8_t>(frame.payload_length_),
output);
} else if (frame.payload_length_ <= 65535) {
// Set mask bit and 16-bit length indicator
pushScalarToByteVector(static_cast<uint8_t>(frame.masking_key_.has_value() ? 0xfe : 0x7e),
output);
// Set 16-bit length
pushScalarToByteVector(htobe16(frame.payload_length_), output);
} else {
// Set mask bit and 64-bit length indicator
pushScalarToByteVector(static_cast<uint8_t>(frame.masking_key_.has_value() ? 0xff : 0x7f),
output);
// Set 64-bit length
pushScalarToByteVector(htobe64(frame.payload_length_), output);
}
// Set masking key
if (frame.masking_key_.has_value()) {
pushScalarToByteVector(htobe32(frame.masking_key_.value()), output);
}
return output;
}

void Decoder::frameDataStart() {
frame_.payload_length_ = length_;
if (length_ == 0) {
state_ = State::FrameFinished;
} else {
if (max_payload_buffer_length_ > 0) {
frame_.payload_ = std::make_unique<Buffer::OwnedImpl>();
}
state_ = State::FramePayload;
}
}

void Decoder::frameData(const uint8_t* mem, uint64_t length) {
if (max_payload_buffer_length_ > 0) {
uint64_t allowed_length = max_payload_buffer_length_ - frame_.payload_->length();
frame_.payload_->add(mem, length <= allowed_length ? length : allowed_length);
}
}

void Decoder::frameDataEnd(std::vector<Frame>& output) {
output.push_back(std::move(frame_));
resetDecoder();
}

void Decoder::resetDecoder() {
frame_ = {false, 0, absl::nullopt, 0, nullptr};
state_ = State::FrameHeaderFlagsAndOpcode;
length_ = 0;
num_remaining_extended_length_bytes_ = 0;
num_remaining_masking_key_bytes_ = 0;
}

uint8_t Decoder::doDecodeFlagsAndOpcode(absl::Span<const uint8_t>& data) {
// Validate opcode (last 4 bits)
uint8_t opcode = data.front() & 0x0f;
if (std::find(kFrameOpcodes.begin(), kFrameOpcodes.end(), opcode) == kFrameOpcodes.end()) {
ENVOY_LOG(debug, "Failed to decode websocket frame with invalid opcode: {}", opcode);
return 0;
}
frame_.opcode_ = opcode;
frame_.final_fragment_ = data.front() & 0x80;
state_ = State::FrameHeaderMaskFlagAndLength;
return 1;
}

uint8_t Decoder::doDecodeMaskFlagAndLength(absl::Span<const uint8_t>& data) {
num_remaining_masking_key_bytes_ = data.front() & 0x80 ? kMaskingKeyLength : 0;
uint8_t length_indicator = data.front() & 0x7f;
if (length_indicator == 0x7e) {
num_remaining_extended_length_bytes_ = kPayloadLength16Bit;
state_ = State::FrameHeaderExtendedLength16Bit;
} else if (length_indicator == 0x7f) {
num_remaining_extended_length_bytes_ = kPayloadLength64Bit;
state_ = State::FrameHeaderExtendedLength64Bit;
} else if (num_remaining_masking_key_bytes_ > 0) {
length_ = length_indicator;
state_ = State::FrameHeaderMaskingKey;
} else {
length_ = length_indicator;
frameDataStart();
}
return 1;
}

uint8_t Decoder::doDecodeExtendedLength(absl::Span<const uint8_t>& data) {
uint64_t bytes_to_decode = data.length() <= num_remaining_extended_length_bytes_
? data.length()
: num_remaining_extended_length_bytes_;
uint8_t size_of_extended_length =
state_ == State::FrameHeaderExtendedLength16Bit ? kPayloadLength16Bit : kPayloadLength64Bit;
uint8_t shift_of_bytes = size_of_extended_length - num_remaining_extended_length_bytes_;
uint8_t* destination = reinterpret_cast<uint8_t*>(&length_) + shift_of_bytes;

ASSERT(shift_of_bytes >= 0);
ASSERT(shift_of_bytes < size_of_extended_length);
memcpy(destination, data.data(), bytes_to_decode); // NOLINT(safe-memcpy)
num_remaining_extended_length_bytes_ -= bytes_to_decode;

if (num_remaining_extended_length_bytes_ == 0) {
length_ = state_ == State::FrameHeaderExtendedLength16Bit ? htobe16(length_) : htobe64(length_);
if (num_remaining_masking_key_bytes_ > 0) {
state_ = State::FrameHeaderMaskingKey;
} else {
frameDataStart();
}
}
return bytes_to_decode;
}

uint8_t Decoder::doDecodeMaskingKey(absl::Span<const uint8_t>& data) {
if (!frame_.masking_key_.has_value()) {
frame_.masking_key_ = 0;
}
uint64_t bytes_to_decode = data.length() <= num_remaining_masking_key_bytes_
? data.length()
: num_remaining_masking_key_bytes_;
uint8_t shift_of_bytes = kMaskingKeyLength - num_remaining_masking_key_bytes_;
uint8_t* destination =
reinterpret_cast<uint8_t*>(&(frame_.masking_key_.value())) + shift_of_bytes;
ASSERT(shift_of_bytes >= 0);
ASSERT(shift_of_bytes < kMaskingKeyLength);
memcpy(destination, data.data(), bytes_to_decode); // NOLINT(safe-memcpy)
num_remaining_masking_key_bytes_ -= bytes_to_decode;

if (num_remaining_masking_key_bytes_ == 0) {
frame_.masking_key_ = htobe32(frame_.masking_key_.value());
frameDataStart();
}
return bytes_to_decode;
}

uint64_t Decoder::doDecodePayload(absl::Span<const uint8_t>& data) {
uint64_t remain_in_buffer = data.length();
uint64_t bytes_decoded = 0;
if (remain_in_buffer <= length_) {
frameData(data.data(), remain_in_buffer);
bytes_decoded += remain_in_buffer;
length_ -= remain_in_buffer;
} else {
frameData(data.data(), length_);
bytes_decoded += length_;
length_ = 0;
}
if (length_ == 0) {
state_ = State::FrameFinished;
}
return bytes_decoded;
}

absl::optional<std::vector<Frame>> Decoder::decode(const Buffer::Instance& input) {
absl::optional<std::vector<Frame>> output = std::vector<Frame>();
for (const Buffer::RawSlice& slice : input.getRawSlices()) {
absl::Span<const uint8_t> data(reinterpret_cast<uint8_t*>(slice.mem_), slice.len_);
while (!data.empty() || state_ == State::FrameFinished) {
uint64_t bytes_decoded = 0;
switch (state_) {
case State::FrameHeaderFlagsAndOpcode:
bytes_decoded = doDecodeFlagsAndOpcode(data);
if (bytes_decoded == 0) {
return absl::nullopt;
}
break;
case State::FrameHeaderMaskFlagAndLength:
bytes_decoded = doDecodeMaskFlagAndLength(data);
break;
case State::FrameHeaderExtendedLength16Bit:
case State::FrameHeaderExtendedLength64Bit:
bytes_decoded = doDecodeExtendedLength(data);
break;
case State::FrameHeaderMaskingKey:
bytes_decoded = doDecodeMaskingKey(data);
break;
case State::FramePayload:
bytes_decoded = doDecodePayload(data);
break;
case State::FrameFinished:
frameDataEnd(output.value());
break;
}
data.remove_prefix(bytes_decoded);
}
}
return !output->empty() ? std::move(output) : absl::nullopt;
}

} // namespace WebSocket
} // namespace Envoy
135 changes: 135 additions & 0 deletions source/common/websocket/codec.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#pragma once

#include <array>
#include <cstdint>
#include <vector>

#include "envoy/buffer/buffer.h"

#include "source/common/common/logger.h"

namespace Envoy {
namespace WebSocket {

// Opcodes (https://datatracker.ietf.org/doc/html/rfc6455#section-11.8)
constexpr uint8_t kFrameOpcodeContinuation = 0;
constexpr uint8_t kFrameOpcodeText = 1;
constexpr uint8_t kFrameOpcodeBinary = 2;
constexpr uint8_t kFrameOpcodeClose = 8;
constexpr uint8_t kFrameOpcodePing = 9;
constexpr uint8_t kFrameOpcodePong = 10;
constexpr std::array<uint8_t, 6> kFrameOpcodes = {kFrameOpcodeContinuation, kFrameOpcodeText,
kFrameOpcodeBinary, kFrameOpcodeClose,
kFrameOpcodePing, kFrameOpcodePong};

// Length of the masking key which is 4 bytes fixed size
constexpr uint8_t kMaskingKeyLength = 4;
// 16 bit payload length
constexpr uint8_t kPayloadLength16Bit = 2;
// 64 bit payload length
constexpr uint8_t kPayloadLength64Bit = 8;
// Maximum payload buffer length
constexpr uint64_t kMaxPayloadBufferLength = 0x7fffffffffffffff;

// Wire format (https://datatracker.ietf.org/doc/html/rfc6455#section-5.2)
// of WebSocket frame:
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | | Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------- - - - - - - - - - - - - - - - +
// : .... Payload Data continued .... Payload Data continued ..... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | .... Payload Data continued .... Payload Data continued ..... |
// +---------------------------------------------------------------+

// In-memory representation of the contents of a WebSocket frame.
struct Frame {
// Indicates that this is the final fragment in a message.
bool final_fragment_;
// Frame opcode.
uint8_t opcode_;
// The 4 byte fixed size masking key used to mask the payload. Masking/unmasking should be
// performed as described in https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
absl::optional<uint32_t> masking_key_;
// Length of the payload as the number of bytes.
uint64_t payload_length_;
// WebSocket payload data (extension data and application data).
Buffer::InstancePtr payload_;
};

// Encoder encodes in memory WebSocket frames into frames in the wire format
class Encoder : public Logger::Loggable<Logger::Id::websocket> {
public:
Encoder() = default;

// Creates a new Websocket data frame header with the given frame data.
// @param frame supplies the frame to be encoded.
// @return std::vector<uint8_t> buffer with encoded header data.
absl::optional<std::vector<uint8_t>> encodeFrameHeader(const Frame& frame);
};

// Decoder decodes bytes in input buffer into in-memory WebSocket frames.
class Decoder : public Logger::Loggable<Logger::Id::websocket> {
public:
Decoder(uint64_t max_payload_length = 0)
: max_payload_buffer_length_{std::min(max_payload_length, kMaxPayloadBufferLength)} {};
// Decodes the given buffer into WebSocket frames. If the input is not sufficient to make a
// complete WebSocket frame, then the decoder saves the state of halfway decoded WebSocket
// frame until the next decode calls feed rest of the frame data.
// @param input supplies the binary octets wrapped in a WebSocket frame.
// @return the decoded frames.
absl::optional<std::vector<Frame>> decode(const Buffer::Instance& input);

private:
void resetDecoder();
void frameDataStart();
void frameData(const uint8_t* mem, uint64_t length);
void frameDataEnd(std::vector<Frame>& output);

uint8_t doDecodeFlagsAndOpcode(absl::Span<const uint8_t>& data);
uint8_t doDecodeMaskFlagAndLength(absl::Span<const uint8_t>& data);
uint8_t doDecodeExtendedLength(absl::Span<const uint8_t>& data);
uint8_t doDecodeMaskingKey(absl::Span<const uint8_t>& data);
uint64_t doDecodePayload(absl::Span<const uint8_t>& data);

// Current state of the frame that is being processed.
enum class State {
// Decoding the first byte. Waiting for decoding the final frame flag (1 bit)
// and reserved flags (3 bits) and opcode (4 bits) of the WebSocket data frame.
FrameHeaderFlagsAndOpcode,
// Decoding the second byte. Waiting for decoding the mask flag (1 bit) and
// length/length flag (7 bit) of the WebSocket data frame.
FrameHeaderMaskFlagAndLength,
// Waiting for decoding the extended 16 bit length.
FrameHeaderExtendedLength16Bit,
// Waiting for decoding the extended 64 bit length.
FrameHeaderExtendedLength64Bit,
// Waiting for decoding the masking key (4 bytes) only if the mask bit is set.
FrameHeaderMaskingKey,
// Waiting for decoding the payload (both extension data and application data).
FramePayload,
// Frame has finished decoding.
FrameFinished
};
uint64_t max_payload_buffer_length_;
// Current frame that is being decoded.
Frame frame_;
State state_ = State::FrameHeaderFlagsAndOpcode;
uint64_t length_ = 0;
uint8_t num_remaining_extended_length_bytes_ = 0;
uint8_t num_remaining_masking_key_bytes_ = 0;
};

} // namespace WebSocket
} // namespace Envoy
Loading

0 comments on commit e54aaa3

Please sign in to comment.