Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coroutines Conversion: update api/streams/internal.c++ #1137

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 56 additions & 31 deletions src/workerd/api/streams/internal.c++
Original file line number Diff line number Diff line change
Expand Up @@ -47,64 +47,89 @@ kj::Promise<void> pumpTo(ReadableStreamSource& input, WritableStreamSink& output

// Modified from AllReader in kj/async-io.c++.
class AllReader {
using PartList = kj::Array<kj::ArrayPtr<byte>>;

public:
explicit AllReader(ReadableStreamSource& input, uint64_t limit)
: input(input), limit(limit) {
JSG_REQUIRE(limit > 0, TypeError, "Memory limit exceeded before EOF.");
KJ_IF_MAYBE(length, input.tryGetLength(StreamEncoding::IDENTITY)) {
KJ_IF_SOME(length, input.tryGetLength(StreamEncoding::IDENTITY)) {
// Oh hey, we might be able to bail early.
JSG_REQUIRE(*length < limit, TypeError, "Memory limit would be exceeded before EOF.");
JSG_REQUIRE(length < limit, TypeError, "Memory limit would be exceeded before EOF.");
}
}
KJ_DISALLOW_COPY_AND_MOVE(AllReader);

kj::Promise<kj::Array<byte>> readAllBytes() {
return loop().then([this](PartList&& partPtrs) {
auto out = kj::heapArray<byte>(runningTotal);
copyInto(out, kj::mv(partPtrs));
return kj::mv(out);
});
kj::Promise<kj::Array<kj::byte>> readAllBytes() {
return read<kj::byte>();
}

kj::Promise<kj::String> readAllText() {
return loop().then([this](PartList&& partPtrs) {
auto out = kj::heapArray<char>(runningTotal + 1);
copyInto(out.slice(0, out.size() - 1).asBytes(), kj::mv(partPtrs));
out.back() = '\0';
return kj::String(kj::mv(out));
});
auto data = co_await read<char>(ReadOption::NULL_TERMINATE);
co_return kj::String(kj::mv(data));
}

private:
ReadableStreamSource& input;
uint64_t limit;
kj::Vector<kj::Array<kj::byte>> parts;
uint64_t runningTotal = 0;

kj::Promise<PartList> loop() {
auto bytes = kj::heapArray<kj::byte>(4096);
enum class ReadOption {
NONE,
NULL_TERMINATE,
};

template <typename T>
kj::Promise<kj::Array<T>> read(ReadOption option = ReadOption::NONE) {
kj::Vector<kj::Array<T>> parts;
uint64_t runningTotal = 0;
static constexpr size_t DEFAULT_BUFFER_CHUNK = 4096;
static constexpr size_t MAX_BUFFER_CHUNK = DEFAULT_BUFFER_CHUNK * 4;

// If we know in advance how much data we'll be reading, then we can attempt to
// optimize the loop here by setting the value specifically so we are only
// allocating once. But, to be safe, let's enforce an upper bound on each allocation
// even if we do know the total.
size_t amountToRead = kj::min(MAX_BUFFER_CHUNK,
input.tryGetLength(StreamEncoding::IDENTITY).orDefault(DEFAULT_BUFFER_CHUNK));

for (;;) {
// TODO(perf): We can likely further optimize this loop by checking to see
// how much of the buffer was filled and using the remaining buffer space if
// it is not completely filled by the previous iteration. Doing so makes this
// loop a bit more complicated tho, so for now let's keep things simple.
auto bytes = kj::heapArray<T>(amountToRead);
size_t amount = co_await input.tryRead(bytes.begin(), 1, bytes.size());

return input.tryRead(bytes.begin(), 1, bytes.size())
.then([this, bytes = kj::mv(bytes)](size_t amount) mutable
-> kj::Promise<PartList> {
if (amount == 0) {
return KJ_MAP(p, parts) { return p.asPtr(); };
break;
}

runningTotal += amount;
if (runningTotal >= limit) {
return JSG_KJ_EXCEPTION(FAILED, TypeError, "Memory limit exceeded before EOF.");
}
JSG_REQUIRE(runningTotal < limit, TypeError, "Memory limit exceeded before EOF.");
parts.add(bytes.slice(0, amount).attach(kj::mv(bytes)));
return loop();
});
};

if (option == ReadOption::NULL_TERMINATE) {
auto out = kj::heapArray<T>(runningTotal + 1);
out[runningTotal] = '\0';
copyInto<T>(out, parts.asPtr());
co_return kj::mv(out);
}

// As an optimization, if there's only a single part in the list, we can avoid
// further copies.
if (parts.size() == 1) {
co_return kj::mv(parts[0]);
}

auto out = kj::heapArray<T>(runningTotal);
copyInto<T>(out, parts.asPtr());
co_return kj::mv(out);
}

void copyInto(kj::ArrayPtr<byte> out, PartList in) {
template <typename T>
void copyInto(kj::ArrayPtr<T> out, kj::ArrayPtr<kj::Array<T>> in) {
size_t pos = 0;
for (auto& part: in) {
KJ_ASSERT(part.size() <= out.size() - pos);
KJ_DASSERT(part.size() <= out.size() - pos);
memcpy(out.begin() + pos, part.begin(), part.size());
pos += part.size();
}
Expand Down
Loading