Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset WebSocket outgoing message status in a single continuation #379

Merged
merged 2 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 43 additions & 48 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -437,22 +437,22 @@ void WebSocket::send(jsg::Lock& js, kj::OneOf<kj::Array<byte>, kj::String> messa
JSG_REQUIRE(native.state.is<Accepted>(), TypeError,
"You must call accept() on this WebSocket before sending messages.");

KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::mv(text),
});
break;
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
outgoingMessages->insert(GatedMessage{
IoContext::current().waitForOutputLocksIfNecessary(),
kj::mv(data),
});
break;
auto maybeOutputLock = IoContext::current().waitForOutputLocksIfNecessary();
auto msg = [&]() -> kj::WebSocket::Message {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
return kj::mv(text);
break;
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return kj::mv(data);
break;
}
}
}

KJ_UNREACHABLE;
}();
outgoingMessages->insert(GatedMessage{kj::mv(maybeOutputLock), kj::mv(msg)});

ensurePumping(js);
}
Expand Down Expand Up @@ -564,20 +564,19 @@ void WebSocket::dispatchOpen(jsg::Lock& js) {
void WebSocket::ensurePumping(jsg::Lock& js) {
auto& native = *farNative;
if (!native.isPumping) {
native.isPumping = true;
auto& context = IoContext::current();
auto& accepted = KJ_ASSERT_NONNULL(native.state.tryGet<Accepted>());
auto promise = kj::evalNow([&]() {
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws));
return accepted.canceler.wrap(pump(context, *outgoingMessages, *accepted.ws, native));
});

// TODO(cleanup): We use awaitIoLegacy() here because we don't want this to count as a pending
// event if this is a WebSocketPair with the other end being handled in the same isolate.
// In that case, the pump can hang if accept() is never called on the other end. Ideally,
// this scenario would be handled in-isolate using jsg::Promise, but that would take some
// refactoring.
context.awaitIoLegacy(kj::mv(promise)).then(js, [this, thisHandle = JSG_THIS](jsg::Lock& js) {
jasnell marked this conversation as resolved.
Show resolved Hide resolved
auto& native = *farNative;
native.isPumping = false;
if (native.outgoingAborted) {
// Apparently, the peer stopped accepting messages (probably, disconnected entirely), but
// this didn't cause our writes to fail, maybe due to timing. Let's set the error now.
Expand All @@ -588,26 +587,11 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
native.state.init<Released>();
}
}, [this](jsg::Lock& js, jsg::Value&& exception) mutable {
farNative->isPumping = false;
outgoingMessages->clear();
reportError(js, kj::mv(exception));
});
}
}

kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws) {
if (outgoingMessages.size() == 0) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(promise, outgoingMessages.ordered().begin()->outputLock) {
return promise->then([&context, &outgoingMessages, &ws]() mutable {
return pumpAfterFrontOutputLock(context, outgoingMessages, ws);
});
} else {
return pumpAfterFrontOutputLock(context, outgoingMessages, ws);
}
}

namespace {
size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
// This does not count the extra data of the RPC frame or the savings from any compression.
Expand All @@ -633,37 +617,48 @@ size_t countBytesFromMessage(const kj::WebSocket::Message& message) {
}
}

kj::Promise<void> WebSocket::pumpAfterFrontOutputLock(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws) {
GatedMessage gatedMessage =
outgoingMessages.release(*outgoingMessages.ordered().begin());
auto size = countBytesFromMessage(gatedMessage.message);
kj::Promise<void> WebSocket::pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native) {
KJ_ASSERT(!native.isPumping);
native.isPumping = true;
xortive marked this conversation as resolved.
Show resolved Hide resolved
KJ_DEFER({
// We use a KJ_DEFER to set native.isPumping = false to ensure that it happens -- we had a bug
// in the past where this was handled by the caller of WebSocket::pump() and it allowed for
// messages to get stuck in `outgoingMessages` until the pump task was restarted.
native.isPumping = false;
xortive marked this conversation as resolved.
Show resolved Hide resolved

// Either we were already through all our outgoing messages or we experienced failure/
// cancellation and cannot send these anyway.
outgoingMessages.clear();
});

while (outgoingMessages.size() > 0) {
GatedMessage gatedMessage = outgoingMessages.release(*outgoingMessages.ordered().begin());
KJ_IF_MAYBE(promise, gatedMessage.outputLock) {
co_await *promise;
}

auto size = countBytesFromMessage(gatedMessage.message);

kj::Promise<void> promise = nullptr;
{
KJ_SWITCH_ONEOF(gatedMessage.message) {
KJ_CASE_ONEOF(text, kj::String) {
promise = ws.send(text);
co_await ws.send(text);
break;
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
promise = ws.send(data);
co_await ws.send(data);
break;
}
KJ_CASE_ONEOF(close, kj::WebSocket::Close) {
promise = ws.close(close.code, close.reason);
co_await ws.close(close.code, close.reason);
break;
}
}
}

return promise.attach(kj::mv(gatedMessage.message))
.then([&context, &outgoingMessages, &ws, size]() {
KJ_IF_MAYBE(a, context.getActor()) {
a->getMetrics().sentWebSocketMessage(size);
}
return pump(context, outgoingMessages, ws);
});
}
}

kj::Promise<void> WebSocket::readLoop(kj::WebSocket& ws) {
Expand Down
4 changes: 1 addition & 3 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,7 @@ class WebSocket: public EventTarget {
void ensurePumping(jsg::Lock& js);

static kj::Promise<void> pump(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws);
static kj::Promise<void> pumpAfterFrontOutputLock(
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws);
IoContext& context, OutgoingMessagesMap& outgoingMessages, kj::WebSocket& ws, Native& native);
// Write messages from `outgoingMessages` into `ws`.
//
// These are not necessarily called under isolate lock, but they are called on the given
Expand Down