Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for service responses when calling or creating services. #495

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions custom_components/pyscript/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import yaml

from homeassistant.core import SupportsResponse
from homeassistant.const import SERVICE_RELOAD
from homeassistant.helpers.service import async_set_service_schema

Expand Down Expand Up @@ -377,6 +378,7 @@ async def trigger_init(self, trig_ctx, func_name):
"time_trigger": {"kwargs": {dict}},
"task_unique": {"kill_me": {bool, int}},
"time_active": {"hold_off": {int, float}},
"service": {"supports_response": {str}},
"state_trigger": {
"kwargs": {dict},
"state_hold": {int, float},
Expand Down Expand Up @@ -485,11 +487,14 @@ async def pyscript_service_handler(call):
func_args.update(call.data)

async def do_service_call(func, ast_ctx, data):
await func.call(ast_ctx, **data)
retval = await func.call(ast_ctx, **data)
if ast_ctx.get_exception_obj():
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
return retval

Function.create_task(do_service_call(func, ast_ctx, func_args))
task = Function.create_task(do_service_call(func, ast_ctx, func_args))
await task
return task.result()

return pyscript_service_handler

Expand All @@ -500,7 +505,7 @@ async def do_service_call(func, ast_ctx, data):
if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START):
raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service")
Function.service_register(
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self)
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SupportsResponse.NONE)
)
async_set_service_schema(Function.hass, domain, name, service_desc)
self.trigger_service.add(srv_name)
Expand Down
26 changes: 21 additions & 5 deletions custom_components/pyscript/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import traceback

from homeassistant.core import Context
from homeassistant.core import Context, SupportsResponse

from .const import LOGGER_PATH

Expand Down Expand Up @@ -324,14 +324,22 @@ async def service_call(cls, domain, name, **kwargs):
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default

await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, name) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, name, kwargs, **hass_args)

@classmethod
async def service_completions(cls, root):
Expand Down Expand Up @@ -394,6 +402,7 @@ async def service_call(*args, **kwargs):
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
Expand All @@ -404,7 +413,14 @@ async def service_call(*args, **kwargs):
if len(args) != 0:
raise TypeError(f"service {domain}.{service} takes only keyword arguments")

await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

return service_call

Expand Down Expand Up @@ -450,7 +466,7 @@ def create_task(cls, coro, ast_ctx=None):
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))

@classmethod
def service_register(cls, global_ctx_name, domain, service, callback):
def service_register(cls, global_ctx_name, domain, service, callback, supports_response = SupportsResponse.NONE):
"""Register a new service callback."""
key = f"{domain}.{service}"
if key not in cls.service_cnt:
Expand All @@ -462,7 +478,7 @@ def service_register(cls, global_ctx_name, domain, service, callback):
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
)
cls.service_cnt[key] += 1
cls.hass.services.async_register(domain, service, callback)
cls.hass.services.async_register(domain, service, callback, supports_response = supports_response)

@classmethod
def service_remove(cls, global_ctx_name, domain, service):
Expand Down
13 changes: 11 additions & 2 deletions custom_components/pyscript/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging

from homeassistant.core import Context
from homeassistant.core import Context, SupportsResponse
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
from homeassistant.helpers.service import async_get_all_descriptions

Expand Down Expand Up @@ -290,6 +290,7 @@ async def service_call(*args, **kwargs):
for keyword, typ, default in [
("context", [Context], Function.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
Expand All @@ -306,7 +307,15 @@ async def service_call(*args, **kwargs):
kwargs[param_name] = args[0]
elif len(args) != 0:
raise TypeError(f"service {domain}.{service} takes no positional arguments")
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

return service_call

Expand Down
Loading