Skip to content

Commit

Permalink
Use with blocks for closing and opening connections
Browse files Browse the repository at this point in the history
Fixes #337.
  • Loading branch information
PerchunPak committed Oct 16, 2022
1 parent 94bd9ef commit 4afc931
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 120 deletions.
61 changes: 30 additions & 31 deletions mcstatus/protocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mcstatus.address import Address

if TYPE_CHECKING:
from typing_extensions import SupportsIndex, TypeAlias
from typing_extensions import Self, SupportsIndex, TypeAlias

BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]"

Expand Down Expand Up @@ -512,12 +512,11 @@ def close(self) -> None:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()

def __del__(self) -> None:
"""Close self.socket."""
try:
self.close()
except OSError: # Probably, the socket was already closed by the OS
pass
def __enter__(self) -> Self:
return self

def __exit__(self, _, __, ___) -> None: # noqa: ANN001
self.close()


class TCPSocketConnection(SocketConnection):
Expand Down Expand Up @@ -588,18 +587,18 @@ def write(self, data: Connection | str | bytes | bytearray) -> None:
class TCPAsyncSocketConnection(BaseAsyncReadSyncWriteConnection):
"""Asynchronous TCP Connection class"""

__slots__ = ("reader", "writer", "timeout")
__slots__ = ("reader", "writer", "timeout", "_addr")

def __init__(self) -> None:
def __init__(self, addr: Address, timeout: float = 3) -> None:
# These will only be None until connect is called, ignore the None type assignment
self.reader: asyncio.StreamReader = None # type: ignore[assignment]
self.writer: asyncio.StreamWriter = None # type: ignore[assignment]
self.timeout: float = None # type: ignore[assignment]
self.timeout: float = timeout
self._addr = addr

async def connect(self, addr: Address, timeout: float = 3) -> None:
async def connect(self) -> None:
"""Use asyncio to open a connection to address. Timeout is in seconds."""
self.timeout = timeout
conn = asyncio.open_connection(addr[0], addr[1])
conn = asyncio.open_connection(*self._addr)
self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout)

async def read(self, length: int) -> bytearray:
Expand All @@ -623,30 +622,30 @@ def write(self, data: Connection | str | bytes | bytearray) -> None:
def close(self) -> None:
"""Close self.writer."""
if self.writer is not None: # If initialized
try:
self.writer.close()
except RuntimeError: # There is a case where event loop is closed
pass
self.writer.close()

async def __aenter__(self) -> Self:
await self.connect()
return self

def __del__(self) -> None:
"""Close self."""
async def __aexit__(self, _, __, ___) -> None: # noqa: ANN001
self.close()


class UDPAsyncSocketConnection(BaseAsyncConnection):
"""Asynchronous UDP Connection class"""

__slots__ = ("stream", "timeout")
__slots__ = ("stream", "timeout", "_addr")

def __init__(self) -> None:
def __init__(self, addr: Address, timeout: float = 3) -> None:
# This will only be None until connect is called, ignore the None type assignment
self.stream: asyncio_dgram.aio.DatagramClient = None # type: ignore[assignment]
self.timeout: float = None # type: ignore[assignment]
self.timeout: float = timeout
self._addr = addr

async def connect(self, addr: Address, timeout: float = 3) -> None:
async def connect(self) -> None:
"""Connect to address. Timeout is in seconds."""
self.timeout = timeout
conn = asyncio_dgram.connect((addr[0], addr[1]))
conn = asyncio_dgram.connect(self._addr)
self.stream = await asyncio.wait_for(conn, timeout=self.timeout)

def remaining(self) -> int:
Expand All @@ -671,9 +670,9 @@ def close(self) -> None:
if self.stream is not None: # If initialized
self.stream.close()

def __del__(self) -> None:
"""Close self.stream"""
try:
self.close()
except (asyncio.exceptions.CancelledError, asyncio.exceptions.TimeoutError):
return
async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, _, __, ___) -> None: # noqa: ANN001
self.close()
35 changes: 16 additions & 19 deletions mcstatus/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def ping(self, **kwargs) -> float:
:return: The latency between the Minecraft Server and you.
"""

connection = TCPSocketConnection(self.address, self.timeout)
return self._retry_ping(connection, **kwargs)
with TCPSocketConnection(self.address, self.timeout) as connection:
return self._retry_ping(connection, **kwargs)

@retry(tries=3)
def _retry_ping(self, connection: TCPSocketConnection, **kwargs) -> float:
Expand All @@ -107,9 +107,8 @@ async def async_ping(self, **kwargs) -> float:
:return: The latency between the Minecraft Server and you.
"""

connection = TCPAsyncSocketConnection()
await connection.connect(self.address, self.timeout)
return await self._retry_async_ping(connection, **kwargs)
async with TCPAsyncSocketConnection(self.address, self.timeout) as connection:
return await self._retry_async_ping(connection, **kwargs)

@retry(tries=3)
async def _retry_async_ping(self, connection: TCPAsyncSocketConnection, **kwargs) -> float:
Expand All @@ -125,8 +124,8 @@ def status(self, **kwargs) -> PingResponse:
:return: Status information in a `PingResponse` instance.
"""

connection = TCPSocketConnection(self.address, self.timeout)
return self._retry_status(connection, **kwargs)
with TCPSocketConnection(self.address, self.timeout) as connection:
return self._retry_status(connection, **kwargs)

@retry(tries=3)
def _retry_status(self, connection: TCPSocketConnection, **kwargs) -> PingResponse:
Expand All @@ -142,9 +141,8 @@ async def async_status(self, **kwargs) -> PingResponse:
:return: Status information in a `PingResponse` instance.
"""

connection = TCPAsyncSocketConnection()
await connection.connect(self.address, self.timeout)
return await self._retry_async_status(connection, **kwargs)
async with TCPAsyncSocketConnection(self.address, self.timeout) as connection:
return await self._retry_async_status(connection, **kwargs)

@retry(tries=3)
async def _retry_async_status(self, connection: TCPAsyncSocketConnection, **kwargs) -> PingResponse:
Expand All @@ -169,10 +167,10 @@ def query(self) -> QueryResponse:

@retry(tries=3)
def _retry_query(self, addr: Address) -> QueryResponse:
connection = UDPSocketConnection(addr, self.timeout)
querier = ServerQuerier(connection)
querier.handshake()
return querier.read_query()
with UDPSocketConnection(addr, self.timeout) as connection:
querier = ServerQuerier(connection)
querier.handshake()
return querier.read_query()

async def async_query(self) -> QueryResponse:
"""Asynchronously checks the status of a Minecraft Java Edition server via the query protocol."""
Expand All @@ -190,11 +188,10 @@ async def async_query(self) -> QueryResponse:

@retry(tries=3)
async def _retry_async_query(self, address: Address) -> QueryResponse:
connection = UDPAsyncSocketConnection()
await connection.connect(address, self.timeout)
querier = AsyncServerQuerier(connection)
await querier.handshake()
return await querier.read_query()
async with UDPAsyncSocketConnection(address, self.timeout) as connection:
querier = AsyncServerQuerier(connection)
await querier.handshake()
return await querier.read_query()


class BedrockServer(MCServer):
Expand Down
72 changes: 38 additions & 34 deletions tests/protocol/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,76 +206,80 @@ def test_write_buffer(self):


class TestTCPSocketConnection:
def setup_method(self):
self.test_addr = Address("localhost", 1234)
@pytest.fixture(scope="class")
def connection(self):
test_addr = Address("localhost", 1234)

socket = Mock()
socket.recv = Mock()
socket.send = Mock()
with patch("socket.create_connection") as create_connection:
create_connection.return_value = socket
self.connection = TCPSocketConnection(self.test_addr)
with TCPSocketConnection(test_addr) as connection:
yield connection

def test_flush(self):
def test_flush(self, connection):
with pytest.raises(TypeError):
self.connection.flush()
connection.flush()

def test_receive(self):
def test_receive(self, connection):
with pytest.raises(TypeError):
self.connection.receive("") # type: ignore # This is desired to produce TypeError
connection.receive("") # type: ignore # This is desired to produce TypeError

def test_remaining(self):
def test_remaining(self, connection):
with pytest.raises(TypeError):
self.connection.remaining()
connection.remaining()

def test_read(self):
self.connection.socket.recv.return_value = bytearray.fromhex("7FAA")
def test_read(self, connection):
connection.socket.recv.return_value = bytearray.fromhex("7FAA")

assert self.connection.read(2) == bytearray.fromhex("7FAA")
assert connection.read(2) == bytearray.fromhex("7FAA")

def test_read_empty(self):
self.connection.socket.recv.return_value = bytearray.fromhex("")
def test_read_empty(self, connection):
connection.socket.recv.return_value = bytearray.fromhex("")

with pytest.raises(IOError):
self.connection.read(2)
connection.read(2)

def test_write(self):
self.connection.write(bytearray.fromhex("7FAA"))
def test_write(self, connection):
connection.write(bytearray.fromhex("7FAA"))

self.connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) # type: ignore[attr-defined]
connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) # type: ignore[attr-defined]


class TestUDPSocketConnection:
def setup_method(self):
self.test_addr = Address("localhost", 1234)
@pytest.fixture(scope="class")
def connection(self):
test_addr = Address("localhost", 1234)

socket = Mock()
socket.recvfrom = Mock()
socket.sendto = Mock()
with patch("socket.socket") as create_socket:
create_socket.return_value = socket
self.connection = UDPSocketConnection(self.test_addr)
with UDPSocketConnection(test_addr) as connection:
yield connection

def test_flush(self):
def test_flush(self, connection):
with pytest.raises(TypeError):
self.connection.flush()
connection.flush()

def test_receive(self):
def test_receive(self, connection):
with pytest.raises(TypeError):
self.connection.receive("") # type: ignore # This is desired to produce TypeError
connection.receive("") # type: ignore # This is desired to produce TypeError

def test_remaining(self):
assert self.connection.remaining() == 65535
def test_remaining(self, connection):
assert connection.remaining() == 65535

def test_read(self):
self.connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")]
def test_read(self, connection):
connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")]

assert self.connection.read(2) == bytearray.fromhex("7FAA")
assert connection.read(2) == bytearray.fromhex("7FAA")

def test_write(self):
self.connection.write(bytearray.fromhex("7FAA"))
def test_write(self, connection):
connection.write(bytearray.fromhex("7FAA"))

self.connection.socket.sendto.assert_called_once_with( # type: ignore[attr-defined]
connection.socket.sendto.assert_called_once_with( # type: ignore[attr-defined]
bytearray.fromhex("7FAA"),
self.test_addr,
Address("localhost", 1234),
)
4 changes: 2 additions & 2 deletions tests/test_async_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_is_completely_asynchronous():
conn = TCPAsyncSocketConnection()
conn = TCPAsyncSocketConnection
assertions = 0
for attribute in dir(conn):
if attribute.startswith("read_"):
Expand All @@ -14,7 +14,7 @@ def test_is_completely_asynchronous():


def test_query_is_completely_asynchronous():
conn = UDPAsyncSocketConnection()
conn = UDPAsyncSocketConnection
assertions = 0
for attribute in dir(conn):
if attribute.startswith("read_"):
Expand Down
Loading

0 comments on commit 4afc931

Please sign in to comment.