From f4315b9fddf242278280927a0d5175a245ad04cc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 22 Oct 2023 09:13:58 -1000 Subject: [PATCH] Fix handling of explict close The check for next_event was not correct for explict closes as https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/_connection.py#L445 will only return h11.ConnectionClosed as an object and not a type --- pyhap/accessory.py | 9 ++++-- pyhap/hap_protocol.py | 26 ++++++++++------- pyhap/hap_server.py | 30 ++++++++++++------- tests/test_hap_protocol.py | 60 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 23 deletions(-) diff --git a/pyhap/accessory.py b/pyhap/accessory.py index f8a95f57..902fc7f7 100644 --- a/pyhap/accessory.py +++ b/pyhap/accessory.py @@ -241,7 +241,9 @@ def to_HAP(self, include_value: bool = True) -> Dict[str, Any]: """ return { HAP_REPR_AID: self.aid, - HAP_REPR_SERVICES: [s.to_HAP(include_value=include_value) for s in self.services], + HAP_REPR_SERVICES: [ + s.to_HAP(include_value=include_value) for s in self.services + ], } def setup_message(self): @@ -391,7 +393,10 @@ def to_HAP(self, include_value: bool = True) -> List[Dict[str, Any]]: .. seealso:: Accessory.to_HAP """ - return [acc.to_HAP(include_value=include_value) for acc in (super(), *self.accessories.values())] + return [ + acc.to_HAP(include_value=include_value) + for acc in (super(), *self.accessories.values()) + ] def get_characteristic(self, aid: int, iid: int) -> Optional["Characteristic"]: """.. seealso:: Accessory.to_HAP""" diff --git a/pyhap/hap_protocol.py b/pyhap/hap_protocol.py index b54db5e1..0f51dec9 100644 --- a/pyhap/hap_protocol.py +++ b/pyhap/hap_protocol.py @@ -46,7 +46,7 @@ def __init__( connections: Dict[str, "HAPServerProtocol"], accessory_driver: "AccessoryDriver", ) -> None: - self.loop: asyncio.AbstractEventLoop = loop + self.loop = loop self.conn = h11.Connection(h11.SERVER) self.connections = connections self.accessory_driver = accessory_driver @@ -55,7 +55,7 @@ def __init__( self.transport: Optional[asyncio.Transport] = None self.request: Optional[h11.Request] = None - self.request_body: Optional[bytes] = None + self.request_body: List[bytes] = [] self.response: Optional[HAPResponse] = None self.last_activity: Optional[float] = None @@ -246,27 +246,33 @@ def _process_one_event(self) -> bool: logger.debug( "%s (%s): h11 Event: %s", self.peername, self.handler.client_uuid, event ) - if event in (h11.NEED_DATA, h11.ConnectionClosed): + if event is h11.NEED_DATA: return False if event is h11.PAUSED: self.conn.start_next_cycle() return True - if isinstance(event, h11.Request): + event_type = type(event) + if event_type is h11.ConnectionClosed: + return False + + if event_type is h11.Request: self.request = event - self.request_body = b"" + self.request_body = [] return True - if isinstance(event, h11.Data): - self.request_body += event.data + if event_type is h11.Data: + if TYPE_CHECKING: + assert isinstance(event, h11.Data) # nosec + self.request_body.append(event.data) return True - if isinstance(event, h11.EndOfMessage): - response = self.handler.dispatch(self.request, bytes(self.request_body)) + if event_type is h11.EndOfMessage: + response = self.handler.dispatch(self.request, b"".join(self.request_body)) self._process_response(response) self.request = None - self.request_body = None + self.request_body = [] return True return self._handle_invalid_conn_state(f"Unexpected event: {event}") diff --git a/pyhap/hap_server.py b/pyhap/hap_server.py index 5016355e..1e8414a8 100644 --- a/pyhap/hap_server.py +++ b/pyhap/hap_server.py @@ -3,12 +3,17 @@ The HAPServer is the point of contact to and from the world. """ +import asyncio import logging import time +from typing import TYPE_CHECKING, Dict, Optional, Tuple from .hap_protocol import HAPServerProtocol from .util import callback +if TYPE_CHECKING: + from .accessory_driver import AccessoryDriver + logger = logging.getLogger(__name__) IDLE_CONNECTION_CHECK_INTERVAL_SECONDS = 120 @@ -28,17 +33,18 @@ class HAPServer: implements exclusive access to the send methods. """ - def __init__(self, addr_port, accessory_handler): + def __init__( + self, addr_port: Tuple[str, int], accessory_handler: "AccessoryDriver" + ) -> None: """Create a HAP Server.""" self._addr_port = addr_port - self.connections = {} # (address, port): socket + self.connections: Dict[Tuple[str, int], HAPServerProtocol] = {} self.accessory_handler = accessory_handler - self.server = None - self._serve_task = None - self._connection_cleanup = None - self.loop = None + self.server: Optional[asyncio.Server] = None + self._connection_cleanup: Optional[asyncio.TimerHandle] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None - async def async_start(self, loop): + async def async_start(self, loop: asyncio.AbstractEventLoop) -> None: """Start the http-hap server.""" self.loop = loop self.server = await loop.create_server( @@ -49,7 +55,7 @@ async def async_start(self, loop): self.async_cleanup_connections() @callback - def async_cleanup_connections(self): + def async_cleanup_connections(self) -> None: """Cleanup stale connections.""" now = time.time() for hap_proto in list(self.connections.values()): @@ -59,7 +65,7 @@ def async_cleanup_connections(self): ) @callback - def async_stop(self): + def async_stop(self) -> None: """Stop the server. This method must be run in the event loop. @@ -70,10 +76,12 @@ def async_stop(self): self.server.close() self.connections.clear() - def push_event(self, data, client_addr, immediate=False): + def push_event( + self, data: bytes, client_addr: Tuple[str, int], immediate: bool = False + ) -> bool: """Queue an event to the current connection with the provided data. - :param data: The charateristic changes + :param data: The characteristic changes :type data: dict :param client_addr: A client (address, port) tuple to which to send the data. diff --git a/tests/test_hap_protocol.py b/tests/test_hap_protocol.py index 0d2a58c3..06bd71b9 100644 --- a/tests/test_hap_protocol.py +++ b/tests/test_hap_protocol.py @@ -8,9 +8,30 @@ from pyhap import hap_handler, hap_protocol from pyhap.accessory import Accessory, Bridge +from pyhap.accessory_driver import AccessoryDriver from pyhap.hap_handler import HAPResponse +class MockTransport(asyncio.Transport): # pylint: disable=abstract-method + """A mock transport.""" + + _is_closing: bool = False + + def set_write_buffer_limits(self, high=None, low=None): + """Set the write buffer limits.""" + + def write_eof(self) -> None: + """Write EOF to the stream.""" + + def close(self) -> None: + """Close the stream.""" + self._is_closing = True + + def is_closing(self) -> bool: + """Return True if the transport is closing or closed.""" + return self._is_closing + + class MockHAPCrypto: """Mock HAPCrypto that only returns plaintext.""" @@ -734,3 +755,42 @@ async def test_does_not_timeout(driver): assert writer.call_args_list[0][0][0].startswith(b"HTTP/1.1 200 OK\r\n") is True hap_proto.check_idle(time.time()) assert hap_proto_close.called is False + + +def test_explicit_close(driver: AccessoryDriver): + """Test an explicit connection close.""" + loop = MagicMock() + + transport = MockTransport() + connections = {} + + acc = Accessory(driver, "TestAcc", aid=1) + assert acc.aid == 1 + service = acc.driver.loader.get_service("TemperatureSensor") + acc.add_service(service) + driver.add_accessory(acc) + + hap_proto = hap_protocol.HAPServerProtocol(loop, connections, driver) + hap_proto.connection_made(transport) + + hap_proto.hap_crypto = MockHAPCrypto() + hap_proto.handler.is_encrypted = True + assert hap_proto.transport.is_closing() is False + + with patch.object(hap_proto.transport, "write") as writer: + hap_proto.data_received( + b"GET /characteristics?id=3762173001.7 HTTP/1.1\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long + ) + hap_proto.data_received( + b"GET /characteristics?id=1.5 HTTP/1.1\r\nConnection: close\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long + ) + + assert b"Content-Length:" in writer.call_args_list[0][0][0] + assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[0][0][0] + assert b"-70402" in writer.call_args_list[0][0][0] + + assert b"Content-Length:" in writer.call_args_list[1][0][0] + assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[1][0][0] + assert b"TestAcc" in writer.call_args_list[1][0][0] + + assert hap_proto.transport.is_closing() is True