Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use with blocks for closing and opening connections #422

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 30 additions & 31 deletions mcstatus/protocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mcstatus.address import Address

if TYPE_CHECKING:
from typing_extensions import SupportsIndex, TypeAlias
from typing_extensions import Self, SupportsIndex, TypeAlias

BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]"

Expand Down Expand Up @@ -512,12 +512,11 @@ def close(self) -> None:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()

def __del__(self) -> None:
"""Close self.socket."""
try:
self.close()
except OSError: # Probably, the socket was already closed by the OS
pass
def __enter__(self) -> Self:
return self

def __exit__(self, *_) -> None:
self.close()


class TCPSocketConnection(SocketConnection):
Expand Down Expand Up @@ -588,18 +587,18 @@ def write(self, data: Connection | str | bytes | bytearray) -> None:
class TCPAsyncSocketConnection(BaseAsyncReadSyncWriteConnection):
"""Asynchronous TCP Connection class"""

__slots__ = ("reader", "writer", "timeout")
__slots__ = ("reader", "writer", "timeout", "_addr")

def __init__(self) -> None:
def __init__(self, addr: Address, timeout: float = 3) -> None:
# These will only be None until connect is called, ignore the None type assignment
self.reader: asyncio.StreamReader = None # type: ignore[assignment]
self.writer: asyncio.StreamWriter = None # type: ignore[assignment]
self.timeout: float = None # type: ignore[assignment]
self.timeout: float = timeout
self._addr = addr

async def connect(self, addr: Address, timeout: float = 3) -> None:
async def connect(self) -> None:
"""Use asyncio to open a connection to address. Timeout is in seconds."""
self.timeout = timeout
conn = asyncio.open_connection(addr[0], addr[1])
conn = asyncio.open_connection(*self._addr)
self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout)

async def read(self, length: int) -> bytearray:
Expand All @@ -623,30 +622,30 @@ def write(self, data: Connection | str | bytes | bytearray) -> None:
def close(self) -> None:
"""Close self.writer."""
if self.writer is not None: # If initialized
try:
self.writer.close()
except RuntimeError: # There is a case where event loop is closed
pass
self.writer.close()

async def __aenter__(self) -> Self:
await self.connect()
return self

def __del__(self) -> None:
"""Close self."""
async def __aexit__(self, *_) -> None:
self.close()


class UDPAsyncSocketConnection(BaseAsyncConnection):
"""Asynchronous UDP Connection class"""

__slots__ = ("stream", "timeout")
__slots__ = ("stream", "timeout", "_addr")

def __init__(self) -> None:
def __init__(self, addr: Address, timeout: float = 3) -> None:
# This will only be None until connect is called, ignore the None type assignment
self.stream: asyncio_dgram.aio.DatagramClient = None # type: ignore[assignment]
self.timeout: float = None # type: ignore[assignment]
self.timeout: float = timeout
self._addr = addr

async def connect(self, addr: Address, timeout: float = 3) -> None:
async def connect(self) -> None:
"""Connect to address. Timeout is in seconds."""
self.timeout = timeout
conn = asyncio_dgram.connect((addr[0], addr[1]))
conn = asyncio_dgram.connect(self._addr)
self.stream = await asyncio.wait_for(conn, timeout=self.timeout)

def remaining(self) -> int:
Expand All @@ -671,9 +670,9 @@ def close(self) -> None:
if self.stream is not None: # If initialized
self.stream.close()

def __del__(self) -> None:
"""Close self.stream"""
try:
self.close()
except (asyncio.exceptions.CancelledError, asyncio.exceptions.TimeoutError):
return
async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *_) -> None:
self.close()
35 changes: 16 additions & 19 deletions mcstatus/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def ping(self, **kwargs) -> float:
:return: The latency between the Minecraft Server and you.
"""

connection = TCPSocketConnection(self.address, self.timeout)
return self._retry_ping(connection, **kwargs)
with TCPSocketConnection(self.address, self.timeout) as connection:
return self._retry_ping(connection, **kwargs)

@retry(tries=3)
def _retry_ping(self, connection: TCPSocketConnection, **kwargs) -> float:
Expand All @@ -107,9 +107,8 @@ async def async_ping(self, **kwargs) -> float:
:return: The latency between the Minecraft Server and you.
"""

connection = TCPAsyncSocketConnection()
await connection.connect(self.address, self.timeout)
return await self._retry_async_ping(connection, **kwargs)
async with TCPAsyncSocketConnection(self.address, self.timeout) as connection:
return await self._retry_async_ping(connection, **kwargs)

@retry(tries=3)
async def _retry_async_ping(self, connection: TCPAsyncSocketConnection, **kwargs) -> float:
Expand All @@ -125,8 +124,8 @@ def status(self, **kwargs) -> PingResponse:
:return: Status information in a `PingResponse` instance.
"""

connection = TCPSocketConnection(self.address, self.timeout)
return self._retry_status(connection, **kwargs)
with TCPSocketConnection(self.address, self.timeout) as connection:
return self._retry_status(connection, **kwargs)

@retry(tries=3)
def _retry_status(self, connection: TCPSocketConnection, **kwargs) -> PingResponse:
Expand All @@ -142,9 +141,8 @@ async def async_status(self, **kwargs) -> PingResponse:
:return: Status information in a `PingResponse` instance.
"""

connection = TCPAsyncSocketConnection()
await connection.connect(self.address, self.timeout)
return await self._retry_async_status(connection, **kwargs)
async with TCPAsyncSocketConnection(self.address, self.timeout) as connection:
return await self._retry_async_status(connection, **kwargs)

@retry(tries=3)
async def _retry_async_status(self, connection: TCPAsyncSocketConnection, **kwargs) -> PingResponse:
Expand All @@ -169,10 +167,10 @@ def query(self) -> QueryResponse:

@retry(tries=3)
def _retry_query(self, addr: Address) -> QueryResponse:
connection = UDPSocketConnection(addr, self.timeout)
querier = ServerQuerier(connection)
querier.handshake()
return querier.read_query()
with UDPSocketConnection(addr, self.timeout) as connection:
querier = ServerQuerier(connection)
querier.handshake()
return querier.read_query()

async def async_query(self) -> QueryResponse:
"""Asynchronously checks the status of a Minecraft Java Edition server via the query protocol."""
Expand All @@ -190,11 +188,10 @@ async def async_query(self) -> QueryResponse:

@retry(tries=3)
async def _retry_async_query(self, address: Address) -> QueryResponse:
connection = UDPAsyncSocketConnection()
await connection.connect(address, self.timeout)
querier = AsyncServerQuerier(connection)
await querier.handshake()
return await querier.read_query()
async with UDPAsyncSocketConnection(address, self.timeout) as connection:
querier = AsyncServerQuerier(connection)
await querier.handshake()
return await querier.read_query()


class BedrockServer(MCServer):
Expand Down
72 changes: 38 additions & 34 deletions tests/protocol/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,76 +206,80 @@ def test_write_buffer(self):


class TestTCPSocketConnection:
def setup_method(self):
self.test_addr = Address("localhost", 1234)
@pytest.fixture(scope="class")
def connection(self):
test_addr = Address("localhost", 1234)

socket = Mock()
socket.recv = Mock()
socket.send = Mock()
with patch("socket.create_connection") as create_connection:
create_connection.return_value = socket
self.connection = TCPSocketConnection(self.test_addr)
with TCPSocketConnection(test_addr) as connection:
yield connection

def test_flush(self):
def test_flush(self, connection):
with pytest.raises(TypeError):
self.connection.flush()
connection.flush()

def test_receive(self):
def test_receive(self, connection):
with pytest.raises(TypeError):
self.connection.receive("") # type: ignore # This is desired to produce TypeError
connection.receive("") # type: ignore # This is desired to produce TypeError

def test_remaining(self):
def test_remaining(self, connection):
with pytest.raises(TypeError):
self.connection.remaining()
connection.remaining()

def test_read(self):
self.connection.socket.recv.return_value = bytearray.fromhex("7FAA")
def test_read(self, connection):
connection.socket.recv.return_value = bytearray.fromhex("7FAA")

assert self.connection.read(2) == bytearray.fromhex("7FAA")
assert connection.read(2) == bytearray.fromhex("7FAA")

def test_read_empty(self):
self.connection.socket.recv.return_value = bytearray.fromhex("")
def test_read_empty(self, connection):
connection.socket.recv.return_value = bytearray.fromhex("")

with pytest.raises(IOError):
self.connection.read(2)
connection.read(2)

def test_write(self):
self.connection.write(bytearray.fromhex("7FAA"))
def test_write(self, connection):
connection.write(bytearray.fromhex("7FAA"))

self.connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) # type: ignore[attr-defined]
connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) # type: ignore[attr-defined]


class TestUDPSocketConnection:
def setup_method(self):
self.test_addr = Address("localhost", 1234)
@pytest.fixture(scope="class")
def connection(self):
test_addr = Address("localhost", 1234)

socket = Mock()
socket.recvfrom = Mock()
socket.sendto = Mock()
with patch("socket.socket") as create_socket:
create_socket.return_value = socket
self.connection = UDPSocketConnection(self.test_addr)
with UDPSocketConnection(test_addr) as connection:
yield connection

def test_flush(self):
def test_flush(self, connection):
with pytest.raises(TypeError):
self.connection.flush()
connection.flush()

def test_receive(self):
def test_receive(self, connection):
with pytest.raises(TypeError):
self.connection.receive("") # type: ignore # This is desired to produce TypeError
connection.receive("") # type: ignore # This is desired to produce TypeError

def test_remaining(self):
assert self.connection.remaining() == 65535
def test_remaining(self, connection):
assert connection.remaining() == 65535

def test_read(self):
self.connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")]
def test_read(self, connection):
connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")]

assert self.connection.read(2) == bytearray.fromhex("7FAA")
assert connection.read(2) == bytearray.fromhex("7FAA")

def test_write(self):
self.connection.write(bytearray.fromhex("7FAA"))
def test_write(self, connection):
connection.write(bytearray.fromhex("7FAA"))

self.connection.socket.sendto.assert_called_once_with( # type: ignore[attr-defined]
connection.socket.sendto.assert_called_once_with( # type: ignore[attr-defined]
bytearray.fromhex("7FAA"),
self.test_addr,
Address("localhost", 1234),
)
4 changes: 2 additions & 2 deletions tests/test_async_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_is_completely_asynchronous():
conn = TCPAsyncSocketConnection()
conn = TCPAsyncSocketConnection
assertions = 0
for attribute in dir(conn):
if attribute.startswith("read_"):
Expand All @@ -14,7 +14,7 @@ def test_is_completely_asynchronous():


def test_query_is_completely_asynchronous():
conn = UDPAsyncSocketConnection()
conn = UDPAsyncSocketConnection
assertions = 0
for attribute in dir(conn):
if attribute.startswith("read_"):
Expand Down
Loading