From bd0659013fa0bcc621db81f85b60d9d96b073802 Mon Sep 17 00:00:00 2001 From: James M Snell Date: Tue, 19 Sep 2023 08:12:18 -0700 Subject: [PATCH] Move various kj stream impls to separate stream-utils (#1189) --- src/workerd/api/global-scope.c++ | 94 ++++--------- src/workerd/api/http.c++ | 16 +-- src/workerd/io/worker.c++ | 17 +-- src/workerd/tests/test-fixture.c++ | 19 +-- src/workerd/util/stream-utils.c++ | 214 +++++++++++++++++++++++++++++ src/workerd/util/stream-utils.h | 30 ++++ 6 files changed, 277 insertions(+), 113 deletions(-) create mode 100644 src/workerd/util/stream-utils.c++ create mode 100644 src/workerd/util/stream-utils.h diff --git a/src/workerd/api/global-scope.c++ b/src/workerd/api/global-scope.c++ index b21b0907b59..ef2b3230802 100644 --- a/src/workerd/api/global-scope.c++ +++ b/src/workerd/api/global-scope.c++ @@ -20,76 +20,32 @@ #include #include #include +#include namespace workerd::api { namespace { -// An InputStream that can be disconnected. Used for request body, which becomes invalid as -// soon as the response is returned. -class NeuterableInputStream: public kj::AsyncInputStream, public kj::Refcounted { -public: - NeuterableInputStream(kj::AsyncInputStream& inner): inner(&inner) {} - - enum NeuterReason { - SENT_RESPONSE, - THREW_EXCEPTION, - CLIENT_DISCONNECTED - }; - - void neuter(NeuterReason reason) { - if (inner.is()) { - inner = reason; - if (!canceler.isEmpty()) { - canceler.cancel(makeException(reason)); - } - } - } - - kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { - return canceler.wrap(getStream().read(buffer, minBytes, maxBytes)); - } - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return canceler.wrap(getStream().tryRead(buffer, minBytes, maxBytes)); - } - kj::Maybe tryGetLength() override { - return getStream().tryGetLength(); - } - kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { - return canceler.wrap(getStream().pumpTo(output, amount)); - } - -private: - kj::OneOf inner; - kj::Canceler canceler; - - kj::AsyncInputStream& getStream() { - KJ_SWITCH_ONEOF(inner) { - KJ_CASE_ONEOF(stream, kj::AsyncInputStream*) { - return *stream; - } - KJ_CASE_ONEOF(reason, NeuterReason) { - kj::throwFatalException(makeException(reason)); - } - } - KJ_UNREACHABLE; - } +enum class NeuterReason { + SENT_RESPONSE, + THREW_EXCEPTION, + CLIENT_DISCONNECTED +}; - kj::Exception makeException(NeuterReason reason) { - switch (reason) { - case SENT_RESPONSE: - return JSG_KJ_EXCEPTION(FAILED, TypeError, - "Can't read from request stream after response has been sent."); - case THREW_EXCEPTION: - return JSG_KJ_EXCEPTION(FAILED, TypeError, - "Can't read from request stream after responding with an exception."); - case CLIENT_DISCONNECTED: - return JSG_KJ_EXCEPTION(DISCONNECTED, TypeError, - "Can't read from request stream because client disconnected."); - } - KJ_UNREACHABLE; +kj::Exception makeNeuterException(NeuterReason reason) { + switch (reason) { + case NeuterReason::SENT_RESPONSE: + return JSG_KJ_EXCEPTION(FAILED, TypeError, + "Can't read from request stream after response has been sent."); + case NeuterReason::THREW_EXCEPTION: + return JSG_KJ_EXCEPTION(FAILED, TypeError, + "Can't read from request stream after responding with an exception."); + case NeuterReason::CLIENT_DISCONNECTED: + return JSG_KJ_EXCEPTION(DISCONNECTED, TypeError, + "Can't read from request stream because client disconnected."); } -}; + KJ_UNREACHABLE; +} kj::String getEventName(v8::PromiseRejectEvent type) { switch (type) { @@ -151,14 +107,14 @@ kj::Promise> ServiceWorkerGlobalScope::request( // that it can drop the reference whenever it gets GC'd. But in this case the stream's lifetime // is not under our control -- it's attached to the request. So, we wrap it in a // NeuterableInputStream which allows us to disconnect the stream before it becomes invalid. - auto ownRequestBody = kj::refcounted(requestBody); + auto ownRequestBody = newNeuterableInputStream(requestBody); auto deferredNeuter = kj::defer([ownRequestBody = kj::addRef(*ownRequestBody)]() mutable { // Make sure to cancel the request body stream since the native stream is no longer valid once // the returned promise completes. Note that the KJ HTTP library deals with the fact that we // haven't consumed the entire request body. - ownRequestBody->neuter(NeuterableInputStream::CLIENT_DISCONNECTED); + ownRequestBody->neuter(makeNeuterException(NeuterReason::CLIENT_DISCONNECTED)); }); - KJ_ON_SCOPE_FAILURE(ownRequestBody->neuter(NeuterableInputStream::THREW_EXCEPTION)); + KJ_ON_SCOPE_FAILURE(ownRequestBody->neuter(makeNeuterException(NeuterReason::THREW_EXCEPTION))); auto& ioContext = IoContext::current(); jsg::Lock& js = lock; @@ -311,9 +267,9 @@ kj::Promise> ServiceWorkerGlobalScope::request( // task finishes. deferredProxy.proxyTask = deferredProxy.proxyTask .then([body = kj::addRef(*ownRequestBody)]() mutable { - body->neuter(NeuterableInputStream::SENT_RESPONSE); + body->neuter(makeNeuterException(NeuterReason::SENT_RESPONSE)); }, [body = kj::addRef(*ownRequestBody)](kj::Exception&& e) mutable { - body->neuter(NeuterableInputStream::THREW_EXCEPTION); + body->neuter(makeNeuterException(NeuterReason::THREW_EXCEPTION)); kj::throwFatalException(kj::mv(e)); }).attach(kj::mv(deferredNeuter)); @@ -321,7 +277,7 @@ kj::Promise> ServiceWorkerGlobalScope::request( }, [body = kj::mv(body2)](kj::Exception&& e) mutable -> DeferredProxy { // HACK: We depend on the fact that the success-case lambda above hasn't been destroyed yet // so `deferredNeuter` hasn't been destroyed yet. - body->neuter(NeuterableInputStream::THREW_EXCEPTION); + body->neuter(makeNeuterException(NeuterReason::THREW_EXCEPTION)); kj::throwFatalException(kj::mv(e)); }); } else { diff --git a/src/workerd/api/http.c++ b/src/workerd/api/http.c++ index 24023898dff..5a6a9520d4b 100644 --- a/src/workerd/api/http.c++ +++ b/src/workerd/api/http.c++ @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -1448,17 +1449,6 @@ void FetchEvent::passThroughOnException() { namespace { -class NullInputStream final: public kj::AsyncInputStream { -public: - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return size_t(0); - } - - kj::Maybe tryGetLength() override { - return uint64_t(0); - } -}; - // Fetch spec requires (suggests?) 20: https://fetch.spec.whatwg.org/#http-redirect-fetch constexpr auto MAX_REDIRECT_COUNT = 20; @@ -1552,7 +1542,7 @@ jsg::Promise> fetchImplNoOutputLock( return js.resolvedPromise(makeHttpResponse(js, jsRequest->getMethodEnum(), kj::mv(urlList), response.statusCode, response.statusText, *response.headers, - kj::heap(), + newNullInputStream(), jsg::alloc(kj::mv(webSocket), WebSocket::REMOTE), Response::BodyEncoding::AUTO, kj::mv(signal))); @@ -1899,7 +1889,7 @@ static jsg::Promise parseResponse( return js.resolvedPromise( Fetcher::GetResult(jsg::alloc( IoContext::current(), - newSystemStream(kj::heap(), StreamEncoding::IDENTITY)))); + newSystemStream(newNullInputStream(), StreamEncoding::IDENTITY)))); } } diff --git a/src/workerd/io/worker.c++ b/src/workerd/io/worker.c++ index 7b5ea1f7df1..16ed0289e21 100644 --- a/src/workerd/io/worker.c++ +++ b/src/workerd/io/worker.c++ @@ -8,6 +8,7 @@ #include "actor-cache.h" #include #include +#include #include #include #include @@ -3368,20 +3369,6 @@ double getWallTimeForProcessSandboxOnly() { auto timePoint = kj::systemPreciseCalendarClock().now(); return (timePoint - kj::UNIX_EPOCH) / kj::MILLISECONDS / 1e3; } - -class NullOutputStream final: public kj::AsyncOutputStream { -public: - kj::Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - kj::Promise write(kj::ArrayPtr> pieces) override { - return kj::READY_NOW; - } - kj::Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } -}; - } // namespace class Worker::Isolate::ResponseStreamWrapper final: public kj::AsyncOutputStream { @@ -3679,7 +3666,7 @@ kj::Promise Worker::Isolate::SubrequestClient::request( // TODO(someday): Support sending WebSocket frames over CDP. For now we fake an empty // response. signalResponse(kj::mv(requestId), 101, "Switching Protocols", headers, - kj::heap()); + newNullOutputStream()); return kj::mv(webSocket); } diff --git a/src/workerd/tests/test-fixture.c++ b/src/workerd/tests/test-fixture.c++ index a1fee6660ed..cbc2f53ad3c 100644 --- a/src/workerd/tests/test-fixture.c++ +++ b/src/workerd/tests/test-fixture.c++ @@ -12,6 +12,7 @@ #include #include #include +#include #include "test-fixture.h" @@ -229,20 +230,6 @@ struct MockResponse final: public kj::HttpService::Response { KJ_FAIL_REQUIRE("NOT SUPPORTED"); } }; - -struct MemoryInputStream final: public kj::AsyncInputStream { - kj::ArrayPtr data; - - MemoryInputStream(kj::ArrayPtr data) : data(data) { } - - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - auto toRead = kj::min(data.size(), maxBytes); - memcpy(buffer, data.begin(), toRead); - data = data.slice(toRead, data.size()); - return toRead; - } -}; - } // namespace @@ -335,7 +322,7 @@ TestFixture::Response TestFixture::runRequest( kj::HttpMethod method, kj::StringPtr url, kj::StringPtr body) { kj::HttpHeaders requestHeaders(*headerTable); MockResponse response; - MemoryInputStream requestBody(body.asBytes()); + auto requestBody = newMemoryInputStream(body); runInIoContext([&](const TestFixture::Environment& env) { auto& globalScope = env.lock.getGlobalScope(); @@ -343,7 +330,7 @@ TestFixture::Response TestFixture::runRequest( method, url, requestHeaders, - requestBody, + *requestBody, response, "{}"_kj, env.lock, diff --git a/src/workerd/util/stream-utils.c++ b/src/workerd/util/stream-utils.c++ new file mode 100644 index 00000000000..4696dbe16e2 --- /dev/null +++ b/src/workerd/util/stream-utils.c++ @@ -0,0 +1,214 @@ +#include "stream-utils.h" +#include +#include +#include +#include + +namespace workerd { + +namespace { +class NullIoStream final: public kj::AsyncIoStream { +public: + void shutdownWrite() override {} + + kj::Promise write(const void* buffer, size_t size) override { + return kj::READY_NOW; + } + kj::Promise write(kj::ArrayPtr> pieces) override { + return kj::READY_NOW; + } + kj::Promise whenWriteDisconnected() override { + return kj::NEVER_DONE; + } + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return kj::constPromise(); + } + + kj::Maybe tryGetLength() override { + return kj::Maybe((uint64_t)0); + } + + kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { + return kj::constPromise(); + } +}; + +class MemoryInputStream final: public kj::AsyncInputStream { +public: + MemoryInputStream(kj::ArrayPtr data) + : data(data) { } + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + size_t toRead = kj::min(data.size(), maxBytes); + memcpy(buffer, data.begin(), toRead); + data = data.slice(toRead, data.size()); + return toRead; + } + +private: + kj::ArrayPtr data; +}; + +class NeuterableInputStreamImpl final: public NeuterableInputStream { +public: + NeuterableInputStreamImpl(kj::AsyncInputStream& inner): inner(&inner) {} + + void neuter(kj::Exception exception) override { + if (inner.is()) { + inner = kj::cp(exception); + if (!canceler.isEmpty()) { + canceler.cancel(kj::mv(exception)); + } + } + } + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return canceler.wrap(getStream().read(buffer, minBytes, maxBytes)); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return canceler.wrap(getStream().tryRead(buffer, minBytes, maxBytes)); + } + kj::Maybe tryGetLength() override { + return getStream().tryGetLength(); + } + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return canceler.wrap(getStream().pumpTo(output, amount)); + } + +private: + kj::OneOf inner; + kj::Canceler canceler; + + kj::AsyncInputStream& getStream() { + KJ_SWITCH_ONEOF(inner) { + KJ_CASE_ONEOF(stream, kj::AsyncInputStream*) { + return *stream; + } + KJ_CASE_ONEOF(exception, kj::Exception) { + kj::throwFatalException(kj::cp(exception)); + } + } + KJ_UNREACHABLE; + } +}; + +class NeuterableIoStreamImpl final: public NeuterableIoStream { +public: + NeuterableIoStreamImpl(kj::AsyncIoStream& inner): inner(&inner) {} + + void neuter(kj::Exception reason) override { + if (inner.is()) { + inner = kj::cp(reason); + if (!canceler.isEmpty()) { + canceler.cancel(kj::mv(reason)); + } + } + } + + // AsyncInputStream + + kj::Promise read(void* buffer, size_t minBytes, size_t maxBytes) override { + return canceler.wrap(getStream().read(buffer, minBytes, maxBytes)); + } + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return canceler.wrap(getStream().tryRead(buffer, minBytes, maxBytes)); + } + kj::Maybe tryGetLength() override { + return getStream().tryGetLength(); + } + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { + return canceler.wrap(getStream().pumpTo(output, amount)); + } + + // AsyncOutputStream + + kj::Promise write(const void* buffer, size_t size) override { + return canceler.wrap(getStream().write(buffer, size)); + } + kj::Promise write(kj::ArrayPtr> pieces) override { + return canceler.wrap(getStream().write(pieces)); + } + kj::Maybe> tryPumpFrom( + kj::AsyncInputStream& input, uint64_t amount) override { + return getStream().tryPumpFrom(input, amount).map([this](kj::Promise promise) { + return canceler.wrap(kj::mv(promise)); + }); + } + kj::Promise whenWriteDisconnected() override { + return canceler.wrap(getStream().whenWriteDisconnected()); + } + + // AsyncIoStream + + void shutdownWrite() override { + getStream().shutdownWrite(); + }; + void abortRead() override { + getStream().abortRead(); + } + void getsockopt(int level, int option, void* value, kj::uint* length) override { + getStream().getsockopt(level, option, value, length); + } + void setsockopt(int level, int option, const void* value, kj::uint length) override { + getStream().setsockopt(level, option, value, length); + } + void getsockname(struct sockaddr* addr, kj::uint* length) override { + getStream().getsockname(addr, length); + } + void getpeername(struct sockaddr* addr, kj::uint* length) override { + getStream().getpeername(addr, length); + } + virtual kj::Maybe getFd() const override { + return getStream().getFd(); + } + +private: + kj::OneOf inner; + kj::Canceler canceler; + + kj::AsyncIoStream& getStream() { + KJ_IF_SOME(stream, inner.tryGet()) { + return *stream; + } + kj::throwFatalException(kj::cp(inner.get())); + } + kj::AsyncIoStream& getStream() const { + KJ_IF_SOME(stream, inner.tryGet()) { + return *stream; + } + kj::throwFatalException(kj::cp(inner.get())); + } +}; + +} // namespace + +kj::Own newNullIoStream() { + return kj::heap(); +} + +kj::Own newNullInputStream() { + return kj::heap(); +} + +kj::Own newNullOutputStream() { + return kj::heap(); +} + +kj::Own newMemoryInputStream(kj::ArrayPtr data) { + return kj::heap(data); +} + +kj::Own newMemoryInputStream(kj::StringPtr data) { + return kj::heap(data.asBytes()); +} + +kj::Own newNeuterableInputStream(kj::AsyncInputStream& inner) { + return kj::refcounted(inner); +} + +kj::Own newNeuterableIoStream(kj::AsyncIoStream& inner) { + return kj::heap(inner); +} + +} // namespace workerd diff --git a/src/workerd/util/stream-utils.h b/src/workerd/util/stream-utils.h new file mode 100644 index 00000000000..3a06996c979 --- /dev/null +++ b/src/workerd/util/stream-utils.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace workerd { + +kj::Own newNullIoStream(); +kj::Own newNullInputStream(); +kj::Own newNullOutputStream(); + +kj::Own newMemoryInputStream(kj::ArrayPtr); +kj::Own newMemoryInputStream(kj::StringPtr); + +// An InputStream that can be disconnected. +class NeuterableInputStream: public kj::AsyncInputStream, + public kj::Refcounted { +public: + virtual void neuter(kj::Exception ex) = 0; +}; + +class NeuterableIoStream: public kj::AsyncIoStream { +public: + virtual void neuter(kj::Exception ex) = 0; +}; + +kj::Own newNeuterableInputStream(kj::AsyncInputStream&); +kj::Own newNeuterableIoStream(kj::AsyncIoStream&); + + +} // namespace workerd