Skip to content

Commit

Permalink
Cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
taras committed Aug 19, 2024
1 parent df8b7af commit 5404fb9
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Connects to an echo server, sends a message and disconnect upon reply.
async def main(endpoint):
(_, client) = await ws_connect(endpoint, ClientListener)
await client.transport.wait_until_closed()
await client.transport.wait_disconnected()
if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/echo_client_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame):
self._transport.send(WSMsgType.BINARY, msg)

(_, client) = await ws_connect(endpoint, PicowsClientListener, ssl=ssl_context)
await client._transport.wait_until_closed()
await client._transport.wait_disconnected()


async def websockets_main(endpoint: str, msg: bytes, duration: int, ssl_context):
Expand Down
2 changes: 1 addition & 1 deletion examples/echo_client_cython.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ cdef class EchoClientListener(WSListener):
async def picows_main_cython(url: str, data: bytes, duration: int, ssl_context):
cdef EchoClientListener client
(_, client) = await ws_connect(url, lambda: EchoClientListener(data, duration), ssl=ssl_context)
await client._transport.wait_until_closed()
await client._transport.wait_disconnected()
return client.rps
39 changes: 24 additions & 15 deletions picows/picows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ from libc.stdlib cimport rand

PICOWS_DEBUG_LL = 9

cdef:
set _ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
bytes _WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"


cdef extern from * nogil:
"""
Expand Down Expand Up @@ -207,12 +211,6 @@ cdef class WSFrame:
f"lib={True if self.last_in_buffer else False}, psz={self.payload_size}, tsz={self.tail_size})")


cdef:
set ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
bytes _WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xFF, 0xFF])
bytes _WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"


cdef class MemoryBuffer:
def __init__(self, Py_ssize_t default_capacity=2048):
self.size = 0
Expand Down Expand Up @@ -367,17 +365,22 @@ cdef class WSTransport:

cpdef disconnect(self):
"""
Immediately disconnect the underlying transport.
It is ok to call this method multiple times. It does nothing if the transport is already disconnected.
Close the underlying transport.
If there is unsent outgoing data in the buffer, it will be flushed
asynchronously. No more data will be received.
It is ok to call this method multiple times.
It does nothing if the transport is already closed.
"""
if self.underlying_transport.is_closing():
return
self.underlying_transport.close()

async def wait_until_closed(self):
async def wait_disconnected(self):
"""
Coroutine that conveniently allows to wait until websocket is completely closed
(underlying transport is disconnected)
Coroutine that conveniently allows to wait until websocket is
completely disconnected.
(underlying transport is closed, on_ws_disconnected has been called)
"""
if not self._disconnected_future.done():
await asyncio.shield(self._disconnected_future)
Expand Down Expand Up @@ -661,7 +664,7 @@ cdef class WSProtocol:

if self._handshake_complete_future.done():
if self._handshake_complete_future.exception() is None:
self.listener.on_ws_disconnected(self.transport)
self._invoke_on_ws_disconnected()
else:
self._handshake_complete_future.set_result(None)

Expand Down Expand Up @@ -1029,7 +1032,7 @@ cdef class WSProtocol:
self._state = WSParserState.READ_HEADER

if frame.msg_type == WSMsgType.CLOSE:
if frame.get_close_code() < 3000 and frame.get_close_code() not in ALLOWED_CLOSE_CODES:
if frame.get_close_code() < 3000 and frame.get_close_code() not in _ALLOWED_CLOSE_CODES:
raise _WSParserError(WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {frame.get_close_code()}")

Expand Down Expand Up @@ -1063,6 +1066,12 @@ cdef class WSProtocol:
else:
self._logger.exception("Unhandled exception in on_ws_frame")

cdef _invoke_on_ws_disconnected(self):
try:
self.listener.on_ws_disconnected(self.transport)
except:
self._logger.exception("Unhandled exception in on_ws_disconnected")

cdef _shrink_buffer(self):
if self._f_curr_frame_start_pos > 0:
memmove(self._buffer.data,
Expand Down Expand Up @@ -1092,6 +1101,8 @@ async def ws_connect(str url: str,
logger_name: str="client"
) -> Tuple[WSTransport, WSListener]:
"""
Open a websocket connection to a given URL.
:param url: Destination URL
:param ws_listener_factory:
A parameterless factory function that returns a user handler. User handler has to derive from :any:`WSListener`.
Expand All @@ -1111,8 +1122,6 @@ async def ws_connect(str url: str,
:param logger_name:
picows will use `picows.<logger_name>` logger to do all the logging.
:return: :any:`WSTransport` object and a user handler returned by `ws_listener_factory()'
Open a websocket connection to a given URL.
"""

url_parts = urllib.parse.urlparse(url, allow_fragments=False)
Expand Down
18 changes: 17 additions & 1 deletion tests/test_echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def get_message(self):
try:
# Gracefull shutdown, expect server to disconnect us because we have sent close message
async with async_timeout.timeout(1):
await client.transport.wait_until_closed()
await client.transport.wait_disconnected()
finally:
client.transport.disconnect()

Expand Down Expand Up @@ -191,3 +191,19 @@ def factory_listener(r):

with pytest.raises(picows.WSError, match="500 Internal Server Error"):
(_, client) = await picows.ws_connect(url, picows.WSListener)


async def test_server_bad_request():
server = await picows.ws_create_server(lambda _: picows.WSListener(),
"127.0.0.1", 0)

async with ServerAsyncContext(server):
r, w = await asyncio.open_connection("127.0.0.1", server.sockets[0].getsockname()[1])

w.write(b"zzzz\r\n\r\n")
resp_header = await r.readuntil(b"\r\n\r\n")
assert b"400 Bad Request" in resp_header
resp_data = await r.read()
assert r.at_eof()
# TODO: Why this fails?
# assert w.transport.is_closing()

0 comments on commit 5404fb9

Please sign in to comment.