diff --git a/src/workerd/api/streams/readable.c++ b/src/workerd/api/streams/readable.c++ index 9ec71a46e65..21ff1ddbc22 100644 --- a/src/workerd/api/streams/readable.c++ +++ b/src/workerd/api/streams/readable.c++ @@ -611,6 +611,57 @@ private: kj::Maybe expectedLength; }; +// Wrapper around ReadableStreamSource that prevents deferred proxying. We need this for RPC +// streams because although they are "system streams", they become disconnected when the IoContext +// is destroyed, due to the JsRpcCustomEventImpl being canceled. +// +// TODO(someday): Devise a better way for RPC streams to extend the lifetime of the RPC session +// beyond the destruction of the IoContext, if it is being used for deferred proxying. +class NoDeferredProxyReadableStream: public ReadableStreamSource { +public: + NoDeferredProxyReadableStream(kj::Own inner, IoContext& ioctx) + : inner(kj::mv(inner)), ioctx(ioctx) {} + + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + return inner->tryRead(buffer, minBytes, maxBytes); + } + + kj::Promise> pumpTo(WritableStreamSink& output, bool end) override { + // Move the deferred proxy part of the task over to the non-deferred part. To do this, + // we use `ioctx.waitForDeferredProxy()`, which returns a single promise covering both parts + // (and, importantly, registering pending events where needed). Then, we add a noop deferred + // proxy to the end of that. + return addNoopDeferredProxy(ioctx.waitForDeferredProxy(inner->pumpTo(output, end))); + } + + StreamEncoding getPreferredEncoding() override { + return inner->getPreferredEncoding(); + } + + kj::Maybe tryGetLength(StreamEncoding encoding) override { + return inner->tryGetLength(encoding); + } + + void cancel(kj::Exception reason) override { + return inner->cancel(kj::mv(reason)); + } + + kj::Maybe tryTee(uint64_t limit) override { + return inner->tryTee(limit).map([&](Tee tee) { + return Tee { + .branches = { + kj::heap(kj::mv(tee.branches[0]), ioctx), + kj::heap(kj::mv(tee.branches[1]), ioctx), + } + }; + }); + } + +private: + kj::Own inner; + IoContext& ioctx; +}; + } // namespace void ReadableStream::serialize(jsg::Lock& js, jsg::Serializer& serializer) { @@ -693,7 +744,9 @@ jsg::Ref ReadableStream::deserialize( externalHandler->setLastStream(ioctx.getByteStreamFactory().kjToCapnp(kj::mv(out))); - return jsg::alloc(ioctx, newSystemStream(kj::mv(in), encoding, ioctx)); + return jsg::alloc(ioctx, + kj::heap( + newSystemStream(kj::mv(in), encoding, ioctx), ioctx)); } kj::StringPtr ReaderImpl::jsgGetMemoryName() const { return "ReaderImpl"_kjc; } diff --git a/src/workerd/api/tests/js-rpc-test.js b/src/workerd/api/tests/js-rpc-test.js index dde0e8d1867..bc5d661b430 100644 --- a/src/workerd/api/tests/js-rpc-test.js +++ b/src/workerd/api/tests/js-rpc-test.js @@ -74,11 +74,18 @@ export let nonClass = { async fetch(req, env, ctx) { // This is used in the stream test to fetch some gziped data. - return new Response("this text was gzipped", { - headers: { - "Content-Encoding": "gzip" - } - }); + if (req.url.endsWith("/gzip")) { + return new Response("this text was gzipped", { + headers: { + "Content-Encoding": "gzip" + } + }); + } else if (req.url.endsWith("/stream-from-rpc")) { + let stream = await env.MyService.returnReadableStream(); + return new Response(stream); + } else { + throw new Error("unknown route"); + } } } @@ -1062,7 +1069,7 @@ export let streams = { // Send an encoded ReadableStream { - let gzippedResp = await env.self.fetch("http://foo"); + let gzippedResp = await env.self.fetch("http://foo/gzip"); let text = await env.MyService.readFromStream(gzippedResp.body); @@ -1087,6 +1094,13 @@ export let streams = { assert.strictEqual(await readPromise, "foo, bar, baz!"); } + + // Perform an HTTP request whose response uses a ReadableStream obtained over RPC. + { + let resp = await env.self.fetch("http://foo/stream-from-rpc"); + + assert.strictEqual(await resp.text(), "foo, bar, baz!"); + } } }