Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of explict close #467

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyhap/accessory.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def to_HAP(self, include_value: bool = True) -> Dict[str, Any]:
"""
return {
HAP_REPR_AID: self.aid,
HAP_REPR_SERVICES: [s.to_HAP(include_value=include_value) for s in self.services],
HAP_REPR_SERVICES: [
s.to_HAP(include_value=include_value) for s in self.services
],
}

def setup_message(self):
Expand Down Expand Up @@ -391,7 +393,10 @@ def to_HAP(self, include_value: bool = True) -> List[Dict[str, Any]]:

.. seealso:: Accessory.to_HAP
"""
return [acc.to_HAP(include_value=include_value) for acc in (super(), *self.accessories.values())]
return [
acc.to_HAP(include_value=include_value)
for acc in (super(), *self.accessories.values())
]

def get_characteristic(self, aid: int, iid: int) -> Optional["Characteristic"]:
""".. seealso:: Accessory.to_HAP"""
Expand Down
26 changes: 16 additions & 10 deletions pyhap/hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
connections: Dict[str, "HAPServerProtocol"],
accessory_driver: "AccessoryDriver",
) -> None:
self.loop: asyncio.AbstractEventLoop = loop
self.loop = loop
self.conn = h11.Connection(h11.SERVER)
self.connections = connections
self.accessory_driver = accessory_driver
Expand All @@ -55,7 +55,7 @@ def __init__(
self.transport: Optional[asyncio.Transport] = None

self.request: Optional[h11.Request] = None
self.request_body: Optional[bytes] = None
self.request_body: List[bytes] = []
self.response: Optional[HAPResponse] = None

self.last_activity: Optional[float] = None
Expand Down Expand Up @@ -246,27 +246,33 @@ def _process_one_event(self) -> bool:
logger.debug(
"%s (%s): h11 Event: %s", self.peername, self.handler.client_uuid, event
)
if event in (h11.NEED_DATA, h11.ConnectionClosed):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bug was missed because there was no explicit test and it looked like it was covered because we did match on NEED_DATA

if event is h11.NEED_DATA:
return False

if event is h11.PAUSED:
self.conn.start_next_cycle()
return True

if isinstance(event, h11.Request):
event_type = type(event)
if event_type is h11.ConnectionClosed:
return False

if event_type is h11.Request:
self.request = event
self.request_body = b""
self.request_body = []
return True

if isinstance(event, h11.Data):
self.request_body += event.data
if event_type is h11.Data:
if TYPE_CHECKING:
assert isinstance(event, h11.Data) # nosec
self.request_body.append(event.data)
return True

if isinstance(event, h11.EndOfMessage):
response = self.handler.dispatch(self.request, bytes(self.request_body))
if event_type is h11.EndOfMessage:
response = self.handler.dispatch(self.request, b"".join(self.request_body))
self._process_response(response)
self.request = None
self.request_body = None
self.request_body = []
return True

return self._handle_invalid_conn_state(f"Unexpected event: {event}")
Expand Down
30 changes: 19 additions & 11 deletions pyhap/hap_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
The HAPServer is the point of contact to and from the world.
"""

import asyncio
import logging
import time
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from .hap_protocol import HAPServerProtocol
from .util import callback

if TYPE_CHECKING:
from .accessory_driver import AccessoryDriver

logger = logging.getLogger(__name__)

IDLE_CONNECTION_CHECK_INTERVAL_SECONDS = 120
Expand All @@ -28,17 +33,18 @@ class HAPServer:
implements exclusive access to the send methods.
"""

def __init__(self, addr_port, accessory_handler):
def __init__(
self, addr_port: Tuple[str, int], accessory_handler: "AccessoryDriver"
) -> None:
"""Create a HAP Server."""
self._addr_port = addr_port
self.connections = {} # (address, port): socket
self.connections: Dict[Tuple[str, int], HAPServerProtocol] = {}
self.accessory_handler = accessory_handler
self.server = None
self._serve_task = None
self._connection_cleanup = None
self.loop = None
self.server: Optional[asyncio.Server] = None
self._connection_cleanup: Optional[asyncio.TimerHandle] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

async def async_start(self, loop):
async def async_start(self, loop: asyncio.AbstractEventLoop) -> None:
"""Start the http-hap server."""
self.loop = loop
self.server = await loop.create_server(
Expand All @@ -49,7 +55,7 @@ async def async_start(self, loop):
self.async_cleanup_connections()

@callback
def async_cleanup_connections(self):
def async_cleanup_connections(self) -> None:
"""Cleanup stale connections."""
now = time.time()
for hap_proto in list(self.connections.values()):
Expand All @@ -59,7 +65,7 @@ def async_cleanup_connections(self):
)

@callback
def async_stop(self):
def async_stop(self) -> None:
"""Stop the server.

This method must be run in the event loop.
Expand All @@ -70,10 +76,12 @@ def async_stop(self):
self.server.close()
self.connections.clear()

def push_event(self, data, client_addr, immediate=False):
def push_event(
self, data: bytes, client_addr: Tuple[str, int], immediate: bool = False
) -> bool:
"""Queue an event to the current connection with the provided data.

:param data: The charateristic changes
:param data: The characteristic changes
:type data: dict

:param client_addr: A client (address, port) tuple to which to send the data.
Expand Down
60 changes: 60 additions & 0 deletions tests/test_hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,30 @@

from pyhap import hap_handler, hap_protocol
from pyhap.accessory import Accessory, Bridge
from pyhap.accessory_driver import AccessoryDriver
from pyhap.hap_handler import HAPResponse


class MockTransport(asyncio.Transport): # pylint: disable=abstract-method
"""A mock transport."""

_is_closing: bool = False

def set_write_buffer_limits(self, high=None, low=None):
"""Set the write buffer limits."""

def write_eof(self) -> None:
"""Write EOF to the stream."""

def close(self) -> None:
"""Close the stream."""
self._is_closing = True

def is_closing(self) -> bool:
"""Return True if the transport is closing or closed."""
return self._is_closing


class MockHAPCrypto:
"""Mock HAPCrypto that only returns plaintext."""

Expand Down Expand Up @@ -734,3 +755,42 @@ async def test_does_not_timeout(driver):
assert writer.call_args_list[0][0][0].startswith(b"HTTP/1.1 200 OK\r\n") is True
hap_proto.check_idle(time.time())
assert hap_proto_close.called is False


def test_explicit_close(driver: AccessoryDriver):
"""Test an explicit connection close."""
loop = MagicMock()

transport = MockTransport()
connections = {}

acc = Accessory(driver, "TestAcc", aid=1)
assert acc.aid == 1
service = acc.driver.loader.get_service("TemperatureSensor")
acc.add_service(service)
driver.add_accessory(acc)

hap_proto = hap_protocol.HAPServerProtocol(loop, connections, driver)
hap_proto.connection_made(transport)

hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True
assert hap_proto.transport.is_closing() is False

with patch.object(hap_proto.transport, "write") as writer:
hap_proto.data_received(
b"GET /characteristics?id=3762173001.7 HTTP/1.1\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)
hap_proto.data_received(
b"GET /characteristics?id=1.5 HTTP/1.1\r\nConnection: close\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)

assert b"Content-Length:" in writer.call_args_list[0][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[0][0][0]
assert b"-70402" in writer.call_args_list[0][0][0]

assert b"Content-Length:" in writer.call_args_list[1][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[1][0][0]
assert b"TestAcc" in writer.call_args_list[1][0][0]

assert hap_proto.transport.is_closing() is True
Loading