Skip to content

Commit

Permalink
WritableStream will now wait for socket connection before closing.
Browse files Browse the repository at this point in the history
  • Loading branch information
dom96 committed Oct 9, 2023
1 parent 461a137 commit 5d1720a
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 16 deletions.
14 changes: 12 additions & 2 deletions src/workerd/api/sockets.c++
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ jsg::Ref<Socket> setupSocket(
if (!allowHalfOpen) {
eofPromise = readable->onEof(js);
}
auto writable = jsg::alloc<WritableStream>(ioContext, kj::mv(sysStreams.writable));
auto openedPrPair = js.newPromiseAndResolver<void>();
openedPrPair.promise.markAsHandled(js);
auto writable = jsg::alloc<WritableStream>(
ioContext, kj::mv(sysStreams.writable), kj::none, openedPrPair.promise.whenResolved(js));

auto result = jsg::alloc<Socket>(
js, ioContext,
Expand All @@ -156,7 +159,9 @@ jsg::Ref<Socket> setupSocket(
kj::mv(tlsStarter),
isSecureSocket,
kj::mv(domain),
isDefaultFetchPort);
isDefaultFetchPort,
kj::mv(openedPrPair));

KJ_IF_SOME(p, eofPromise) {
result->handleReadableEof(js, kj::mv(p));
}
Expand Down Expand Up @@ -337,6 +342,8 @@ void Socket::handleProxyStatus(
" — consider using fetch instead");
}
handleProxyError(js, JSG_KJ_EXCEPTION(FAILED, Error, msg));
} else {
openedResolver.resolve(js);
}
});
result.markAsHandled(js);
Expand All @@ -357,13 +364,16 @@ void Socket::handleProxyStatus(jsg::Lock& js, kj::Promise<kj::Maybe<kj::Exceptio
[this, self = JSG_THIS](jsg::Lock& js, kj::Maybe<kj::Exception> result) -> void {
if (result != kj::none) {
handleProxyError(js, JSG_KJ_EXCEPTION(FAILED, Error, "connection attempt failed"));
} else {
openedResolver.resolve(js);
}
});
result.markAsHandled(js);
}

void Socket::handleProxyError(jsg::Lock& js, kj::Exception e) {
resolveFulfiller(js, kj::mv(e));
openedResolver.reject(js, kj::cp(e));
readable->getController().cancel(js, kj::none).markAsHandled(js);
writable->getController().abort(js, js.error(e.getDescription())).markAsHandled(js);
}
Expand Down
13 changes: 11 additions & 2 deletions src/workerd/api/sockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class Socket: public jsg::Object {
jsg::Ref<ReadableStream> readableParam, jsg::Ref<WritableStream> writable,
jsg::PromiseResolverPair<void> closedPrPair, kj::Promise<void> watchForDisconnectTask,
jsg::Optional<SocketOptions> options, kj::Own<kj::TlsStarterCallback> tlsStarter,
bool isSecureSocket, kj::String domain, bool isDefaultFetchPort)
bool isSecureSocket, kj::String domain, bool isDefaultFetchPort,
jsg::PromiseResolverPair<void> openedPrPair)
: connectionStream(context.addObject(kj::mv(connectionStream))),
readable(kj::mv(readableParam)), writable(kj::mv(writable)),
closedResolver(kj::mv(closedPrPair.resolver)),
Expand All @@ -57,13 +58,18 @@ class Socket: public jsg::Object {
tlsStarter(context.addObject(kj::mv(tlsStarter))),
isSecureSocket(isSecureSocket),
domain(kj::mv(domain)),
isDefaultFetchPort(isDefaultFetchPort) { };
isDefaultFetchPort(isDefaultFetchPort),
openedResolver(kj::mv(openedPrPair.resolver)),
openedPromise(kj::mv(openedPrPair.promise)) { };

jsg::Ref<ReadableStream> getReadable() { return readable.addRef(); }
jsg::Ref<WritableStream> getWritable() { return writable.addRef(); }
jsg::MemoizedIdentity<jsg::Promise<void>>& getClosed() {
return closedPromise;
}
jsg::MemoizedIdentity<jsg::Promise<void>>& getOpened() {
return openedPromise;
}

// Closes the socket connection.
jsg::Promise<void> close(jsg::Lock& js);
Expand Down Expand Up @@ -118,6 +124,9 @@ class Socket: public jsg::Object {
kj::String domain;
// Whether the port this socket connected to is 80/443. Used for nicer errors.
bool isDefaultFetchPort;
// This fulfiller is used to resolve the `openedPromise` below.
jsg::Promise<void>::Resolver openedResolver;
jsg::MemoizedIdentity<jsg::Promise<void>> openedPromise;

kj::Promise<kj::Own<kj::AsyncIoStream>> processConnection();
jsg::Promise<void> maybeCloseWriteSide(jsg::Lock& js);
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/streams/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,8 @@ kj::Own<WritableStreamController> newWritableStreamJsController();
kj::Own<WritableStreamController> newWritableStreamInternalController(
IoContext& ioContext,
kj::Own<WritableStreamSink> source,
kj::Maybe<uint64_t> maybeHighWaterMark = kj::none);
kj::Maybe<uint64_t> maybeHighWaterMark = kj::none,
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable = kj::none);

struct Unlocked {};
struct Locked {};
Expand Down
22 changes: 17 additions & 5 deletions src/workerd/api/streams/internal.c++
Original file line number Diff line number Diff line change
Expand Up @@ -921,9 +921,7 @@ void WritableStreamInternalController::setHighWaterMark(uint64_t highWaterMark)
maybeHighWaterMark = highWaterMark;
}

jsg::Promise<void> WritableStreamInternalController::close(
jsg::Lock& js,
bool markAsHandled) {
jsg::Promise<void> WritableStreamInternalController::closeImpl(jsg::Lock& js, bool markAsHandled) {
if (isClosedOrClosing()) {
auto reason = js.v8TypeError("This WritableStream has been closed."_kj);
return rejectedMaybeHandledPromise<void>(js, reason, markAsHandled);
Expand Down Expand Up @@ -955,6 +953,18 @@ jsg::Promise<void> WritableStreamInternalController::close(
KJ_UNREACHABLE;
}

jsg::Promise<void> WritableStreamInternalController::close(
jsg::Lock& js,
bool markAsHandled) {
KJ_IF_SOME(closureWaitable, maybeClosureWaitable) {
return closureWaitable.then(js, [&](jsg::Lock& js) {
return closeImpl(js, markAsHandled);
});
} else {
return closeImpl(js, markAsHandled);
}
}

jsg::Promise<void> WritableStreamInternalController::flush(
jsg::Lock& js,
bool markAsHandled) {
Expand Down Expand Up @@ -2219,10 +2229,12 @@ kj::Own<ReadableStreamController> newReadableStreamInternalController(
kj::Own<WritableStreamController> newWritableStreamInternalController(
IoContext& ioContext,
kj::Own<WritableStreamSink> sink,
kj::Maybe<uint64_t> maybeHighWaterMark) {
kj::Maybe<uint64_t> maybeHighWaterMark,
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable) {
return kj::heap<WritableStreamInternalController>(
ioContext.addObject(kj::mv(sink)),
maybeHighWaterMark);
maybeHighWaterMark,
kj::mv(maybeClosureWaitable));
}

} // namespace workerd::api
12 changes: 9 additions & 3 deletions src/workerd/api/streams/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ class WritableStreamInternalController: public WritableStreamController {
explicit WritableStreamInternalController(StreamStates::Errored errored)
: state(kj::mv(errored)) {}
explicit WritableStreamInternalController(Writable writable,
kj::Maybe<uint64_t> maybeHighWaterMark = kj::none)
: state(kj::mv(writable)),
maybeHighWaterMark(maybeHighWaterMark) {
kj::Maybe<uint64_t> maybeHighWaterMark = kj::none,
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable = kj::none) : state(kj::mv(writable)),
maybeHighWaterMark(maybeHighWaterMark),
maybeClosureWaitable(kj::mv(maybeClosureWaitable)) {
}

WritableStreamInternalController(WritableStreamInternalController&& other) = default;
Expand Down Expand Up @@ -226,6 +227,7 @@ class WritableStreamInternalController: public WritableStreamController {
void drain(jsg::Lock& js, v8::Local<v8::Value> reason);
void finishClose(jsg::Lock& js);
void finishError(jsg::Lock& js, v8::Local<v8::Value> reason);
jsg::Promise<void> closeImpl(jsg::Lock& js, bool markAsHandled);

struct PipeLocked {
ReadableStream& ref;
Expand All @@ -245,6 +247,10 @@ class WritableStreamInternalController: public WritableStreamController {
// promise on the writer.
kj::Maybe<uint64_t> maybeHighWaterMark;

// Used by Sockets code to ensure the connection is established before the associated
// WritableStream is closed.
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable;

void increaseCurrentWriteBufferSize(jsg::Lock& js, uint64_t amount);
void decreaseCurrentWriteBufferSize(jsg::Lock& js, uint64_t amount);
void updateBackpressure(jsg::Lock& js, bool backpressure);
Expand Down
5 changes: 3 additions & 2 deletions src/workerd/api/streams/writable.c++
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ void WritableStreamDefaultWriter::visitForGc(jsg::GcVisitor& visitor) {
WritableStream::WritableStream(
IoContext& ioContext,
kj::Own<WritableStreamSink> sink,
kj::Maybe<uint64_t> maybeHighWaterMark)
kj::Maybe<uint64_t> maybeHighWaterMark,
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable)
: WritableStream(newWritableStreamInternalController(ioContext, kj::mv(sink),
maybeHighWaterMark)) {}
maybeHighWaterMark, kj::mv(maybeClosureWaitable))) {}

WritableStream::WritableStream(kj::Own<WritableStreamController> controller)
: ioContext(tryGetIoContext()),
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/streams/writable.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ class WritableStream: public jsg::Object {
public:
explicit WritableStream(IoContext& ioContext,
kj::Own<WritableStreamSink> sink,
kj::Maybe<uint64_t> maybeHighWaterMark = kj::none);
kj::Maybe<uint64_t> maybeHighWaterMark = kj::none,
kj::Maybe<jsg::Promise<void>> maybeClosureWaitable = kj::none);

explicit WritableStream(kj::Own<WritableStreamController> controller);

Expand Down

0 comments on commit 5d1720a

Please sign in to comment.