Skip to content

Commit

Permalink
Run handler only when opening handshake succeeds.
Browse files Browse the repository at this point in the history
When process_request() or process_response() returned a HTTP response
without calling accept() or reject() and with a status code other than
101, the connection handler used to start, which was incorrect.

Fix #1419.

Also move start_keepalive() outside of handshake() and bring it together
with starting the connection handler, which is more logical.
  • Loading branch information
aaugustin committed Sep 11, 2024
1 parent d19ed26 commit 98f236f
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 26 deletions.
7 changes: 7 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ Improvements
Previously, :exc:`RuntimeError` was raised. For backwards compatibility,
:exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`.

Bug fixes
.........

* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't
start the connection handler anymore when ``process_request`` or
``process_response`` returns a HTTP response.

13.0.1
------

Expand Down
5 changes: 2 additions & 3 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ async def handshake(
# before receiving a response, when the response cannot be parsed, or
# when the response fails the handshake.

if self.protocol.handshake_exc is None:
self.start_keepalive()
else:
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc

def process_event(self, event: Event) -> None:
Expand Down Expand Up @@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection:
raise uri_or_exc from exc

else:
self.connection.start_keepalive()
return self.connection
else:
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
Expand Down
11 changes: 6 additions & 5 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,11 @@ async def handshake(
self.protocol.send_response(self.response)

# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a request, when the request cannot be parsed, or when
# the response fails the handshake.
# before receiving a request, when the request cannot be parsed, when
# the handshake encounters an error, or when process_request or
# process_response sends a HTTP response that rejects the handshake.

if self.protocol.handshake_exc is None:
self.start_keepalive()
else:
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc

def process_event(self, event: Event) -> None:
Expand Down Expand Up @@ -369,7 +368,9 @@ async def conn_handler(self, connection: ServerConnection) -> None:
connection.close_transport()
return

assert connection.protocol.state is OPEN
try:
connection.start_keepalive()
await self.handler(connection)
except Exception:
connection.logger.error("connection handler failed", exc_info=True)
Expand Down
23 changes: 14 additions & 9 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def accept(self, request: Request) -> Response:
if protocol_header is not None:
headers["Sec-WebSocket-Protocol"] = protocol_header

self.logger.info("connection open")
return Response(101, "Switching Protocols", headers)

def process_request(
Expand Down Expand Up @@ -515,14 +514,7 @@ def reject(self, status: StatusLike, text: str) -> Response:
("Content-Type", "text/plain; charset=utf-8"),
]
)
response = Response(status.value, status.phrase, headers, body)
# When reject() is called from accept(), handshake_exc is already set.
# If a user calls reject(), set handshake_exc to guarantee invariant:
# "handshake_exc is None if and only if opening handshake succeeded."
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
self.logger.info("connection rejected (%d %s)", status.value, status.phrase)
return response
return Response(status.value, status.phrase, headers, body)

def send_response(self, response: Response) -> None:
"""
Expand All @@ -545,7 +537,20 @@ def send_response(self, response: Response) -> None:
if response.status_code == 101:
assert self.state is CONNECTING
self.state = OPEN
self.logger.info("connection open")

else:
# handshake_exc may be already set if accept() encountered an error.
# If the connection isn't open, set handshake_exc to guarantee that
# handshake_exc is None if and only if opening handshake succeeded.
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
self.logger.info(
"connection rejected (%d %s)",
response.status_code,
response.reason_phrase,
)

self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
Expand Down
8 changes: 5 additions & 3 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
validate_subprotocols,
)
from ..http11 import SERVER, Request, Response
from ..protocol import CONNECTING, Event
from ..protocol import CONNECTING, OPEN, Event
from ..server import ServerProtocol
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
from .connection import Connection
Expand Down Expand Up @@ -166,8 +166,9 @@ def handshake(
self.protocol.send_response(self.response)

# self.protocol.handshake_exc is always set when the connection is lost
# before receiving a request, when the request cannot be parsed, or when
# the response fails the handshake.
# before receiving a request, when the request cannot be parsed, when
# the handshake encounters an error, or when process_request or
# process_response sends a HTTP response that rejects the handshake.

if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
Expand Down Expand Up @@ -569,6 +570,7 @@ def protocol_select_subprotocol(
connection.recv_events_thread.join()
return

assert connection.protocol.state is OPEN
try:
handler(connection)
except Exception:
Expand Down
10 changes: 8 additions & 2 deletions tests/asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ async def test_process_request_returns_response(self):
def process_request(ws, request):
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")

async with serve(*args, process_request=process_request) as server:
async def handler(ws):
self.fail("handler must not run")

async with serve(handler, *args[1:], process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
Expand All @@ -160,7 +163,10 @@ async def test_async_process_request_returns_response(self):
async def process_request(ws, request):
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")

async with serve(*args, process_request=process_request) as server:
async def handler(ws):
self.fail("handler must not run")

async with serve(handler, *args[1:], process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
Expand Down
5 changes: 4 additions & 1 deletion tests/sync/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ def test_process_request_returns_response(self):
def process_request(ws, request):
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")

with run_server(process_request=process_request) as server:
def handler(ws):
self.fail("handler must not run")

with run_server(handler, process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
with connect(get_uri(server)):
self.fail("did not raise")
Expand Down
45 changes: 42 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ def make_request(self):
),
)

def test_send_accept(self):
def test_send_response_after_successful_accept(self):
server = ServerProtocol()
request = self.make_request()
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
response = server.accept(self.make_request())
response = server.accept(request)
self.assertIsInstance(response, Response)
server.send_response(response)
self.assertEqual(
Expand All @@ -126,7 +127,32 @@ def test_send_accept(self):
self.assertFalse(server.close_expected())
self.assertEqual(server.state, OPEN)

def test_send_reject(self):
def test_send_response_after_failed_accept(self):
server = ServerProtocol()
request = self.make_request()
del request.headers["Sec-WebSocket-Key"]
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
response = server.accept(request)
self.assertIsInstance(response, Response)
server.send_response(response)
self.assertEqual(
server.data_to_send(),
[
f"HTTP/1.1 400 Bad Request\r\n"
f"Date: {DATE}\r\n"
f"Connection: close\r\n"
f"Content-Length: 94\r\n"
f"Content-Type: text/plain; charset=utf-8\r\n"
f"\r\n"
f"Failed to open a WebSocket connection: "
f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(),
b"",
],
)
self.assertTrue(server.close_expected())
self.assertEqual(server.state, CONNECTING)

def test_send_response_after_reject(self):
server = ServerProtocol()
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n")
Expand All @@ -148,6 +174,19 @@ def test_send_reject(self):
self.assertTrue(server.close_expected())
self.assertEqual(server.state, CONNECTING)

def test_send_response_without_accept_or_reject(self):
server = ServerProtocol()
server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n"))
self.assertEqual(
server.data_to_send(),
[
"HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(),
b"",
],
)
self.assertTrue(server.close_expected())
self.assertEqual(server.state, CONNECTING)

def test_accept_response(self):
server = ServerProtocol()
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
Expand Down

0 comments on commit 98f236f

Please sign in to comment.