Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return type to json_loads #85672

Merged
merged 18 commits into from
Feb 7, 2023
6 changes: 3 additions & 3 deletions homeassistant/components/conversation/default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from homeassistant import core, setup
from homeassistant.helpers import area_registry, entity_registry, intent, template
from homeassistant.helpers.json import json_loads
from homeassistant.helpers.json import JsonObjectType, json_loads_object

from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN
Expand All @@ -29,9 +29,9 @@
REGEX_TYPE = type(re.compile(""))


def json_load(fp: IO[str]) -> dict[str, Any]:
def json_load(fp: IO[str]) -> JsonObjectType:
"""Wrap json_loads for get_intents."""
return json_loads(fp.read())
return json_loads_object(fp.read())


@dataclass
Expand Down
10 changes: 5 additions & 5 deletions homeassistant/components/mobile_app/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from homeassistant.const import ATTR_DEVICE_ID, CONTENT_TYPE_JSON
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.json import JSONEncoder, json_loads
from homeassistant.helpers.json import JSONEncoder, JsonObjectType, json_loads_object

from .const import (
ATTR_APP_DATA,
Expand Down Expand Up @@ -71,7 +71,7 @@ def _decrypt_payload_helper(
ciphertext: str,
get_key_bytes: Callable[[str, int], str | bytes],
key_encoder,
) -> dict[str, str] | None:
) -> JsonObjectType | None:
"""Decrypt encrypted payload."""
try:
keylen, decrypt = setup_decrypt(key_encoder)
Expand All @@ -86,12 +86,12 @@ def _decrypt_payload_helper(
key_bytes = get_key_bytes(key, keylen)

msg_bytes = decrypt(ciphertext, key_bytes)
message = json_loads(msg_bytes)
message = json_loads_object(msg_bytes)
_LOGGER.debug("Successfully decrypted mobile_app payload")
return message


def _decrypt_payload(key: str | None, ciphertext: str) -> dict[str, str] | None:
def _decrypt_payload(key: str | None, ciphertext: str) -> JsonObjectType | None:
"""Decrypt encrypted payload."""

def get_key_bytes(key: str, keylen: int) -> str:
Expand All @@ -100,7 +100,7 @@ def get_key_bytes(key: str, keylen: int) -> str:
return _decrypt_payload_helper(key, ciphertext, get_key_bytes, HexEncoder)


def _decrypt_payload_legacy(key: str | None, ciphertext: str) -> dict[str, str] | None:
def _decrypt_payload_legacy(key: str | None, ciphertext: str) -> JsonObjectType | None:
"""Decrypt encrypted payload."""

def get_key_bytes(key: str, keylen: int) -> bytes:
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/components/mqtt/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.json import json_loads
from homeassistant.helpers.json import json_loads_object
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.loader import async_get_mqtt
Expand Down Expand Up @@ -126,7 +126,7 @@ async def async_discovery_message_received(msg: ReceiveMessage) -> None:

if payload:
try:
discovery_payload = MQTTDiscoveryPayload(json_loads(payload))
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
Expand Down Expand Up @@ -279,7 +279,7 @@ async def discovery_done(_: Any) -> None:
mqtt_data.last_discovery = time.time()
mqtt_integrations = await async_get_mqtt(hass)

for (integration, topics) in mqtt_integrations.items():
for integration, topics in mqtt_integrations.items():

async def async_integration_message_received(
integration: str, msg: ReceiveMessage
Expand Down
10 changes: 5 additions & 5 deletions homeassistant/components/mqtt/light/schema_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import json_dumps, json_loads
from homeassistant.helpers.json import json_dumps, json_loads_object
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
import homeassistant.util.color as color_util
Expand Down Expand Up @@ -349,7 +349,7 @@ def _prepare_subscribe_topics(self) -> None:
@log_messages(self.hass, self.entity_id)
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
values: dict[str, Any] = json_loads(msg.payload)
values = json_loads_object(msg.payload)

if values["state"] == "ON":
self._attr_is_on = True
Expand All @@ -375,7 +375,7 @@ def state_received(msg: ReceiveMessage) -> None:
if brightness_supported(self.supported_color_modes):
try:
self._attr_brightness = int(
values["brightness"]
values["brightness"] # type: ignore[operator]
/ float(self._config[CONF_BRIGHTNESS_SCALE])
* 255
)
Expand All @@ -397,7 +397,7 @@ def state_received(msg: ReceiveMessage) -> None:
if values["color_temp"] is None:
self._attr_color_temp = None
else:
self._attr_color_temp = int(values["color_temp"])
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
except KeyError:
pass
except ValueError:
Expand All @@ -408,7 +408,7 @@ def state_received(msg: ReceiveMessage) -> None:

if self.supported_features and LightEntityFeature.EFFECT:
with suppress(KeyError):
self._attr_effect = values["effect"]
self._attr_effect = cast(str, values["effect"])

get_mqtt_data(self.hass).state_write_requests.write_state_request(self)

Expand Down
8 changes: 6 additions & 2 deletions homeassistant/components/mqtt/siren.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads
from homeassistant.helpers.json import (
JSON_DECODE_EXCEPTIONS,
json_dumps,
json_loads_object,
)
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType

Expand Down Expand Up @@ -245,7 +249,7 @@ def state_message_received(msg: ReceiveMessage) -> None:
json_payload = {STATE: payload}
else:
try:
json_payload = json_loads(payload)
json_payload = json_loads_object(payload)
_LOGGER.debug(
(
"JSON payload detected after processing payload '%s' on"
Expand Down
10 changes: 5 additions & 5 deletions homeassistant/components/mqtt/vacuum/schema_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Support for a State MQTT vacuum."""
from __future__ import annotations

from typing import Any
from typing import Any, cast

import voluptuous as vol

Expand All @@ -25,7 +25,7 @@
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import json_dumps, json_loads
from homeassistant.helpers.json import json_dumps, json_loads_object
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from .. import subscription
Expand Down Expand Up @@ -240,12 +240,12 @@ def _prepare_subscribe_topics(self) -> None:
@log_messages(self.hass, self.entity_id)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle state MQTT message."""
payload: dict[str, Any] = json_loads(msg.payload)
payload = json_loads_object(msg.payload)
if STATE in payload and (
payload[STATE] in POSSIBLE_STATES or payload[STATE] is None
(state := payload[STATE]) in POSSIBLE_STATES or state is None
):
self._attr_state = (
POSSIBLE_STATES[payload[STATE]] if payload[STATE] else None
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
)
del payload[STATE]
self._update_state_attributes(payload)
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/recorder/db_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
json_bytes,
json_bytes_strip_null,
json_loads,
json_loads_object,
)
import homeassistant.util.dt as dt_util

Expand Down Expand Up @@ -209,7 +210,7 @@ def to_native(self, validate_entity_id: bool = True) -> Event | None:
try:
return Event(
self.event_type,
json_loads(self.event_data) if self.event_data else {},
json_loads_object(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
Expand Down Expand Up @@ -356,7 +357,7 @@ def to_native(self, validate_entity_id: bool = True) -> State | None:
parent_id=self.context_parent_id,
)
try:
attrs = json_loads(self.attributes) if self.attributes else {}
attrs = json_loads_object(self.attributes) if self.attributes else {}
except JSON_DECODE_EXCEPTIONS:
# When json_loads fails
_LOGGER.exception("Error converting row to state: %s", self)
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/recorder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
COMPRESSED_STATE_STATE,
)
from homeassistant.core import Context, State
from homeassistant.helpers.json import json_loads
from homeassistant.helpers.json import json_loads_object
import homeassistant.util.dt as dt_util

# pylint: disable=invalid-name
Expand Down Expand Up @@ -343,7 +343,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT:
return {}
try:
attr_cache[source] = attributes = json_loads(source)
attr_cache[source] = attributes = json_loads_object(source)
except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {}
Expand Down
19 changes: 19 additions & 0 deletions homeassistant/helpers/json.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Helpers to help with encoding Home Assistant objects in JSON."""
from collections.abc import Callable
import datetime
import json
from pathlib import Path
from typing import Any, Final

import orjson

JsonValueType = (
dict[str, "JsonValueType"] | list["JsonValueType"] | str | int | float | bool | None
)
"""Any data that can be returned by the standard JSON deserializing process."""
JsonObjectType = dict[str, JsonValueType]
"""Dictionary that can be returned by the standard JSON deserializing process."""

JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,)

Expand Down Expand Up @@ -132,7 +140,18 @@ def json_dumps_sorted(data: Any) -> str:
).decode("utf-8")


json_loads: Callable[[bytes | bytearray | memoryview | str], JsonValueType]
json_loads = orjson.loads
"""Parse JSON data."""


def json_loads_object(__obj: bytes | bytearray | memoryview | str) -> JsonObjectType:
"""Parse JSON data and ensure result is a dictionary."""
value: JsonValueType = json_loads(__obj)
# Avoid isinstance overhead as we are not interested in dict subclasses
if type(value) is dict: # pylint: disable=unidiomatic-typecheck
return value
raise ValueError(f"Expected JSON to be parsed as a dict got {type(value)}")


JSON_DUMP: Final = json_dumps
4 changes: 2 additions & 2 deletions homeassistant/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def async_get_integration_descriptions(
config_flow_path = pathlib.Path(base) / "integrations.json"

flow = await hass.async_add_executor_job(config_flow_path.read_text)
core_flows: dict[str, Any] = json_loads(flow)
core_flows = cast(dict[str, Any], json_loads(flow))
custom_integrations = await async_get_custom_components(hass)
custom_flows: dict[str, Any] = {
"integration": {},
Expand Down Expand Up @@ -476,7 +476,7 @@ def resolve_from_root(
continue

try:
manifest = json_loads(manifest_path.read_text())
manifest = cast(Manifest, json_loads(manifest_path.read_text()))
except JSON_DECODE_EXCEPTIONS as err:
_LOGGER.error(
"Error parsing manifest.json file at %s: %s", manifest_path, err
Expand Down
18 changes: 18 additions & 0 deletions tests/helpers/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
json_bytes_strip_null,
json_dumps,
json_dumps_sorted,
json_loads_object,
)
from homeassistant.util import dt as dt_util
from homeassistant.util.color import RGBColor
Expand Down Expand Up @@ -135,3 +136,20 @@ def test_json_bytes_strip_null():
json_bytes_strip_null([[{"k1": {"k2": ["silly\0stuff"]}}]])
== b'[[{"k1":{"k2":["silly"]}}]]'
)


def test_json_loads_object():
"""Test json_loads_object validates result."""
assert json_loads_object('{"c":1.2}') == {"c": 1.2}
with pytest.raises(
ValueError, match="Expected JSON to be parsed as a dict got <class 'list'>"
):
json_loads_object("[]")
with pytest.raises(
ValueError, match="Expected JSON to be parsed as a dict got <class 'bool'>"
):
json_loads_object("true")
with pytest.raises(
ValueError, match="Expected JSON to be parsed as a dict got <class 'NoneType'>"
):
json_loads_object("null")