Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
taras committed Aug 21, 2024
1 parent f216791 commit 650f5ff
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 49 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ Connects to an echo server, sends a message and disconnect upon reply.
transport.disconnect()
async def main(endpoint):
(_, client) = await ws_connect(endpoint, ClientListener)
async def main(url):
(_, client) = await ws_connect(ClientListener, url)
await client.transport.wait_disconnected()
Expand Down
2 changes: 1 addition & 1 deletion examples/echo_client_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame):
else:
self._transport.send(WSMsgType.BINARY, msg)

(_, client) = await ws_connect(endpoint, PicowsClientListener, ssl=ssl_context)
(_, client) = await ws_connect(PicowsClientListener, endpoint, ssl_context=ssl_context)
await client._transport.wait_disconnected()


Expand Down
4 changes: 3 additions & 1 deletion examples/echo_client_cython.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ cdef class EchoClientListener(WSListener):

async def picows_main_cython(url: str, data: bytes, duration: int, ssl_context):
cdef EchoClientListener client
(_, client) = await ws_connect(url, lambda: EchoClientListener(data, duration), ssl=ssl_context)
(_, client) = await ws_connect(lambda: EchoClientListener(data, duration),
url,
ssl_context=ssl_context)
await client._transport.wait_disconnected()
return client.rps
40 changes: 13 additions & 27 deletions picows/picows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1096,50 +1096,40 @@ cdef class WSProtocol:
self.transport.disconnect()


async def ws_connect(str url: str,
ws_listener_factory: Callable[[], WSListener],
ssl: Optional[Union[bool, SSLContext]]=None,
async def ws_connect(ws_listener_factory: Callable[[], WSListener],
str url: str,
ssl_context: Optional[Union[bool, SSLContext]]=None,
bint disconnect_on_exception: bool=True,
ssl_handshake_timeout=5,
ssl_shutdown_timeout=5,
logger_name: str="client",
websocket_handshake_timeout=5,
local_addr: Optional[Tuple[str, int]]=None,
logger_name: str="client"
**kwargs
) -> Tuple[WSTransport, WSListener]:
"""
Open a websocket connection to a given URL.
:param url: Destination URL
:param ws_listener_factory:
A parameterless factory function that returns a user handler. User handler has to derive from :any:`WSListener`.
:param ssl: optional SSLContext to override default one when wss scheme is used
:param url: Destination URL
:param ssl_context: optional SSLContext to override default one when wss scheme is used
:param disconnect_on_exception:
Indicates whether the client should initiate disconnect on any exception
thrown from WSListener.on_ws* callbacks
:param ssl_handshake_timeout:
is (for a TLS connection) the time in seconds to wait for the TLS handshake to complete before aborting the connection.
:param ssl_shutdown_timeout:
is the time in seconds to wait for the SSL shutdown to complete before aborting the connection.
:param websocket_handshake_timeout:
is the time in seconds to wait for the websocket server to reply to websocket handshake request
:param local_addr:
if given, is a (local_host, local_port) tuple used to bind the socket locally. The local_host and local_port
are looked up using getaddrinfo(), similarly to host and port from url.
:param logger_name:
picows will use `picows.<logger_name>` logger to do all the logging.
:return: :any:`WSTransport` object and a user handler returned by `ws_listener_factory()'
"""

assert "ssl" not in kwargs, "explicit 'ssl' argument for loop.create_connection is not supported"
assert "sock" not in kwargs, "explicit 'sock' argument for loop.create_connection is not supported"
assert "all_errors" not in kwargs, "explicit 'all_errors' argument for loop.create_connection is not supported"

url_parts = urllib.parse.urlparse(url, allow_fragments=False)

if url_parts.scheme == "wss":
if ssl is None:
ssl = True
ssl = ssl_context if ssl_context is not None else True
port = url_parts.port or 443
elif url_parts.scheme == "ws":
ssl = None
ssl_handshake_timeout = None
ssl_shutdown_timeout = None
port = url_parts.port or 80
else:
raise ValueError(f"invalid url scheme: {url}")
Expand All @@ -1150,11 +1140,7 @@ async def ws_connect(str url: str,
cdef WSProtocol ws_protocol

(_, ws_protocol) = await asyncio.get_running_loop().create_connection(
ws_protocol_factory, url_parts.hostname, port,
local_addr=local_addr,
ssl=ssl,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ws_protocol_factory, url_parts.hostname, port, ssl=ssl, **kwargs)

await ws_protocol.wait_until_handshake_complete()

Expand Down
24 changes: 6 additions & 18 deletions tests/test_echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ async def get_message(self):
async with async_timeout.timeout(1):
return await self.msg_queue.get()

(_, client) = await picows.ws_connect(echo_server, PicowsClientListener,
ssl=create_client_ssl_context(),
(_, client) = await picows.ws_connect(PicowsClientListener, echo_server,
ssl_context=create_client_ssl_context(),
websocket_handshake_timeout=0.5)
yield client

Expand Down Expand Up @@ -153,8 +153,8 @@ async def test_close(echo_client):
async def test_client_handshake_timeout(echo_server):
# Set unreasonably small timeout
with pytest.raises(TimeoutError):
(_, client) = await picows.ws_connect(echo_server, picows.WSListener,
ssl=create_client_ssl_context(),
(_, client) = await picows.ws_connect(picows.WSListener, echo_server,
ssl_context=create_client_ssl_context(),
websocket_handshake_timeout=0.00001)


Expand All @@ -178,7 +178,7 @@ async def test_route_not_found():
url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}/"

with pytest.raises(picows.WSError, match="404 Not Found"):
(_, client) = await picows.ws_connect(url, picows.WSListener)
(_, client) = await picows.ws_connect(picows.WSListener, url)


async def test_server_internal_error():
Expand All @@ -190,7 +190,7 @@ def factory_listener(r):
url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}/"

with pytest.raises(picows.WSError, match="500 Internal Server Error"):
(_, client) = await picows.ws_connect(url, picows.WSListener)
(_, client) = await picows.ws_connect(picows.WSListener, url)


async def test_server_bad_request():
Expand All @@ -205,15 +205,3 @@ async def test_server_bad_request():
assert b"400 Bad Request" in resp_header
await r.read()
assert r.at_eof()

# From ChatGPT:
# Common Misconception
# It's a common misconception that w.wait_closed() should detect the server's disconnection and automatically close the connection. In reality:
#
# w.wait_closed() only waits for the client-side closure process to finish.
# You must explicitly call w.close() to initiate the closing process after detecting that the server has closed its side of the connection.
#
# I didn't know that? Isn't that stupid and defies the whole purpose of
# wait_closed?
w.close()
await w.wait_closed()

0 comments on commit 650f5ff

Please sign in to comment.