diff --git a/tests/test_function.py b/tests/test_function.py index 08dfa90..332e1c3 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -13,7 +13,7 @@ from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, DOMAIN, FOLDER from custom_components.pyscript.function import Function from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED -from homeassistant.core import Context +from homeassistant.core import Context, ServiceRegistry from homeassistant.setup import async_setup_component @@ -95,10 +95,9 @@ async def test_func_completions( @pytest.mark.asyncio async def test_service_completions(root, expected, hass, services): # pylint: disable=redefined-outer-name """Test service name completion.""" - with patch.object(Function, "hass", hass): - for domain, service_set in services.items(): - for service in service_set: - hass.services.async_register(domain, service, None) + with patch.object(ServiceRegistry, "async_services", return_value=services), patch.object( + Function, "hass", hass + ): words = await Function.service_completions(root) assert words == expected @@ -1248,48 +1247,42 @@ def service_call_exception(): @pytest.mark.asyncio async def test_service_call_params(hass): """Test that hass params get set properly on service calls.""" - try: - with patch.object(hass.services, "async_call") as call, patch.object( - Function, "service_has_service", return_value=True - ), patch.object( - hass.services, - "supports_response", - return_value="none", - ): - Function.init(hass) - await Function.service_call( - "test", "test", context=Context(id="test"), blocking=True, other_service_data="test" - ) - assert call.called - assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) - assert call.call_args[1] == {"context": Context(id="test"), "blocking": True} - call.reset_mock() - - await Function.service_call( - "test", "test", context=Context(id="test"), blocking=False, other_service_data="test" - ) - assert call.called - assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) - assert call.call_args[1] == {"context": Context(id="test"), "blocking": False} - call.reset_mock() - - await Function.get("test.test")( - context=Context(id="test"), blocking=True, other_service_data="test" - ) - assert call.called - assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) - assert call.call_args[1] == {"context": Context(id="test"), "blocking": True} - call.reset_mock() - - await Function.get("test.test")( - context=Context(id="test"), blocking=False, other_service_data="test" - ) - assert call.called - assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) - assert call.call_args[1] == {"context": Context(id="test"), "blocking": False} - except AttributeError as e: - # ignore cleanup exception - assert str(e) == "'ServiceRegistry' object attribute 'async_call' is read-only" + with patch.object(ServiceRegistry, "async_call") as call, patch.object( + Function, "service_has_service", return_value=True + ), patch.object( + ServiceRegistry, + "supports_response", + return_value="none", + ): + Function.init(hass) + await Function.service_call( + "test", "test", context=Context(id="test"), blocking=True, other_service_data="test" + ) + assert call.called + assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) + assert call.call_args[1] == {"context": Context(id="test"), "blocking": True} + call.reset_mock() + + await Function.service_call( + "test", "test", context=Context(id="test"), blocking=False, other_service_data="test" + ) + assert call.called + assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) + assert call.call_args[1] == {"context": Context(id="test"), "blocking": False} + call.reset_mock() + + await Function.get("test.test")(context=Context(id="test"), blocking=True, other_service_data="test") + assert call.called + assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) + assert call.call_args[1] == {"context": Context(id="test"), "blocking": True} + call.reset_mock() + + await Function.get("test.test")( + context=Context(id="test"), blocking=False, other_service_data="test" + ) + assert call.called + assert call.call_args[0] == ("test", "test", {"other_service_data": "test"}) + assert call.call_args[1] == {"context": Context(id="test"), "blocking": False} # Stop all tasks to avoid conflicts with other tests await Function.waiter_stop() diff --git a/tests/test_state.py b/tests/test_state.py index 957104c..5a07070 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -6,55 +6,48 @@ from custom_components.pyscript.function import Function from custom_components.pyscript.state import State -from homeassistant.core import Context +from homeassistant.core import Context, ServiceRegistry, StateMachine from homeassistant.helpers.state import State as HassState @pytest.mark.asyncio async def test_service_call(hass): """Test calling a service using the entity_id as a property.""" - try: - with patch( - "custom_components.pyscript.state.async_get_all_descriptions", - return_value={ - "test": { - "test": { - "description": None, - "fields": {"entity_id": "blah", "other_service_data": "blah"}, - } - } - }, - ), patch.object(hass.states, "get", return_value=HassState("test.entity", "True")), patch.object( - hass.services, "async_call" - ) as call: - State.init(hass) - Function.init(hass) - await State.get_service_params() - - func = State.get("test.entity.test") - await func(context=Context(id="test"), blocking=True, limit=1, other_service_data="test") - assert call.called - assert call.call_args[0] == ( - "test", - "test", - {"other_service_data": "test", "entity_id": "test.entity"}, - ) - assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1} - call.reset_mock() - - func = State.get("test.entity.test") - await func(context=Context(id="test"), blocking=False, other_service_data="test") - assert call.called - assert call.call_args[0] == ( - "test", - "test", - {"other_service_data": "test", "entity_id": "test.entity"}, - ) - assert call.call_args[1] == {"context": Context(id="test"), "blocking": False} - except AttributeError as e: - # ignore cleanup exception - assert str(e) == "'StateMachine' object attribute 'get' is read-only" - - # Stop all tasks to avoid conflicts with other tests - await Function.waiter_stop() - await Function.reaper_stop() + with patch( + "custom_components.pyscript.state.async_get_all_descriptions", + return_value={ + "test": { + "test": {"description": None, "fields": {"entity_id": "blah", "other_service_data": "blah"}} + } + }, + ), patch.object(StateMachine, "get", return_value=HassState("test.entity", "True")), patch.object( + ServiceRegistry, "async_call" + ) as call: + State.init(hass) + Function.init(hass) + await State.get_service_params() + + func = State.get("test.entity.test") + await func(context=Context(id="test"), blocking=True, limit=1, other_service_data="test") + assert call.called + assert call.call_args[0] == ( + "test", + "test", + {"other_service_data": "test", "entity_id": "test.entity"}, + ) + assert call.call_args[1] == {"context": Context(id="test"), "blocking": True, "limit": 1} + call.reset_mock() + + func = State.get("test.entity.test") + await func(context=Context(id="test"), blocking=False, other_service_data="test") + assert call.called + assert call.call_args[0] == ( + "test", + "test", + {"other_service_data": "test", "entity_id": "test.entity"}, + ) + assert call.call_args[1] == {"context": Context(id="test"), "blocking": False} + + # Stop all tasks to avoid conflicts with other tests + await Function.waiter_stop() + await Function.reaper_stop()