Skip to content

Commit

Permalink
Move various kj stream impls to separate stream-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jasnell committed Sep 15, 2023
1 parent 2453f95 commit a595505
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 113 deletions.
94 changes: 25 additions & 69 deletions src/workerd/api/global-scope.c++
Original file line number Diff line number Diff line change
Expand Up @@ -20,76 +20,32 @@
#include <workerd/util/thread-scopes.h>
#include <workerd/api/hibernatable-web-socket.h>
#include <workerd/api/util.h>
#include <workerd/util/stream-utils.h>

namespace workerd::api {

namespace {

// An InputStream that can be disconnected. Used for request body, which becomes invalid as
// soon as the response is returned.
class NeuterableInputStream: public kj::AsyncInputStream, public kj::Refcounted {
public:
NeuterableInputStream(kj::AsyncInputStream& inner): inner(&inner) {}

enum NeuterReason {
SENT_RESPONSE,
THREW_EXCEPTION,
CLIENT_DISCONNECTED
};

void neuter(NeuterReason reason) {
if (inner.is<kj::AsyncInputStream*>()) {
inner = reason;
if (!canceler.isEmpty()) {
canceler.cancel(makeException(reason));
}
}
}

kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
return canceler.wrap(getStream().read(buffer, minBytes, maxBytes));
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return canceler.wrap(getStream().tryRead(buffer, minBytes, maxBytes));
}
kj::Maybe<uint64_t> tryGetLength() override {
return getStream().tryGetLength();
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
return canceler.wrap(getStream().pumpTo(output, amount));
}

private:
kj::OneOf<kj::AsyncInputStream*, NeuterReason> inner;
kj::Canceler canceler;

kj::AsyncInputStream& getStream() {
KJ_SWITCH_ONEOF(inner) {
KJ_CASE_ONEOF(stream, kj::AsyncInputStream*) {
return *stream;
}
KJ_CASE_ONEOF(reason, NeuterReason) {
kj::throwFatalException(makeException(reason));
}
}
KJ_UNREACHABLE;
}
enum class NeuterReason {
SENT_RESPONSE,
THREW_EXCEPTION,
CLIENT_DISCONNECTED
};

kj::Exception makeException(NeuterReason reason) {
switch (reason) {
case SENT_RESPONSE:
return JSG_KJ_EXCEPTION(FAILED, TypeError,
"Can't read from request stream after response has been sent.");
case THREW_EXCEPTION:
return JSG_KJ_EXCEPTION(FAILED, TypeError,
"Can't read from request stream after responding with an exception.");
case CLIENT_DISCONNECTED:
return JSG_KJ_EXCEPTION(DISCONNECTED, TypeError,
"Can't read from request stream because client disconnected.");
}
KJ_UNREACHABLE;
kj::Exception makeNeuterException(NeuterReason reason) {
switch (reason) {
case NeuterReason::SENT_RESPONSE:
return JSG_KJ_EXCEPTION(FAILED, TypeError,
"Can't read from request stream after response has been sent.");
case NeuterReason::THREW_EXCEPTION:
return JSG_KJ_EXCEPTION(FAILED, TypeError,
"Can't read from request stream after responding with an exception.");
case NeuterReason::CLIENT_DISCONNECTED:
return JSG_KJ_EXCEPTION(DISCONNECTED, TypeError,
"Can't read from request stream because client disconnected.");
}
};
KJ_UNREACHABLE;
}

kj::String getEventName(v8::PromiseRejectEvent type) {
switch (type) {
Expand Down Expand Up @@ -151,14 +107,14 @@ kj::Promise<DeferredProxy<void>> ServiceWorkerGlobalScope::request(
// that it can drop the reference whenever it gets GC'd. But in this case the stream's lifetime
// is not under our control -- it's attached to the request. So, we wrap it in a
// NeuterableInputStream which allows us to disconnect the stream before it becomes invalid.
auto ownRequestBody = kj::refcounted<NeuterableInputStream>(requestBody);
auto ownRequestBody = newNeuterableInputStream(requestBody);
auto deferredNeuter = kj::defer([ownRequestBody = kj::addRef(*ownRequestBody)]() mutable {
// Make sure to cancel the request body stream since the native stream is no longer valid once
// the returned promise completes. Note that the KJ HTTP library deals with the fact that we
// haven't consumed the entire request body.
ownRequestBody->neuter(NeuterableInputStream::CLIENT_DISCONNECTED);
ownRequestBody->neuter(makeNeuterException(NeuterReason::CLIENT_DISCONNECTED));
});
KJ_ON_SCOPE_FAILURE(ownRequestBody->neuter(NeuterableInputStream::THREW_EXCEPTION));
KJ_ON_SCOPE_FAILURE(ownRequestBody->neuter(makeNeuterException(NeuterReason::THREW_EXCEPTION)));

auto& ioContext = IoContext::current();
jsg::Lock& js = lock;
Expand Down Expand Up @@ -311,17 +267,17 @@ kj::Promise<DeferredProxy<void>> ServiceWorkerGlobalScope::request(
// task finishes.
deferredProxy.proxyTask = deferredProxy.proxyTask
.then([body = kj::addRef(*ownRequestBody)]() mutable {
body->neuter(NeuterableInputStream::SENT_RESPONSE);
body->neuter(makeNeuterException(NeuterReason::SENT_RESPONSE));
}, [body = kj::addRef(*ownRequestBody)](kj::Exception&& e) mutable {
body->neuter(NeuterableInputStream::THREW_EXCEPTION);
body->neuter(makeNeuterException(NeuterReason::THREW_EXCEPTION));
kj::throwFatalException(kj::mv(e));
}).attach(kj::mv(deferredNeuter));

return deferredProxy;
}, [body = kj::mv(body2)](kj::Exception&& e) mutable -> DeferredProxy<void> {
// HACK: We depend on the fact that the success-case lambda above hasn't been destroyed yet
// so `deferredNeuter` hasn't been destroyed yet.
body->neuter(NeuterableInputStream::THREW_EXCEPTION);
body->neuter(makeNeuterException(NeuterReason::THREW_EXCEPTION));
kj::throwFatalException(kj::mv(e));
});
} else {
Expand Down
16 changes: 3 additions & 13 deletions src/workerd/api/http.c++
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <workerd/io/features.h>
#include <workerd/util/http-util.h>
#include <workerd/util/mimetype.h>
#include <workerd/util/stream-utils.h>
#include <workerd/util/thread-scopes.h>
#include <workerd/jsg/ser.h>
#include <workerd/io/io-context.h>
Expand Down Expand Up @@ -1448,17 +1449,6 @@ void FetchEvent::passThroughOnException() {

namespace {

class NullInputStream final: public kj::AsyncInputStream {
public:
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}

kj::Maybe<uint64_t> tryGetLength() override {
return uint64_t(0);
}
};

// Fetch spec requires (suggests?) 20: https://fetch.spec.whatwg.org/#http-redirect-fetch
constexpr auto MAX_REDIRECT_COUNT = 20;

Expand Down Expand Up @@ -1552,7 +1542,7 @@ jsg::Promise<jsg::Ref<Response>> fetchImplNoOutputLock(
return js.resolvedPromise(makeHttpResponse(js,
jsRequest->getMethodEnum(), kj::mv(urlList),
response.statusCode, response.statusText, *response.headers,
kj::heap<NullInputStream>(),
newNullInputStream(),
jsg::alloc<WebSocket>(kj::mv(webSocket), WebSocket::REMOTE),
Response::BodyEncoding::AUTO,
kj::mv(signal)));
Expand Down Expand Up @@ -1899,7 +1889,7 @@ static jsg::Promise<Fetcher::GetResult> parseResponse(
return js.resolvedPromise(
Fetcher::GetResult(jsg::alloc<ReadableStream>(
IoContext::current(),
newSystemStream(kj::heap<NullInputStream>(), StreamEncoding::IDENTITY))));
newSystemStream(newNullInputStream(), StreamEncoding::IDENTITY))));
}
}

Expand Down
17 changes: 2 additions & 15 deletions src/workerd/io/worker.c++
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "actor-cache.h"
#include <workerd/util/batch-queue.h>
#include <workerd/util/mimetype.h>
#include <workerd/util/stream-utils.h>
#include <workerd/util/thread-scopes.h>
#include <workerd/util/xthreadnotifier.h>
#include <workerd/api/actor-state.h>
Expand Down Expand Up @@ -3368,20 +3369,6 @@ double getWallTimeForProcessSandboxOnly() {
auto timePoint = kj::systemPreciseCalendarClock().now();
return (timePoint - kj::UNIX_EPOCH) / kj::MILLISECONDS / 1e3;
}

class NullOutputStream final: public kj::AsyncOutputStream {
public:
kj::Promise<void> write(const void* buffer, size_t size) override {
return kj::READY_NOW;
}
kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
kj::Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};

} // namespace

class Worker::Isolate::ResponseStreamWrapper final: public kj::AsyncOutputStream {
Expand Down Expand Up @@ -3679,7 +3666,7 @@ kj::Promise<void> Worker::Isolate::SubrequestClient::request(
// TODO(someday): Support sending WebSocket frames over CDP. For now we fake an empty
// response.
signalResponse(kj::mv(requestId), 101, "Switching Protocols", headers,
kj::heap<NullOutputStream>());
newNullOutputStream());
return kj::mv(webSocket);
}

Expand Down
19 changes: 3 additions & 16 deletions src/workerd/tests/test-fixture.c++
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <workerd/io/worker-entrypoint.h>
#include <workerd/jsg/modules.h>
#include <workerd/server/workerd-api.h>
#include <workerd/util/stream-utils.h>

#include "test-fixture.h"

Expand Down Expand Up @@ -229,20 +230,6 @@ struct MockResponse final: public kj::HttpService::Response {
KJ_FAIL_REQUIRE("NOT SUPPORTED");
}
};

struct MemoryInputStream final: public kj::AsyncInputStream {
kj::ArrayPtr<const byte> data;

MemoryInputStream(kj::ArrayPtr<const byte> data) : data(data) { }

kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
auto toRead = kj::min(data.size(), maxBytes);
memcpy(buffer, data.begin(), toRead);
data = data.slice(toRead, data.size());
return toRead;
}
};

} // namespace


Expand Down Expand Up @@ -335,15 +322,15 @@ TestFixture::Response TestFixture::runRequest(
kj::HttpMethod method, kj::StringPtr url, kj::StringPtr body) {
kj::HttpHeaders requestHeaders(*headerTable);
MockResponse response;
MemoryInputStream requestBody(body.asBytes());
auto requestBody = newMemoryInputStream(body);

runInIoContext([&](const TestFixture::Environment& env) {
auto& globalScope = env.lock.getGlobalScope();
return globalScope.request(
method,
url,
requestHeaders,
requestBody,
*requestBody,
response,
"{}"_kj,
env.lock,
Expand Down
Loading

0 comments on commit a595505

Please sign in to comment.