Skip to content

Commit

Permalink
Coroutines Conversion: update api/streams/internal.c++
Browse files Browse the repository at this point in the history
  • Loading branch information
jasnell committed Sep 11, 2023
1 parent e421a98 commit 5b1c5d4
Showing 1 changed file with 56 additions and 31 deletions.
87 changes: 56 additions & 31 deletions src/workerd/api/streams/internal.c++
Original file line number Diff line number Diff line change
Expand Up @@ -47,64 +47,89 @@ kj::Promise<void> pumpTo(ReadableStreamSource& input, WritableStreamSink& output

// Modified from AllReader in kj/async-io.c++.
class AllReader {
using PartList = kj::Array<kj::ArrayPtr<byte>>;

public:
explicit AllReader(ReadableStreamSource& input, uint64_t limit)
: input(input), limit(limit) {
JSG_REQUIRE(limit > 0, TypeError, "Memory limit exceeded before EOF.");
KJ_IF_MAYBE(length, input.tryGetLength(StreamEncoding::IDENTITY)) {
KJ_IF_SOME(length, input.tryGetLength(StreamEncoding::IDENTITY)) {
// Oh hey, we might be able to bail early.
JSG_REQUIRE(*length < limit, TypeError, "Memory limit would be exceeded before EOF.");
JSG_REQUIRE(length < limit, TypeError, "Memory limit would be exceeded before EOF.");
}
}
KJ_DISALLOW_COPY_AND_MOVE(AllReader);

kj::Promise<kj::Array<byte>> readAllBytes() {
return loop().then([this](PartList&& partPtrs) {
auto out = kj::heapArray<byte>(runningTotal);
copyInto(out, kj::mv(partPtrs));
return kj::mv(out);
});
kj::Promise<kj::Array<kj::byte>> readAllBytes() {
return read<kj::byte>();
}

kj::Promise<kj::String> readAllText() {
return loop().then([this](PartList&& partPtrs) {
auto out = kj::heapArray<char>(runningTotal + 1);
copyInto(out.slice(0, out.size() - 1).asBytes(), kj::mv(partPtrs));
out.back() = '\0';
return kj::String(kj::mv(out));
});
auto data = co_await read<char>(ReadOption::NULL_TERMINATE);
co_return kj::String(kj::mv(data));
}

private:
ReadableStreamSource& input;
uint64_t limit;
kj::Vector<kj::Array<kj::byte>> parts;
uint64_t runningTotal = 0;

kj::Promise<PartList> loop() {
auto bytes = kj::heapArray<kj::byte>(4096);
enum class ReadOption {
NONE,
NULL_TERMINATE,
};

template <typename T>
kj::Promise<kj::Array<T>> read(ReadOption option = ReadOption::NONE) {
kj::Vector<kj::Array<T>> parts;
uint64_t runningTotal = 0;
static constexpr size_t DEFAULT_BUFFER_CHUNK = 4096;
static constexpr size_t MAX_BUFFER_CHUNK = DEFAULT_BUFFER_CHUNK * 4;

// If we know in advance how much data we'll be reading, then we can attempt to
// optimize the loop here by setting the value specifically so we are only
// allocating once. But, to be safe, let's enforce an upper bound on each allocation
// even if we do know the total.
size_t amountToRead = kj::min(MAX_BUFFER_CHUNK,
input.tryGetLength(StreamEncoding::IDENTITY).orDefault(DEFAULT_BUFFER_CHUNK));

for (;;) {
// TODO(perf): We can likely further optimize this loop by checking to see
// how much of the buffer was filled and using the remaining buffer space if
// it is not completely filled by the previous iteration. Doing so makes this
// loop a bit more complicated tho, so for now let's keep things simple.
auto bytes = kj::heapArray<T>(amountToRead);
size_t amount = co_await input.tryRead(bytes.begin(), 1, bytes.size());

return input.tryRead(bytes.begin(), 1, bytes.size())
.then([this, bytes = kj::mv(bytes)](size_t amount) mutable
-> kj::Promise<PartList> {
if (amount == 0) {
return KJ_MAP(p, parts) { return p.asPtr(); };
break;
}

runningTotal += amount;
if (runningTotal >= limit) {
return JSG_KJ_EXCEPTION(FAILED, TypeError, "Memory limit exceeded before EOF.");
}
JSG_REQUIRE(runningTotal < limit, TypeError, "Memory limit exceeded before EOF.");
parts.add(bytes.slice(0, amount).attach(kj::mv(bytes)));
return loop();
});
};

if (option == ReadOption::NULL_TERMINATE) {
auto out = kj::heapArray<T>(runningTotal + 1);
out[runningTotal] = '\0';
copyInto<T>(out, parts.asPtr());
co_return kj::mv(out);
}

// As an optimization, if there's only a single part in the list, we can avoid
// further copies.
if (parts.size() == 1) {
co_return kj::mv(parts[0]);
}

auto out = kj::heapArray<T>(runningTotal);
copyInto<T>(out, parts.asPtr());
co_return kj::mv(out);
}

void copyInto(kj::ArrayPtr<byte> out, PartList in) {
template <typename T>
void copyInto(kj::ArrayPtr<T> out, kj::ArrayPtr<kj::Array<T>> in) {
size_t pos = 0;
for (auto& part: in) {
KJ_ASSERT(part.size() <= out.size() - pos);
KJ_DASSERT(part.size() <= out.size() - pos);
memcpy(out.begin() + pos, part.begin(), part.size());
pos += part.size();
}
Expand Down

0 comments on commit 5b1c5d4

Please sign in to comment.