Skip to content

Commit

Permalink
Merge pull request #379 from cloudflare/bcaimano/web-socket-pump
Browse files Browse the repository at this point in the history
Reset WebSocket outgoing message status in a single continuation
  • Loading branch information
xortive authored Feb 21, 2023
2 parents c417d75 + 2b857f8 commit c336d40
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 51 deletions.
91 changes: 43 additions & 48 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -437,22 +437,22 @@ void WebSocket::send(jsg::Lock& js, kj::OneOf<kj::Array<byte>, kj::String> messa
JSG_REQUIRE(native.state.is<Accepted>(), TypeError,
"You must call accept() on this WebSocket before sending messages.");

KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::mv(text),
});
break;
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::mv(data),
});
break;
auto maybeOutputLock = IoContext::current().waitForOutputLocksIfNecessary();
auto msg = [&]() -> kj::WebSocket::Message {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
return kj::mv(text);
break;
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return kj::mv(data);
break;
}
}
}

KJ_UNREACHABLE;
}();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg)});

ensurePumping(js);
}
Expand Down Expand Up @@ -564,20 +564,19 @@ void WebSocket::dispatchOpen(jsg::Lock& js) {
void WebSocket::ensurePumping(jsg::Lock& js) {
auto& native = *farNative;
if (!native.isPumping) {
native.isPumping = true;
auto& context = IoContext::current();
auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet<Accepted>());
auto promise = kj::evalNow([&]() {
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws));
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws, native));
});

// TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending
// event if this is a WebSocketPair with the other end being handled in the same isolate.
// In that case, the pump can hang if accept() is never called on the other end. Ideally,
// this scenario would be handled in-isolate using jsg::Promise, but that would take some
// refactoring.
context.awaitIoLegacy(kj::mv(promise)).then(js, [this, thisHandle = JSG_THIS](jsg::Lock& js) {
auto& native = *farNative;
native.isPumping = false;
if (native.outgoingAborted) {
// Apparently, the peer stopped accepting messages (probably, disconnected entirely), but
// this didn't cause our writes to fail, maybe due to timing. Let's set the error now.
Expand All @@ -588,26 +587,11 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
native.state.init<Released>();
}
}, [this](jsg::Lock& js, jsg::Value&& exception) mutable {
farNative->isPumping = false;
outgoingMessages->clear();
reportError(js, kj::mv(exception));
});
}
}

kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws) {
if (outgoingMessages.size() == 0) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(promise, outgoingMessages.ordered().begin()->outputLock) {
return promise->then([&context, &outgoingMessages, &ws]() mutable {
return pumpAfterFrontOutputLock(context, outgoingMessages, ws);
});
} else {
return pumpAfterFrontOutputLock(context, outgoingMessages, ws);
}
}

namespace {
size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
// This does not count the extra data of the RPC frame or the savings from any compression.
Expand All @@ -633,37 +617,48 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
}
}

kj::Promise<void> WebSocket::pumpAfterFrontOutputLock(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws) {
GatedMessage gatedMessage =
outgoingMessages.release(*outgoingMessages.ordered().begin());
auto size = countBytesFromMessage(gatedMessage.message);
kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native) {
KJ_ASSERT(!native.isPumping);
native.isPumping = true;
KJ_DEFER({
// We use a KJ_DEFER to set native.isPumping = false to ensure that it happens -- we had a bug
// in the past where this was handled by the caller of WebSocket::pump() and it allowed for
// messages to get stuck in `outgoingMessages` until the pump task was restarted.
native.isPumping = false;

// Either we were already through all our outgoing messages or we experienced failure/
// cancellation and cannot send these anyway.
outgoingMessages.clear();
});

while (outgoingMessages.size() > 0) {
GatedMessage gatedMessage = outgoingMessages.release(*outgoingMessages.ordered().begin());
KJ_IF_MAYBE(promise, gatedMessage.outputLock) {
co_await *promise;
}

auto size = countBytesFromMessage(gatedMessage.message);

kj::Promise<void> promise = nullptr;
{
KJ_SWITCH_ONEOF(gatedMessage.message) {
KJ_CASE_ONEOF(text, kj::String) {
promise = ws.send(text);
co_await ws.send(text);
break;
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
promise = ws.send(data);
co_await ws.send(data);
break;
}
KJ_CASE_ONEOF(close, kj::WebSocket::Close) {
promise = ws.close(close.code, close.reason);
co_await ws.close(close.code, close.reason);
break;
}
}
}

return promise.attach(kj::mv(gatedMessage.message))
.then([&context, &outgoingMessages, &ws, size]() {
KJ_IF_MAYBE(a, context.getActor()) {
a->getMetrics().sentWebSocketMessage(size);
}
return pump(context, outgoingMessages, ws);
});
}
}

kj::Promise<void> WebSocket::readLoop(kj::WebSocket& ws) {
Expand Down
4 changes: 1 addition & 3 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,7 @@ class WebSocket: public EventTarget {
void ensurePumping(jsg::Lock& js);

static kj::Promise<void> pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws);
static kj::Promise<void> pumpAfterFrontOutputLock(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws);
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native);
// Write messages from `outgoingMessages` into `ws`.
//
// These are not necessarily called under isolate lock, but they are called on the given
Expand Down

0 comments on commit c336d40

Please sign in to comment.