Skip to content

Commit

Permalink
Merge pull request #1130 from cloudflare/joaquim/ws-autoresponse-race…
Browse files Browse the repository at this point in the history
…-fix

Fix hibernatable websocket race between auto-response and regular messages
  • Loading branch information
jqmmes authored Oct 2, 2023
2 parents f8a4aa6 + 1754538 commit 66cb3ce
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 18 deletions.
69 changes: 62 additions & 7 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ IoOwn<WebSocket::Native> WebSocket::initNative(
// We might have called `close()` when this WebSocket was previously active.
// If so, we want to prevent any future calls to `send()`.
nativeObj->closedOutgoing = closedOutgoingConn;
autoResponseStatus.isClosed = nativeObj->closedOutgoing;
return ioContext.addObject(kj::mv(nativeObj));
}

Expand Down Expand Up @@ -531,7 +532,10 @@ void WebSocket::send(jsg::Lock& js, kj::OneOf<kj::Array<byte>, kj::String> messa

KJ_UNREACHABLE;
}();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg)});

auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() - autoResponseStatus.queuedAutoResponses;
autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg), pendingAutoResponses});

ensurePumping(js);
}
Expand Down Expand Up @@ -574,17 +578,25 @@ void WebSocket::close(
if (reason != kj::none) {
// The default code of 1005 cannot have a reason, per the standard, so if a reason is specified
// then there must be a code, too.
JSG_REQUIRE(code != nullptr, TypeError,
JSG_REQUIRE(code != kj::none, TypeError,
"If you specify a WebSocket close reason, you must also specify a code.");
}

// pendingAutoResponses stores the number of queuedAutoResponses that will be pumped before sending
// the current GatedMessage, guaranteeing order.
// queuedAutoResponses stores the total number of auto-response messages that are already in accounted
// for in previous GatedMessages. This is useful to easily calculate the number of pendingAutoResponses
// for each new GateMessage.
auto pendingAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size() -
autoResponseStatus.queuedAutoResponses;
autoResponseStatus.queuedAutoResponses = autoResponseStatus.pendingAutoResponseDeque.size();
outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::WebSocket::Close {
// Code 1005 actually translates to sending a close message with no body on the wire.
static_cast<uint16_t>(code.orDefault(1005)),
kj::mv(reason).orDefault(nullptr),
},
}, pendingAutoResponses
});

native.closedOutgoing = true;
Expand Down Expand Up @@ -670,11 +682,12 @@ void WebSocket::serializeAttachment(jsg::Lock& js, jsg::JsValue attachment) {
serializedAttachment = kj::mv(released.data);
}

void WebSocket::setAutoResponseTimestamp(kj::Maybe<kj::Date> time) {
void WebSocket::setAutoResponseStatus(kj::Maybe<kj::Date> time,
kj::Promise<void> autoResponsePromise) {
autoResponseTimestamp = time;
autoResponseStatus.ongoingAutoResponse = kj::mv(autoResponsePromise);
}


kj::Maybe<kj::Date> WebSocket::getAutoResponseTimestamp() {
return autoResponseTimestamp;
}
Expand All @@ -689,7 +702,8 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
auto& context = IoContext::current();
auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet<Accepted>());
auto promise = kj::evalNow([&]() {
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws, native));
return accepted.canceler.wrap(pump(context, *outgoingMessages,
*accepted.ws, native, autoResponseStatus));
});

// TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending
Expand Down Expand Up @@ -729,6 +743,17 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
}
}

kj::Promise<void> WebSocket::sendAutoResponse(kj::String message, kj::WebSocket& ws) {
if (autoResponseStatus.isPumping) {
autoResponseStatus.pendingAutoResponseDeque.push_back(kj::mv(message));
} else if (!autoResponseStatus.isClosed){
auto p = ws.send(message).fork();
autoResponseStatus.ongoingAutoResponse = p.addBranch();
co_await p;
autoResponseStatus.ongoingAutoResponse = kj::READY_NOW;
}
}

namespace {

size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
Expand Down Expand Up @@ -757,9 +782,11 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
} // namespace

kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native) {
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse) {
KJ_ASSERT(!native.isPumping);
native.isPumping = true;
autoResponse.isPumping = true;
KJ_DEFER({
// We use a KJ_DEFER to set native.isPumping = false to ensure that it happens -- we had a bug
// in the past where this was handled by the caller of WebSocket::pump() and it allowed for
Expand All @@ -769,8 +796,19 @@ kj::Promise<void> WebSocket::pump(
// Either we were already through all our outgoing messages or we experienced failure/
// cancellation and cannot send these anyway.
outgoingMessages.clear();

autoResponse.isPumping = false;

if (autoResponse.pendingAutoResponseDeque.size() > 0) {
autoResponse.pendingAutoResponseDeque.clear();
}
});

// If we have a ongoingAutoResponse, we must co_await it here because there's a ws.send()
// in progress. Otherwise there can occur ws.send() race problems.
co_await autoResponse.ongoingAutoResponse;
autoResponse.ongoingAutoResponse = kj::READY_NOW;

while (outgoingMessages.size() > 0) {
GatedMessage gatedMessage = outgoingMessages.release(*outgoingMessages.ordered().begin());
KJ_IF_SOME(promise, gatedMessage.outputLock) {
Expand All @@ -779,6 +817,15 @@ kj::Promise<void> WebSocket::pump(

auto size = countBytesFromMessage(gatedMessage.message);

while (gatedMessage.pendingAutoResponses > 0) {
KJ_ASSERT(autoResponse.pendingAutoResponseDeque.size() >= gatedMessage.pendingAutoResponses);
auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front());
autoResponse.pendingAutoResponseDeque.pop_front();
gatedMessage.pendingAutoResponses--;
autoResponse.queuedAutoResponses--;
co_await ws.send(message);
}

KJ_SWITCH_ONEOF(gatedMessage.message) {
KJ_CASE_ONEOF(text, kj::String) {
co_await ws.send(text);
Expand All @@ -790,6 +837,7 @@ kj::Promise<void> WebSocket::pump(
}
KJ_CASE_ONEOF(close, kj::WebSocket::Close) {
co_await ws.close(close.code, close.reason);
autoResponse.isClosed = true;
break;
}
}
Expand All @@ -798,6 +846,13 @@ kj::Promise<void> WebSocket::pump(
a.getMetrics().sentWebSocketMessage(size);
}
}
// If there are any auto-responses left to process, we should do it now.
// We should also check if the last sent message was a close. Shouldn't happen.
while (autoResponse.pendingAutoResponseDeque.size() > 0 && !autoResponse.isClosed) {
auto message = kj::mv(autoResponse.pendingAutoResponseDeque.front());
autoResponse.pendingAutoResponseDeque.pop_front();
co_await ws.send(message);
}
}

void WebSocket::tryReleaseNative(jsg::Lock& js) {
Expand Down
22 changes: 20 additions & 2 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "basics.h"
#include <workerd/io/io-context.h>
#include <workerd/jsg/ser.h>
#include <stdlib.h>

namespace workerd {
class ActorObserver;
Expand Down Expand Up @@ -395,12 +396,15 @@ class WebSocket: public EventTarget {

// Used to get/store the last auto request/response timestamp for this WebSocket.
// These methods are c++ only and are not exposed to our js interface.
void setAutoResponseTimestamp(kj::Maybe<kj::Date> time);
// Also used to track hibernatable websockets auto-response sends.
void setAutoResponseStatus(kj::Maybe<kj::Date> time, kj::Promise<void> autoResponsePromise);

// Used to get/store the last auto request/response timestamp for this WebSocket.
// These methods are c++ only and are not exposed to our js interface.
kj::Maybe<kj::Date> getAutoResponseTimestamp();

kj::Promise<void> sendAutoResponse(kj::String message, kj::WebSocket& ws);

int getReadyState();

bool isAccepted();
Expand Down Expand Up @@ -642,12 +646,25 @@ class WebSocket: public EventTarget {
struct GatedMessage {
kj::Maybe<kj::Promise<void>> outputLock; // must wait for this before actually sending
kj::WebSocket::Message message;
size_t pendingAutoResponses = 0;
};
using OutgoingMessagesMap = kj::Table<GatedMessage, kj::InsertionOrderIndex>;
// Queue of messages to be sent. This is wraped in a IoOwn so that the pump loop can safely
// access the map without locking the isolate.
IoOwn<OutgoingMessagesMap> outgoingMessages;

// Keep track of current hibernatable websockets auto-response status to avoid racing
// between regular websocket messages, and auto-responses.
struct AutoResponse {
kj::Promise<void> ongoingAutoResponse = kj::READY_NOW;
std::deque<kj::String> pendingAutoResponseDeque;
size_t queuedAutoResponses = 0;
bool isPumping = false;
bool isClosed = false;
};

AutoResponse autoResponseStatus;

Locality locality;

// Contains a websocket and possibly some data from the WebSocketResponse headers.
Expand Down Expand Up @@ -677,7 +694,8 @@ class WebSocket: public EventTarget {
// objects so are safe to access from the thread without the isolate lock. The whole task is
// owned by the `IoContext` so it'll be canceled if the `IoContext` is destroyed.
static kj::Promise<void> pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native);
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native,
AutoResponse& autoResponse);

kj::Promise<kj::Maybe<kj::Exception>> readLoop();

Expand Down
26 changes: 20 additions & 6 deletions src/workerd/io/hibernation-manager.c++
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,29 @@ kj::Promise<void> HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) {
// We'll store the current timestamp in the HibernatableWebSocket to assure it gets
// stored even if the WebSocket is currently hibernating. In that scenario, the timestamp
// value will be loaded into the WebSocket during unhibernation.
KJ_IF_SOME(active, hib.activeOrPackage.tryGet<jsg::Ref<api::WebSocket>>()) {
// If the actor is not hibernated/If the WebSocket is active, we need to update
// autoResponseTimestamp on the active websocket.
(active)->setAutoResponseTimestamp(hib.autoResponseTimestamp);
KJ_SWITCH_ONEOF(hib.activeOrPackage){
KJ_CASE_ONEOF(apiWs, jsg::Ref<api::WebSocket>) {
// If the actor is not hibernated/If the WebSocket is active, we need to update
// autoResponseTimestamp on the active websocket.
apiWs->setAutoResponseStatus(hib.autoResponseTimestamp, kj::READY_NOW);
co_await apiWs->sendAutoResponse(kj::str(reqResp->getResponse().asArray()), ws);
}
KJ_CASE_ONEOF(package, api::WebSocket::HibernationPackage) {
if (!package.closedOutgoingConnection) {
// We need to store the autoResponsePromise because we may instantiate an api::websocket
// If we do that, we have to provide it with the promise to avoid races. This can
// happen if we have a websocket hibernating, that unhibernates and sends a
// message while ws.send() for auto-response is also sending.
auto p = ws.send(reqResp->getResponse().asArray()).fork();
hib.autoResponsePromise = p.addBranch();
co_await p;
hib.autoResponsePromise = kj::READY_NOW;
}
}
}
co_await ws.send((reqResp)->getResponse().asArray());
skip = true;
// If we've sent an auto response message, we should not unhibernate or deliver the
// received message to the actor
skip = true;
}
}
KJ_CASE_ONEOF_DEFAULT {}
Expand Down
12 changes: 9 additions & 3 deletions src/workerd/io/hibernation-manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
// to the api::WebSocket.
jsg::Ref<api::WebSocket> getActiveOrUnhibernate(jsg::Lock& js) {
KJ_IF_SOME(package, activeOrPackage.tryGet<api::WebSocket::HibernationPackage>()) {
// Now that we unhibernated the WebSocket, we can set the last received autoResponse timestamp
// that was stored in the corresponding HibernatableWebSocket. We also move autoResponsePromise
// from the hibernation manager to api::websocket to prevent possible ws.send races.
activeOrPackage.init<jsg::Ref<api::WebSocket>>(
api::WebSocket::hibernatableFromNative(js, *KJ_REQUIRE_NONNULL(ws), kj::mv(package))
)->setAutoResponseTimestamp(autoResponseTimestamp);
// Now that we unhibernated the WebSocket, we can set the last received autoResponse timestamp
// that was stored in the corresponding HibernatableWebSocket.
)->setAutoResponseStatus(autoResponseTimestamp, kj::mv(autoResponsePromise));
autoResponsePromise = kj::READY_NOW;
}
return activeOrPackage.get<jsg::Ref<api::WebSocket>>().addRef();
}
Expand Down Expand Up @@ -151,6 +153,10 @@ class HibernationManagerImpl final : public Worker::Actor::HibernationManager {
// Stores the last received autoResponseRequest timestamp.
kj::Maybe<kj::Date> autoResponseTimestamp;

// Keeps track of the currently ongoing websocket auto-response send promise. This promise may
// be moved to api::websocket if an hibernating websocket unhibernates.
kj::Promise<void> autoResponsePromise = kj::READY_NOW;

friend HibernationManagerImpl;
};

Expand Down

0 comments on commit 66cb3ce

Please sign in to comment.