From 5c724e2b38f0da5ce9f14edd7720ae4d616d36b0 Mon Sep 17 00:00:00 2001 From: Milan Miladinovic Date: Wed, 22 Feb 2023 17:06:29 -0500 Subject: [PATCH] Implement and use HibernationManagerImpl Previously we introduced the HibernationManager class, but this commit actually implements it. The HibernationManager needs to be stored somewhere that lives long, so we'll put it in the deferred proxy task. --- src/workerd/api/actor-state.c++ | 5 +- src/workerd/api/hibernatable-web-socket.c++ | 9 +- src/workerd/api/hibernatable-web-socket.h | 1 - src/workerd/api/http.c++ | 13 +- src/workerd/io/hibernation-manager.c++ | 199 ++++++++++++++++++++ src/workerd/io/hibernation-manager.h | 191 +++++++++++++++++++ 6 files changed, 412 insertions(+), 6 deletions(-) create mode 100644 src/workerd/io/hibernation-manager.c++ create mode 100644 src/workerd/io/hibernation-manager.h diff --git a/src/workerd/api/actor-state.c++ b/src/workerd/api/actor-state.c++ index b35a1464a4f3..c33fa27d68e4 100644 --- a/src/workerd/api/actor-state.c++ +++ b/src/workerd/api/actor-state.c++ @@ -14,6 +14,7 @@ #include #include "sql.h" #include +#include namespace workerd::api { @@ -759,7 +760,9 @@ void DurableObjectState::acceptWebSocket( // We need to get a HibernationManager to give the websocket to. auto& a = KJ_REQUIRE_NONNULL(IoContext::current().getActor()); if (a.getHibernationManager() == nullptr) { - // TODO(now): Actually set the hibernation manager. + a.setHibernationManager( + kj::refcounted( + a.getLoopback(), KJ_REQUIRE_NONNULL(a.getHibernationEventType()))); } // HibernationManager's acceptWebSocket() will throw if the websocket is in an incompatible state. // Note that not providing a tag is equivalent to providing an empty tag array. diff --git a/src/workerd/api/hibernatable-web-socket.c++ b/src/workerd/api/hibernatable-web-socket.c++ index dc26df344c7b..00443c60cd67 100644 --- a/src/workerd/api/hibernatable-web-socket.c++ +++ b/src/workerd/api/hibernatable-web-socket.c++ @@ -5,6 +5,7 @@ #include "hibernatable-web-socket.h" #include #include +#include namespace workerd::api { @@ -12,9 +13,11 @@ HibernatableWebSocketEvent::HibernatableWebSocketEvent() : ExtendableEvent("webSocketMessage") {}; jsg::Ref HibernatableWebSocketEvent::getWebSocket(jsg::Lock& lock) { - // This is just a stub implementation and is to be replaced once the new websocket manager - // needs it - return jsg::alloc(kj::str(""), WebSocket::Locality::LOCAL); + auto& manager = static_cast( + KJ_REQUIRE_NONNULL( + KJ_REQUIRE_NONNULL(IoContext::current().getActor()).getHibernationManager())); + auto& hibernatableWebSocket = KJ_REQUIRE_NONNULL(manager.webSocketForEventHandler); + return hibernatableWebSocket.getActiveOrUnhibernate(lock); } kj::Promise HibernatableWebSocketCustomEventImpl::run( diff --git a/src/workerd/api/hibernatable-web-socket.h b/src/workerd/api/hibernatable-web-socket.h index 4bee64699b08..dd4bdf5f2827 100644 --- a/src/workerd/api/hibernatable-web-socket.h +++ b/src/workerd/api/hibernatable-web-socket.h @@ -21,7 +21,6 @@ class HibernatableWebSocketEvent final: public ExtendableEvent { static jsg::Ref constructor(kj::String type) = delete; - // TODO(soon): return correct ws instead of the current stub implementation jsg::Ref getWebSocket(jsg::Lock& lock); JSG_RESOURCE_TYPE(HibernatableWebSocketEvent) { diff --git a/src/workerd/api/http.c++ b/src/workerd/api/http.c++ index 7ff312febd66..9bbf8ecac587 100644 --- a/src/workerd/api/http.c++ +++ b/src/workerd/api/http.c++ @@ -1338,7 +1338,18 @@ kj::Promise> Response::send( } auto clientSocket = outer.acceptWebSocket(outHeaders); - return (*ws)->couple(kj::mv(clientSocket)); + auto wsPromise = (*ws)->couple(kj::mv(clientSocket)); + + KJ_IF_MAYBE(a, context.getActor()) { + KJ_IF_MAYBE(hib, (*a).getHibernationManager()) { + // We attach a reference to the deferred proxy task so the HibernationManager lives at least + // as long as the websocket connection. + // The actor still retains its reference to the manager, so any subsequent requests prior + // to hibernation will not need to re-obtain a reference. + wsPromise = wsPromise.attach(kj::addRef(*hib)); + } + } + return wsPromise; } else KJ_IF_MAYBE(jsBody, getBody()) { auto encoding = getContentEncoding(context, outHeaders, bodyEncoding); auto maybeLength = (*jsBody)->tryGetLength(encoding); diff --git a/src/workerd/io/hibernation-manager.c++ b/src/workerd/io/hibernation-manager.c++ new file mode 100644 index 000000000000..b33076845d0d --- /dev/null +++ b/src/workerd/io/hibernation-manager.c++ @@ -0,0 +1,199 @@ +// Copyright (c) 2017-2023 Cloudflare, Inc. +// Licensed under the Apache 2.0 license found in the LICENSE file or at: +// https://opensource.org/licenses/Apache-2.0 + +#include "io-context.h" +#include + +namespace workerd { + +HibernationManagerImpl::~HibernationManagerImpl() { + // Note that the HibernatableWebSocket destructor handles removing any references to itself in + // `tagToWs`, and even removes the hashmap entry if there are no more entries in the bucket. + allWs.clear(); + KJ_ASSERT(tagToWs.size() == 0, "tagToWs hashmap wasn't cleared."); +} + +void HibernationManagerImpl::acceptWebSocket( + jsg::Ref ws, + kj::ArrayPtr tags) { + // First, we create the HibernatableWebSocket and add it to the collection where it'll stay + // until it's destroyed. + + JSG_REQUIRE(allWs.size() < ACTIVE_CONNECTION_LIMIT, Error, "only ", ACTIVE_CONNECTION_LIMIT, + " websockets can be accepted on a single Durable Object instance"); + + auto hib = kj::heap(kj::mv(ws), tags, *this); + HibernatableWebSocket& refToHibernatable = *hib.get(); + allWs.push_front(kj::mv(hib)); + refToHibernatable.node = allWs.begin(); + + // If the `tags` array is empty (i.e. user did not provide a tag), we skip the population of the + // `tagToWs` HashMap below and go straight to initiating the readLoop. + + // It is the caller's responsibility to ensure all elements of `tags` are unique. + // TODO(cleanup): Maybe we could enforce uniqueness by using an immutable type that + // can only be constructed if the elements in the collection are distinct, ex. "DistinctArray". + // + // We need to add the HibernatableWebSocket to each bucket in `tagToWs` corresponding to its tags. + // 1. Create the entry if it doesn't exist + // 2. Fill the TagListItem in the HibernatableWebSocket's tagItems array + size_t position = 0; + for (auto tag = tags.begin(); tag < tags.end(); tag++, position++) { + auto& tagCollection = tagToWs.findOrCreate(*tag, [&tag]() { + auto item = kj::heap( + kj::mv(*tag), kj::heap>()); + return decltype(tagToWs)::Entry { + item->tag, + kj::mv(item) + }; + }); + // This TagListItem sits in the HibernatableWebSocket's tagItems array. + auto& tagListItem = refToHibernatable.tagItems[position]; + tagListItem.hibWS = refToHibernatable; + tagListItem.tag = tagCollection->tag.asPtr(); + + auto& list = tagCollection->list; + list->add(tagListItem); + // We also give the TagListItem a reference to the list it was added to so the + // HibernatableWebSocket can quickly remove itself from the list without doing a lookup + // in `tagToWs`. + tagListItem.list = *list.get(); + } + + // Finally, we initiate the readloop for this HibernatableWebSocket. + kj::Promise> readLoopPromise = kj::evalNow([&] { + return readLoop(refToHibernatable); + }).then([]() -> kj::Maybe { return nullptr; }, + [](kj::Exception&& e) -> kj::Maybe { return kj::mv(e); }); + + // Give the task to the HibernationManager so it lives long. + readLoopTasks.add(readLoopPromise.then( + [&refToHibernatable, this](kj::Maybe&& maybeError) -> kj::Promise { + return handleSocketTermination(refToHibernatable, maybeError); + })); +} + +kj::Vector> HibernationManagerImpl::getWebSockets( + jsg::Lock& js, + kj::Maybe maybeTag) { + kj::Vector> matches; + KJ_IF_MAYBE(tag, maybeTag) { + KJ_IF_MAYBE(item, tagToWs.find(*tag)) { + auto& list = *((*item)->list); + for (auto& entry: list) { + auto& hibWS = KJ_REQUIRE_NONNULL(entry.hibWS); + matches.add(hibWS.getActiveOrUnhibernate(js)); + } + } + } else { + // Add all websockets! + for (auto& hibWS : allWs) { + matches.add(hibWS->getActiveOrUnhibernate(js)); + } + } + return kj::mv(matches); +} + +void HibernationManagerImpl::hibernateWebSockets(Worker::Lock& lock) { + jsg::Lock& js(lock); + v8::HandleScope handleScope(js.v8Isolate); + v8::Context::Scope contextScope(lock.getContext()); + for (auto& ws : allWs) { + KJ_IF_MAYBE(active, ws->activeOrPackage.tryGet>()) { + // Transfers ownership of properties from api::WebSocket to HibernatableWebSocket via the + // HibernationPackage. + ws->activeOrPackage.init( + active->get()->buildPackageForHibernation()); + } + } +} + +void HibernationManagerImpl::dropHibernatableWebSocket(HibernatableWebSocket& hib) { + removeFromAllWs(hib); +} + +inline void HibernationManagerImpl::removeFromAllWs(HibernatableWebSocket& hib) { + auto& node = KJ_REQUIRE_NONNULL(hib.node); + allWs.erase(node); +} + +kj::Promise HibernationManagerImpl::handleSocketTermination( + HibernatableWebSocket& hib, kj::Maybe& maybeError) { + kj::Maybe> event; + KJ_IF_MAYBE(error, maybeError) { + webSocketForEventHandler = hib; + if (!hib.hasDispatchedClose && + (error->getType() == kj::Exception::Type::DISCONNECTED)) { + // If premature disconnect/cancel, dispatch a close event if we haven't already. + auto params = api::HibernatableSocketParams( + 1006, + kj::str("WebSocket disconnected without sending Close frame."), + false); + // Dispatch the close event. + auto workerInterface = loopback->getWorker(IoChannelFactory::SubrequestMetadata{}); + event = workerInterface->customEvent(kj::heap( + hibernationEventType, readLoopTasks, kj::mv(params), *this)) + .then([&](auto _) { hib.hasDispatchedClose = true; }); + } else { + // Otherwise, we need to dispatch an error event! + auto params = api::HibernatableSocketParams(kj::mv(*error)); + + // Dispatch the error event. + auto workerInterface = loopback->getWorker(IoChannelFactory::SubrequestMetadata{}); + event = workerInterface->customEvent(kj::heap( + hibernationEventType, readLoopTasks, kj::mv(params), *this)).ignoreResult(); + } + } + + // Returning the event promise will store it in readLoopTasks. + // After the task completes, we want to drop the websocket since we've closed the connection. + KJ_IF_MAYBE(promise, event) { + return kj::mv(*promise).then([&]() { + dropHibernatableWebSocket(hib); + }); + } else { + dropHibernatableWebSocket(hib); + return kj::READY_NOW; + } +} + +kj::Promise HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) { + // Like the api::WebSocket readLoop(), but we dispatch different types of events. + auto& ws = *hib.ws; + while (true) { + kj::WebSocket::Message message = co_await ws.receive(); + // Note that errors are handled by the callee of `readLoop`, since we throw from `receive()`. + webSocketForEventHandler = hib; + + // Build the event params depending on what type of message we got. + kj::Maybe maybeParams; + KJ_SWITCH_ONEOF(message) { + KJ_CASE_ONEOF(text, kj::String) { + maybeParams.emplace(kj::mv(text)); + } + KJ_CASE_ONEOF(data, kj::Array) { + maybeParams.emplace(kj::mv(data)); + } + KJ_CASE_ONEOF(close, kj::WebSocket::Close) { + maybeParams.emplace(close.code, kj::mv(close.reason), true); + } + } + + auto params = kj::mv(KJ_REQUIRE_NONNULL(maybeParams)); + auto isClose = params.isCloseEvent(); + // Dispatch the event. + auto workerInterface = loopback->getWorker(IoChannelFactory::SubrequestMetadata{}); + co_await workerInterface->customEvent( + kj::heap( + hibernationEventType, readLoopTasks, kj::mv(params), *this)); + if (isClose) { + // We've dispatched the close event, so let's mark our websocket as having done so to + // prevent a situation where we dispatch it twice. + hib.hasDispatchedClose = true; + co_return; + } + } +} + +}; // namespace workerd diff --git a/src/workerd/io/hibernation-manager.h b/src/workerd/io/hibernation-manager.h new file mode 100644 index 000000000000..efdc7490eac6 --- /dev/null +++ b/src/workerd/io/hibernation-manager.h @@ -0,0 +1,191 @@ +// Copyright (c) 2017-2023 Cloudflare, Inc. +// Licensed under the Apache 2.0 license found in the LICENSE file or at: +// https://opensource.org/licenses/Apache-2.0 + +#pragma once + +#include +#include +#include +#include "v8-isolate.h" +#include + +#include + +namespace workerd { + +class HibernationManagerImpl final : public Worker::Actor::HibernationManager { + // Implements the HibernationManager class. +public: + HibernationManagerImpl(kj::Own loopback, uint16_t hibernationEventType) + : loopback(kj::mv(loopback)), + hibernationEventType(hibernationEventType), + onDisconnect(DisconnectHandler{}), + readLoopTasks(onDisconnect) {} + ~HibernationManagerImpl(); + + void acceptWebSocket(jsg::Ref ws, kj::ArrayPtr tags) override; + // Tells the HibernationManager to create a new HibernatableWebSocket with the associated tags + // and to initiate the `readLoop()` for this websocket. The `tags` array *must* contain only + // unique elements. + + kj::Vector> getWebSockets( + jsg::Lock& js, + kj::Maybe tag) override; + // Gets a collection of websockets associated with the given tag. Any hibernating websockets will + // be woken up. If no tag is provided, we return all accepted websockets. + + void hibernateWebSockets(Worker::Lock& lock) override; + // Hibernates all the websockets held by the HibernationManager. + // This converts our activeOrPackage from an api::WebSocket to a HibernationPackage. + + friend jsg::Ref api::HibernatableWebSocketEvent::getWebSocket(jsg::Lock& lock); + +private: + class HibernatableWebSocket; + struct TagListItem { + // Each HibernatableWebSocket can have multiple tags, so we want to store a reference + // in our kj::List. + kj::Maybe hibWS; + kj::ListLink link; + kj::StringPtr tag; + kj::Maybe&> list; + // The List that refers to this TagListItem. + // If `list` is null, we've already removed this item from the list. + }; + + class HibernatableWebSocket { + // api::WebSockets cannot survive hibernation, but kj::WebSockets do. This class helps us + // manage the transition of an api::WebSocket from its active state to a hibernated state + // and vice versa. + // + // Some properties of the JS websocket object need to be retained throughout hibernation, + // such as `attachment`, `url`, `extensions`, etc. These properties are only read/modified + // when initiating, or waking from hibernation. + public: + HibernatableWebSocket(jsg::Ref websocket, + kj::ArrayPtr tags, + HibernationManagerImpl& manager) + : tagItems(kj::heapArray(tags.size())), + activeOrPackage(kj::mv(websocket)), + // Extract's the kj::Own from api::WebSocket so the HibernatableWebSocket + // can own it. The api::WebSocket retains a reference to our ws. + ws(activeOrPackage.get>()->acceptAsHibernatable()), + manager(manager) {} + + ~HibernatableWebSocket() { + // We expect this dtor to be called when we're removing a HibernatableWebSocket + // from our `allWs` collection in the HibernationManager. + + // This removal is fast because we have direct access to each kj::List, as well as direct + // access to each TagListItem we want to remove. + for (auto& item: tagItems) { + KJ_IF_MAYBE(list, item.list) { + // The list reference is non-null, so we still have a valid reference to this + // TagListItem in the list, which we will now remove. + list->remove(item); + if (list->empty()) { + // Remove the bucket in tagToWs if the tag has no more websockets. + manager.tagToWs.erase(kj::mv(item.tag)); + } + } + item.hibWS = nullptr; + item.list = nullptr; + } + } + + jsg::Ref getActiveOrUnhibernate(jsg::Lock& js) { + // Returns a reference to the active websocket. If the websocket is currently hibernating, + // we have to unhibernate it first. The process moves values from the HibernatableWebSocket + // to the api::WebSocket. + KJ_IF_MAYBE(package, activeOrPackage.tryGet()) { + activeOrPackage.init>( + api::WebSocket::hibernatableFromNative(js, *ws, kj::mv(*package))); + } + return activeOrPackage.get>().addRef(); + } + + kj::ListLink link; + + kj::Array tagItems; + // An array of all the items/nodes that refer to this HibernatableWebSocket. + // Keeping track of these items allows us to quickly remove every reference from `tagToWs` + // once the websocket disconnects -- rather than iterating through each relevant tag in the + // hashmap and removing it from each kj::List. + + kj::OneOf, api::WebSocket::HibernationPackage> activeOrPackage; + // If active, we have an api::WebSocket reference, otherwise, we're hibernating, so we retain + // the websocket's properties in a HibernationPackage until it's time to wake up. + kj::Own ws; + HibernationManagerImpl& manager; + // TODO(someday): We (currently) only use the HibernationManagerImpl reference to refer to + // `tagToWs` when running the dtor for `HibernatableWebSocket`. This feels a bit excessive, + // I would rather have the HibernationManager deal with its collections than have the + // HibernatableWebSocket do so. Maybe come back to this at some point? + kj::Maybe>::iterator> node; + // Reference to the Node in `allWs` that allows us to do fast deletion on disconnect. + + bool hasDispatchedClose = false; + // True once we have dispatched the close event. + // This prevents us from dispatching it if we have already done so. + friend HibernationManagerImpl; + }; + +private: + void dropHibernatableWebSocket(HibernatableWebSocket& hib); + // Removes a HibernatableWebSocket from the HibernationManager's various collections. + + inline void removeFromAllWs(HibernatableWebSocket& hib); + // Removes the HibernatableWebSocket from `allWs`. + + kj::Promise handleSocketTermination( + HibernatableWebSocket& hib, kj::Maybe& maybeError); + // Handles the termination of the websocket. If termination was not clean, we might try to + // dispatch a close event (if we haven't already), or an error event. + // We will also remove the HibernatableWebSocket from the HibernationManager's collections. + + kj::Promise readLoop(HibernatableWebSocket& hib); + // Like the api::WebSocket readLoop(), but we dispatch different types of events. + + struct TagCollection { + // This struct is held by the `tagToWs` hashmap. The key is a StringPtr to tag, and the value + // is this struct itself. + kj::String tag; + kj::Own> list; + + TagCollection(kj::String tag, decltype(list) list): tag(kj::mv(tag)), list(kj::mv(list)) {} + TagCollection(TagCollection&& other) = default; + }; + + kj::HashMap> tagToWs; + // A hashmap of tags to HibernatableWebSockets associated with the tag. + // We use a kj::List so we can quickly remove websockets that have disconnected. + // Also note that we box the keys and values such that in the event of a hashmap resizing we don't + // move the underlying data (thereby keeping any references intact). + + std::list> allWs; + // We store all of our HibernatableWebSockets in a doubly linked-list. + + kj::Own loopback; + // Used to obtain the worker so we can dispatch Hibernatable websocket events. + + uint16_t hibernationEventType; + // Passed to HibernatableWebSocket custom event as the typeId. + + kj::Maybe webSocketForEventHandler; + // Allows the HibernatableWebSocket event handler that is currently running to access the + // HibernatableWebSocket that it needs to execute. + + const size_t ACTIVE_CONNECTION_LIMIT = 1024 * 64; + // The maximum number of Hibernatable WebSocket connections a single HibernationManagerImpl + // instance can manage. + + class DisconnectHandler: public kj::TaskSet::ErrorHandler { + public: + void taskFailed(kj::Exception&& exception) override {}; + // We don't need to do anything here; we already handle disconnects in the callee of readLoop(). + }; + DisconnectHandler onDisconnect; + kj::TaskSet readLoopTasks; +}; +}; // namespace workerd