Skip to content

Commit

Permalink
Implement cross-request promise waits using promise context tagging
Browse files Browse the repository at this point in the history
  • Loading branch information
jasnell committed Jun 22, 2023
1 parent 7a4a57a commit 6cd27dd
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/workerd/io/io-context.c++
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,8 @@ public:
threadScope(context),
workerLock(*context.worker, lockType),
handleScope(workerLock.getIsolate()),
jsContextScope(workerLock.getContext()) {
jsContextScope(workerLock.getContext()),
promiseContextScope(workerLock.getIsolate(), context.getPromiseContextTag(workerLock)) {
KJ_REQUIRE(context.currentInputLock == nullptr);
KJ_REQUIRE(context.currentLock == nullptr);
context.currentInputLock = kj::mv(inputLock);
Expand Down Expand Up @@ -1125,6 +1126,7 @@ private:
Worker::Lock workerLock;
v8::HandleScope handleScope;
v8::Context::Scope jsContextScope;
v8::Isolate::PromiseContextScope promiseContextScope;
};

void IoContext::runImpl(Runnable& runnable, bool takePendingEvent,
Expand Down Expand Up @@ -1285,6 +1287,8 @@ void IoContext::runFinalizers(Worker::AsyncLock& asyncLock) {
RunnableImpl runnable(*this, kj::mv(warnings));
runImpl(runnable, false, asyncLock, nullptr, true);
}

promiseContextTag = nullptr;
}

#ifdef KJ_DEBUG
Expand Down Expand Up @@ -1409,4 +1413,11 @@ void IoContext::requireCurrentOrThrowJs() {
"of Cloudflare Workers which allows us to improve overall performance.");
}

v8::Local<v8::Object> IoContext::getPromiseContextTag(jsg::Lock& js) {
if (promiseContextTag == nullptr) {
promiseContextTag = js.v8Ref(v8::Object::New(js.v8Isolate));
}
return KJ_REQUIRE_NONNULL(promiseContextTag).getHandle(js);
}

} // namespace workerd
4 changes: 4 additions & 0 deletions src/workerd/io/io-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ class IoContext final: public kj::Refcounted, private kj::TaskSet::ErrorHandler

void writeLogfwdr(uint channel, kj::FunctionParam<void(capnp::AnyPointer::Builder)> buildMessage);

v8::Local<v8::Object> getPromiseContextTag(jsg::Lock& js);

private:
ThreadContext& thread;

Expand Down Expand Up @@ -1004,6 +1006,8 @@ class IoContext final: public kj::Refcounted, private kj::TaskSet::ErrorHandler
void requireCurrent();
void checkFarGet(const DeleteQueue* expectedQueue);

kj::Maybe<jsg::V8Ref<v8::Object>> promiseContextTag;

class Runnable {
public:
virtual void run(Worker::Lock& lock) = 0;
Expand Down
93 changes: 93 additions & 0 deletions src/workerd/io/worker.c++
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <workerd/jsg/util.h>
#include <workerd/io/cdp.capnp.h>
#include <workerd/io/compatibility-date.h>
#include <workerd/util/wait-list.h>
#include <capnp/compat/json.h>
#include <kj/compat/gzip.h>
#include <kj/compat/brotli.h>
Expand Down Expand Up @@ -978,6 +979,58 @@ kj::Maybe<kj::String> makeCompatJson(kj::ArrayPtr<kj::StringPtr> enableFlags) {
return kj::String(json.releaseAsArray());
}

struct CrossThreadPromiseWaiter {
// When a promise is created in a different IoContext, we need to use a
// CrossThreadWaitList in order to wait on it. The Waiter instance will
// be held on the Promise itself, and will be fulfilled/rejected when the
// promise is resolved or rejected. This will signal all of the waiters
// from other IoContexts.
CrossThreadWaitList waitList;
kj::Maybe<jsg::Value> result;
jsg::Promise<void> addWaiter(jsg::Lock& js) {
return IoContext::current().awaitIo(js, waitList.addWaiter());
}
void fulfill(jsg::Value value) {
result = kj::mv(value);
waitList.fulfill();
}
void reject(kj::Exception ex) {
waitList.reject(kj::mv(ex));
}
CrossThreadPromiseWaiter() = default;
KJ_DISALLOW_COPY_AND_MOVE(CrossThreadPromiseWaiter);
};

CrossThreadPromiseWaiter& getCrossThreadPromiseWaiter(jsg::Lock& js,
jsg::V8Ref<v8::Promise>& promise) {

auto key = v8::Private::ForApi(js.v8Isolate, jsg::v8StrIntern(js.v8Isolate, "waiter"));
auto handle = promise.getHandle(js);
auto opaqueWaiter = jsg::check(handle->GetPrivate(js.v8Context(), key));
if (opaqueWaiter->IsUndefined()) {
// Is opaqueWaiter is undefined here, then this is the first time we are waiting on
// this promise from a different context. We'll create our CrossThreadWaitList and
// set the promise up to resolve it.

auto waiter = kj::heap<CrossThreadPromiseWaiter>();
auto onSuccess = [promise=promise.addRef(js)]
(jsg::Lock& js, jsg::Value value) mutable {
getCrossThreadPromiseWaiter(js, promise).fulfill(kj::mv(value));
};

auto onFailure = [promise=promise.addRef(js)]
(jsg::Lock& js, jsg::Value exception) mutable {
getCrossThreadPromiseWaiter(js, promise).reject(js.exceptionToKj(kj::mv(exception)));
};

js.toPromise(handle).then(js, kj::mv(onSuccess), kj::mv(onFailure)).markAsHandled();

opaqueWaiter = jsg::wrapOpaque(js.v8Context(), kj::mv(waiter));
jsg::check(handle->SetPrivate(js.v8Context(), key, opaqueWaiter));
}
return *jsg::unwrapOpaqueRef<kj::Own<CrossThreadPromiseWaiter>>(js.v8Isolate, opaqueWaiter);
}

} // namespace

Worker::Isolate::Isolate(kj::Own<ApiIsolate> apiIsolateParam,
Expand Down Expand Up @@ -1060,6 +1113,46 @@ Worker::Isolate::Isolate(kj::Own<ApiIsolate> apiIsolateParam,
}
}
});

// The PromiseCrossContextCallback is used to allow cross-IoContext promise following.
// When the IoContext::Scope is entered, we set the "promise context tag" associated
// with the IoContext on the Isolate that is locked. Any Promise that is created within
// that scope will be tagged with the same promise context tag. When an attempt to
// follow a promise occurs (e.g. either using Promise.prototype.then() or await, etc)
// our patched v8 logic will check to see if the followed promise's tag matches the
// current Isolate tag. If they do not, then v8 will invoke this callback. The promise
// here is the promise that belongs to a different IoContext.
lock->v8Isolate->SetPromiseCrossContextCallback([](v8::Local<v8::Context> context,
v8::Local<v8::Promise> promise,
v8::Local<v8::Object> tag) ->
v8::MaybeLocal<v8::Promise> {
if (!IoContext::hasCurrent()) return promise;
auto isolate = context->GetIsolate();
auto& js = jsg::Lock::from(isolate);
auto promiseRef = js.v8Ref(promise);
// The CrossThreadPromiseWaiter is a help that is attached to the promise. It is
// allows multiple cross-thread/cross-request waiters to be attached to the promise,
// all of which will be fulfilled when promise is resolved.
auto& waiter = getCrossThreadPromiseWaiter(js, promiseRef);

// Note that we capture a strong (non-traced) reference to the promise here to ensure
// that we can still get access to the CrossThreadPromiseWaiter so we can extract the
// resolved value after the waiter promise resolves. We have to do it this way because
// the CrossThreadWaitList used by the waiter is not currently capable of following
// anything but a kj::Promise<void>. We end up having to pass the value out of band.
// That's unfortunate but works for now.
// TODO(cleanup): Update CrossThreadWaitList to support passing non-void values.
return js.wrapSimplePromise(waiter.addWaiter(js).then(js,
[promise=kj::mv(promiseRef)](jsg::Lock& js) mutable -> jsg::Promise<jsg::Value> {
KJ_IF_MAYBE(value, getCrossThreadPromiseWaiter(js, promise).result) {
return js.resolvedPromise(value->addRef(js));
} else {
// It should never be the case that the value is not set here but we still
// handle it just in case.
return js.resolvedPromise(js.v8Ref(js.v8Undefined()));
}
}));
});
}

Worker::Script::Script(kj::Own<const Isolate> isolateParam, kj::StringPtr id,
Expand Down
2 changes: 2 additions & 0 deletions src/workerd/jsg/jsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1918,12 +1918,14 @@ class Lock {
// pays attention to the return value.
// TODO(later): See if we can easily combine wrapSimpleFunction and wrapReturningFunction
// into one.
virtual v8::Local<v8::Promise> wrapSimplePromise(Promise<Value> promise) = 0;

bool toBool(v8::Local<v8::Value> value);
virtual kj::String toString(v8::Local<v8::Value> value) = 0;
virtual jsg::Dict<v8::Local<v8::Value>> toDict(v8::Local<v8::Value> value) = 0;
// Convenience methods to unwrap various types of V8 values. All of these could be done manually
// via the V8 API, but these methods are much easier.
virtual Promise<Value> toPromise(v8::Local<v8::Promise> promise) = 0;

// ---------------------------------------------------------------------------
// Setup stuff
Expand Down
7 changes: 7 additions & 0 deletions src/workerd/jsg/setup.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@ class Isolate: public IsolateBase {
return jsgIsolate.wrapper->template unwrap<jsg::Dict<v8::Local<v8::Value>>>(
v8Isolate->GetCurrentContext(), value, jsg::TypeErrorContext::other());
}
v8::Local<v8::Promise> wrapSimplePromise(jsg::Promise<jsg::Value> promise) override {
return jsgIsolate.wrapper->wrap(v8Context(), nullptr, kj::mv(promise));
}
jsg::Promise<jsg::Value> toPromise(v8::Local<v8::Promise> promise) override {
return jsgIsolate.wrapper->template unwrap<jsg::Promise<jsg::Value>>(
v8Isolate->GetCurrentContext(), promise, jsg::TypeErrorContext::other());
}

template <typename T, typename... Args>
JsContext<T> newContext(Args&&... args) {
Expand Down

0 comments on commit 6cd27dd

Please sign in to comment.