From 76137712f61712569cee5887cdacbaba921e96f0 Mon Sep 17 00:00:00 2001 From: PerchunPak Date: Sun, 16 Oct 2022 09:56:54 +0200 Subject: [PATCH] Use `with` blocks for closing and opening connections Fixes #337. --- mcstatus/protocol/connection.py | 61 +++++++++++++------------- mcstatus/server.py | 35 +++++++-------- tests/protocol/test_connection.py | 72 ++++++++++++++++--------------- tests/test_async_support.py | 4 +- tests/test_server.py | 47 ++++++++++---------- tests/test_timeout.py | 14 +++--- 6 files changed, 113 insertions(+), 120 deletions(-) diff --git a/mcstatus/protocol/connection.py b/mcstatus/protocol/connection.py index e9c55ed8..8c9b40be 100644 --- a/mcstatus/protocol/connection.py +++ b/mcstatus/protocol/connection.py @@ -17,7 +17,7 @@ from mcstatus.address import Address if TYPE_CHECKING: - from typing_extensions import SupportsIndex, TypeAlias + from typing_extensions import Self, SupportsIndex, TypeAlias BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" @@ -512,12 +512,11 @@ def close(self) -> None: self.socket.shutdown(socket.SHUT_RDWR) self.socket.close() - def __del__(self) -> None: - """Close self.socket.""" - try: - self.close() - except OSError: # Probably, the socket was already closed by the OS - pass + def __enter__(self) -> Self: + return self + + def __exit__(self, *_) -> None: + self.close() class TCPSocketConnection(SocketConnection): @@ -588,18 +587,18 @@ def write(self, data: Connection | str | bytes | bytearray) -> None: class TCPAsyncSocketConnection(BaseAsyncReadSyncWriteConnection): """Asynchronous TCP Connection class""" - __slots__ = ("reader", "writer", "timeout") + __slots__ = ("reader", "writer", "timeout", "_addr") - def __init__(self) -> None: + def __init__(self, addr: Address, timeout: float = 3) -> None: # These will only be None until connect is called, ignore the None type assignment self.reader: asyncio.StreamReader = None # type: ignore[assignment] self.writer: asyncio.StreamWriter = None # type: ignore[assignment] - self.timeout: float = None # type: ignore[assignment] + self.timeout: float = timeout + self._addr = addr - async def connect(self, addr: Address, timeout: float = 3) -> None: + async def connect(self) -> None: """Use asyncio to open a connection to address. Timeout is in seconds.""" - self.timeout = timeout - conn = asyncio.open_connection(addr[0], addr[1]) + conn = asyncio.open_connection(*self._addr) self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout) async def read(self, length: int) -> bytearray: @@ -623,30 +622,30 @@ def write(self, data: Connection | str | bytes | bytearray) -> None: def close(self) -> None: """Close self.writer.""" if self.writer is not None: # If initialized - try: - self.writer.close() - except RuntimeError: # There is a case where event loop is closed - pass + self.writer.close() + + async def __aenter__(self) -> Self: + await self.connect() + return self - def __del__(self) -> None: - """Close self.""" + async def __aexit__(self, *_) -> None: self.close() class UDPAsyncSocketConnection(BaseAsyncConnection): """Asynchronous UDP Connection class""" - __slots__ = ("stream", "timeout") + __slots__ = ("stream", "timeout", "_addr") - def __init__(self) -> None: + def __init__(self, addr: Address, timeout: float = 3) -> None: # This will only be None until connect is called, ignore the None type assignment self.stream: asyncio_dgram.aio.DatagramClient = None # type: ignore[assignment] - self.timeout: float = None # type: ignore[assignment] + self.timeout: float = timeout + self._addr = addr - async def connect(self, addr: Address, timeout: float = 3) -> None: + async def connect(self) -> None: """Connect to address. Timeout is in seconds.""" - self.timeout = timeout - conn = asyncio_dgram.connect((addr[0], addr[1])) + conn = asyncio_dgram.connect(self._addr) self.stream = await asyncio.wait_for(conn, timeout=self.timeout) def remaining(self) -> int: @@ -671,9 +670,9 @@ def close(self) -> None: if self.stream is not None: # If initialized self.stream.close() - def __del__(self) -> None: - """Close self.stream""" - try: - self.close() - except (asyncio.exceptions.CancelledError, asyncio.exceptions.TimeoutError): - return + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *_) -> None: + self.close() diff --git a/mcstatus/server.py b/mcstatus/server.py index 92f14c55..f96c07d0 100644 --- a/mcstatus/server.py +++ b/mcstatus/server.py @@ -91,8 +91,8 @@ def ping(self, **kwargs) -> float: :return: The latency between the Minecraft Server and you. """ - connection = TCPSocketConnection(self.address, self.timeout) - return self._retry_ping(connection, **kwargs) + with TCPSocketConnection(self.address, self.timeout) as connection: + return self._retry_ping(connection, **kwargs) @retry(tries=3) def _retry_ping(self, connection: TCPSocketConnection, **kwargs) -> float: @@ -107,9 +107,8 @@ async def async_ping(self, **kwargs) -> float: :return: The latency between the Minecraft Server and you. """ - connection = TCPAsyncSocketConnection() - await connection.connect(self.address, self.timeout) - return await self._retry_async_ping(connection, **kwargs) + async with TCPAsyncSocketConnection(self.address, self.timeout) as connection: + return await self._retry_async_ping(connection, **kwargs) @retry(tries=3) async def _retry_async_ping(self, connection: TCPAsyncSocketConnection, **kwargs) -> float: @@ -125,8 +124,8 @@ def status(self, **kwargs) -> PingResponse: :return: Status information in a `PingResponse` instance. """ - connection = TCPSocketConnection(self.address, self.timeout) - return self._retry_status(connection, **kwargs) + with TCPSocketConnection(self.address, self.timeout) as connection: + return self._retry_status(connection, **kwargs) @retry(tries=3) def _retry_status(self, connection: TCPSocketConnection, **kwargs) -> PingResponse: @@ -142,9 +141,8 @@ async def async_status(self, **kwargs) -> PingResponse: :return: Status information in a `PingResponse` instance. """ - connection = TCPAsyncSocketConnection() - await connection.connect(self.address, self.timeout) - return await self._retry_async_status(connection, **kwargs) + async with TCPAsyncSocketConnection(self.address, self.timeout) as connection: + return await self._retry_async_status(connection, **kwargs) @retry(tries=3) async def _retry_async_status(self, connection: TCPAsyncSocketConnection, **kwargs) -> PingResponse: @@ -169,10 +167,10 @@ def query(self) -> QueryResponse: @retry(tries=3) def _retry_query(self, addr: Address) -> QueryResponse: - connection = UDPSocketConnection(addr, self.timeout) - querier = ServerQuerier(connection) - querier.handshake() - return querier.read_query() + with UDPSocketConnection(addr, self.timeout) as connection: + querier = ServerQuerier(connection) + querier.handshake() + return querier.read_query() async def async_query(self) -> QueryResponse: """Asynchronously checks the status of a Minecraft Java Edition server via the query protocol.""" @@ -190,11 +188,10 @@ async def async_query(self) -> QueryResponse: @retry(tries=3) async def _retry_async_query(self, address: Address) -> QueryResponse: - connection = UDPAsyncSocketConnection() - await connection.connect(address, self.timeout) - querier = AsyncServerQuerier(connection) - await querier.handshake() - return await querier.read_query() + async with UDPAsyncSocketConnection(address, self.timeout) as connection: + querier = AsyncServerQuerier(connection) + await querier.handshake() + return await querier.read_query() class BedrockServer(MCServer): diff --git a/tests/protocol/test_connection.py b/tests/protocol/test_connection.py index 0ecb62ab..0a183b9f 100644 --- a/tests/protocol/test_connection.py +++ b/tests/protocol/test_connection.py @@ -206,76 +206,80 @@ def test_write_buffer(self): class TestTCPSocketConnection: - def setup_method(self): - self.test_addr = Address("localhost", 1234) + @pytest.fixture(scope="class") + def connection(self): + test_addr = Address("localhost", 1234) socket = Mock() socket.recv = Mock() socket.send = Mock() with patch("socket.create_connection") as create_connection: create_connection.return_value = socket - self.connection = TCPSocketConnection(self.test_addr) + with TCPSocketConnection(test_addr) as connection: + yield connection - def test_flush(self): + def test_flush(self, connection): with pytest.raises(TypeError): - self.connection.flush() + connection.flush() - def test_receive(self): + def test_receive(self, connection): with pytest.raises(TypeError): - self.connection.receive("") # type: ignore # This is desired to produce TypeError + connection.receive("") # type: ignore # This is desired to produce TypeError - def test_remaining(self): + def test_remaining(self, connection): with pytest.raises(TypeError): - self.connection.remaining() + connection.remaining() - def test_read(self): - self.connection.socket.recv.return_value = bytearray.fromhex("7FAA") + def test_read(self, connection): + connection.socket.recv.return_value = bytearray.fromhex("7FAA") - assert self.connection.read(2) == bytearray.fromhex("7FAA") + assert connection.read(2) == bytearray.fromhex("7FAA") - def test_read_empty(self): - self.connection.socket.recv.return_value = bytearray.fromhex("") + def test_read_empty(self, connection): + connection.socket.recv.return_value = bytearray.fromhex("") with pytest.raises(IOError): - self.connection.read(2) + connection.read(2) - def test_write(self): - self.connection.write(bytearray.fromhex("7FAA")) + def test_write(self, connection): + connection.write(bytearray.fromhex("7FAA")) - self.connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) # type: ignore[attr-defined] + connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) # type: ignore[attr-defined] class TestUDPSocketConnection: - def setup_method(self): - self.test_addr = Address("localhost", 1234) + @pytest.fixture(scope="class") + def connection(self): + test_addr = Address("localhost", 1234) socket = Mock() socket.recvfrom = Mock() socket.sendto = Mock() with patch("socket.socket") as create_socket: create_socket.return_value = socket - self.connection = UDPSocketConnection(self.test_addr) + with UDPSocketConnection(test_addr) as connection: + yield connection - def test_flush(self): + def test_flush(self, connection): with pytest.raises(TypeError): - self.connection.flush() + connection.flush() - def test_receive(self): + def test_receive(self, connection): with pytest.raises(TypeError): - self.connection.receive("") # type: ignore # This is desired to produce TypeError + connection.receive("") # type: ignore # This is desired to produce TypeError - def test_remaining(self): - assert self.connection.remaining() == 65535 + def test_remaining(self, connection): + assert connection.remaining() == 65535 - def test_read(self): - self.connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")] + def test_read(self, connection): + connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")] - assert self.connection.read(2) == bytearray.fromhex("7FAA") + assert connection.read(2) == bytearray.fromhex("7FAA") - def test_write(self): - self.connection.write(bytearray.fromhex("7FAA")) + def test_write(self, connection): + connection.write(bytearray.fromhex("7FAA")) - self.connection.socket.sendto.assert_called_once_with( # type: ignore[attr-defined] + connection.socket.sendto.assert_called_once_with( # type: ignore[attr-defined] bytearray.fromhex("7FAA"), - self.test_addr, + Address("localhost", 1234), ) diff --git a/tests/test_async_support.py b/tests/test_async_support.py index 2b45570c..8e9491fc 100644 --- a/tests/test_async_support.py +++ b/tests/test_async_support.py @@ -4,7 +4,7 @@ def test_is_completely_asynchronous(): - conn = TCPAsyncSocketConnection() + conn = TCPAsyncSocketConnection assertions = 0 for attribute in dir(conn): if attribute.startswith("read_"): @@ -14,7 +14,7 @@ def test_is_completely_asynchronous(): def test_query_is_completely_asynchronous(): - conn = UDPAsyncSocketConnection() + conn = UDPAsyncSocketConnection assertions = 0 for attribute in dir(conn): if attribute.startswith("read_"): diff --git a/tests/test_server.py b/tests/test_server.py index e4029622..b8d828a8 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -102,7 +102,7 @@ def test_ping(self): self.socket.receive(bytearray.fromhex("09010000000001C54246")) with patch("mcstatus.server.TCPSocketConnection") as connection: - connection.return_value = self.socket + connection.return_value.__enter__.return_value = self.socket latency = self.server.ping(ping_token=29704774, version=47) assert self.socket.flush() == bytearray.fromhex("0F002F096C6F63616C686F737463DD0109010000000001C54246") @@ -110,13 +110,12 @@ def test_ping(self): assert latency >= 0 def test_ping_retry(self): - with patch("mcstatus.server.TCPSocketConnection") as connection: - connection.return_value = None - with patch("mcstatus.server.ServerPinger") as pinger: - pinger.side_effect = [Exception, Exception, Exception] - with pytest.raises(Exception): - self.server.ping() - assert pinger.call_count == 3 + # Use a blank mock for the connection, we don't want to actually create any connections + with patch("mcstatus.server.TCPSocketConnection"), patch("mcstatus.server.ServerPinger") as pinger: + pinger.side_effect = [Exception, Exception, Exception] + with pytest.raises(Exception): + self.server.ping() + assert pinger.call_count == 3 def test_status(self): self.socket.receive( @@ -128,7 +127,7 @@ def test_status(self): ) with patch("mcstatus.server.TCPSocketConnection") as connection: - connection.return_value = self.socket + connection.return_value.__enter__.return_value = self.socket info = self.server.status(version=47) assert self.socket.flush() == bytearray.fromhex("0F002F096C6F63616C686F737463DD010100") @@ -141,13 +140,12 @@ def test_status(self): assert info.latency >= 0 def test_status_retry(self): - with patch("mcstatus.server.TCPSocketConnection") as connection: - connection.return_value = None - with patch("mcstatus.server.ServerPinger") as pinger: - pinger.side_effect = [Exception, Exception, Exception] - with pytest.raises(Exception): - self.server.status() - assert pinger.call_count == 3 + # Use a blank mock for the connection, we don't want to actually create any connections + with patch("mcstatus.server.TCPSocketConnection"), patch("mcstatus.server.ServerPinger") as pinger: + pinger.side_effect = [Exception, Exception, Exception] + with pytest.raises(Exception): + self.server.status() + assert pinger.call_count == 3 def test_query(self): self.socket.receive(bytearray.fromhex("090000000035373033353037373800")) @@ -167,7 +165,7 @@ def test_query(self): with patch("mcstatus.server.UDPSocketConnection") as connection, patch.object( self.server.address, "resolve_ip" ) as resolve_ip: - connection.return_value = self.socket + connection.return_value.__enter__.return_value = self.socket resolve_ip.return_value = "127.0.0.1" info = self.server.query() @@ -187,14 +185,13 @@ def test_query(self): } def test_query_retry(self): - with patch("mcstatus.server.UDPSocketConnection") as connection: - connection.return_value = None - with patch("mcstatus.server.ServerQuerier") as querier: - querier.side_effect = [Exception, Exception, Exception] - with pytest.raises(Exception), patch.object(self.server.address, "resolve_ip") as resolve_ip: - resolve_ip.return_value = "127.0.0.1" - self.server.query() - assert querier.call_count == 3 + # Use a blank mock for the connection, we don't want to actually create any connections + with patch("mcstatus.server.UDPSocketConnection"), patch("mcstatus.server.ServerQuerier") as querier: + querier.side_effect = [Exception, Exception, Exception] + with pytest.raises(Exception), patch.object(self.server.address, "resolve_ip") as resolve_ip: + resolve_ip.return_value = "127.0.0.1" + self.server.query() + assert querier.call_count == 3 def test_lookup_constructor(self): s = JavaServer.lookup("example.org:4444") diff --git a/tests/test_timeout.py b/tests/test_timeout.py index bcea6e90..c6dc3548 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -20,13 +20,9 @@ async def fake_asyncio_asyncio_open_connection(hostname: str, port: int): class TestAsyncSocketConnection: - def setup_method(self): - self.tcp_async_socket = TCPAsyncSocketConnection() - self.test_addr = Address("dummy_address", 1234) - - def test_tcp_socket_read(self): + @pytest.mark.asyncio + async def test_tcp_socket_read(self): with patch("asyncio.open_connection", fake_asyncio_asyncio_open_connection): - asyncio.run(self.tcp_async_socket.connect(self.test_addr, timeout=0.01)) - - with pytest.raises(TimeoutError): - asyncio.run(self.tcp_async_socket.read(10)) + async with TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) as tcp_async_socket: + with pytest.raises(TimeoutError): + await tcp_async_socket.read(10)