Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor config entry storage and index #107590

Merged
merged 12 commits into from
Jan 13, 2024
191 changes: 136 additions & 55 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
from collections import UserDict
from collections.abc import (
Callable,
Coroutine,
Generator,
Iterable,
Mapping,
ValuesView,
)
from contextvars import ContextVar
from copy import deepcopy
from enum import Enum, StrEnum
Expand Down Expand Up @@ -336,6 +344,13 @@ def __init__(
self._tries = 0
self._setup_again_job: HassJob | None = None

def __repr__(self) -> str:
"""Representation of ConfigEntry."""
return (
f"<ConfigEntry entry_id={self.entry_id} version={self.version} domain={self.domain} "
f"title={self.title} state={self.state} unique_id={self.unique_id}>"
)

async def async_setup(
self,
hass: HomeAssistant,
Expand Down Expand Up @@ -1057,6 +1072,67 @@ def _async_discovery(self) -> None:
)


class ConfigEntryItems(UserDict[str, ConfigEntry]):
"""Container for config items, maps config_entry_id -> entry.

Maintains two additional indexes:
- domain -> list[ConfigEntry]
- domain -> unique_id -> ConfigEntry
"""

def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._domain_index: dict[str, list[ConfigEntry]] = {}
self._domain_unique_id_index: dict[str, dict[str, ConfigEntry]] = {}

def values(self) -> ValuesView[ConfigEntry]:
"""Return the underlying values to avoid __iter__ overhead."""
return self.data.values()

def __setitem__(self, entry_id: str, entry: ConfigEntry) -> None:
"""Add an item."""
data = self.data
if entry_id in data:
# This is likely a bug in a test that is adding the same entry twice.
# In the future, once we have fixed the tests, this will raise HomeAssistantError.
_LOGGER.error("An entry with the id %s already exists", entry_id)
self._unindex_entry(entry_id)
data[entry_id] = entry
self._domain_index.setdefault(entry.domain, []).append(entry)
if entry.unique_id is not None:
self._domain_unique_id_index.setdefault(entry.domain, {})[
entry.unique_id
] = entry

def _unindex_entry(self, entry_id: str) -> None:
"""Unindex an entry."""
entry = self.data[entry_id]
domain = entry.domain
self._domain_index[domain].remove(entry)
if not self._domain_index[domain]:
del self._domain_index[domain]
if (unique_id := entry.unique_id) is not None:
del self._domain_unique_id_index[domain][unique_id]
if not self._domain_unique_id_index[domain]:
del self._domain_unique_id_index[domain]

def __delitem__(self, entry_id: str) -> None:
"""Remove an item."""
self._unindex_entry(entry_id)
super().__delitem__(entry_id)

def get_entries_for_domain(self, domain: str) -> list[ConfigEntry]:
"""Get entries for a domain."""
return self._domain_index.get(domain, [])

def get_entry_by_domain_and_unique_id(
self, domain: str, unique_id: str
) -> ConfigEntry | None:
"""Get entry by domain and unique id."""
return self._domain_unique_id_index.get(domain, {}).get(unique_id)


class ConfigEntries:
"""Manage the configuration entries.

Expand All @@ -1069,8 +1145,7 @@ def __init__(self, hass: HomeAssistant, hass_config: ConfigType) -> None:
self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
self.options = OptionsFlowManager(hass)
self._hass_config = hass_config
self._entries: dict[str, ConfigEntry] = {}
self._domain_index: dict[str, list[ConfigEntry]] = {}
self._entries = ConfigEntryItems()
self._store = storage.Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY
)
Expand All @@ -1093,23 +1168,29 @@ def async_domains(
@callback
def async_get_entry(self, entry_id: str) -> ConfigEntry | None:
"""Return entry with matching entry_id."""
return self._entries.get(entry_id)
return self._entries.data.get(entry_id)

@callback
def async_entries(self, domain: str | None = None) -> list[ConfigEntry]:
"""Return all entries or entries for a specific domain."""
if domain is None:
return list(self._entries.values())
return list(self._domain_index.get(domain, []))
return list(self._entries.get_entries_for_domain(domain))

@callback
def async_entry_for_domain_unique_id(
self, domain: str, unique_id: str
) -> ConfigEntry | None:
"""Return entry for a domain with a matching unique id."""
return self._entries.get_entry_by_domain_and_unique_id(domain, unique_id)

async def async_add(self, entry: ConfigEntry) -> None:
"""Add and setup an entry."""
if entry.entry_id in self._entries:
if entry.entry_id in self._entries.data:
raise HomeAssistantError(
f"An entry with the id {entry.entry_id} already exists."
)
self._entries[entry.entry_id] = entry
self._domain_index.setdefault(entry.domain, []).append(entry)
self._async_dispatch(ConfigEntryChange.ADDED, entry)
await self.async_setup(entry.entry_id)
self._async_schedule_save()
Expand All @@ -1127,9 +1208,6 @@ async def async_remove(self, entry_id: str) -> dict[str, Any]:
await entry.async_remove(self.hass)

del self._entries[entry.entry_id]
self._domain_index[entry.domain].remove(entry)
if not self._domain_index[entry.domain]:
del self._domain_index[entry.domain]
self._async_schedule_save()

dev_reg = device_registry.async_get(self.hass)
Expand Down Expand Up @@ -1189,13 +1267,10 @@ async def async_initialize(self) -> None:
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown)

if config is None:
self._entries = {}
self._domain_index = {}
self._entries = ConfigEntryItems()
return

entries = {}
domain_index: dict[str, list[ConfigEntry]] = {}

entries: ConfigEntryItems = ConfigEntryItems()
for entry in config["entries"]:
pref_disable_new_entities = entry.get("pref_disable_new_entities")

Expand Down Expand Up @@ -1230,9 +1305,7 @@ async def async_initialize(self) -> None:
pref_disable_polling=entry.get("pref_disable_polling"),
)
entries[entry_id] = config_entry
domain_index.setdefault(domain, []).append(config_entry)

self._domain_index = domain_index
self._entries = entries

async def async_setup(self, entry_id: str) -> bool:
Expand Down Expand Up @@ -1365,8 +1438,15 @@ def async_update_entry(
"""
changed = False

if unique_id is not UNDEFINED and entry.unique_id != unique_id:
# Reindex the entry if the unique_id has changed
bdraco marked this conversation as resolved.
Show resolved Hide resolved
entry_id = entry.entry_id
del self._entries[entry_id]
entry.unique_id = unique_id
self._entries[entry_id] = entry
changed = True

for attr, value in (
("unique_id", unique_id),
("title", title),
("pref_disable_new_entities", pref_disable_new_entities),
("pref_disable_polling", pref_disable_polling),
Expand Down Expand Up @@ -1579,38 +1659,41 @@ def _abort_if_unique_id_configured(
if self.unique_id is None:
return

for entry in self._async_current_entries(include_ignore=True):
if entry.unique_id != self.unique_id:
continue
should_reload = False
if (
updates is not None
and self.hass.config_entries.async_update_entry(
entry, data={**entry.data, **updates}
)
and reload_on_update
and entry.state
in (ConfigEntryState.LOADED, ConfigEntryState.SETUP_RETRY)
):
# Existing config entry present, and the
# entry data just changed
should_reload = True
elif (
self.source in DISCOVERY_SOURCES
and entry.state is ConfigEntryState.SETUP_RETRY
):
# Existing config entry present in retry state, and we
# just discovered the unique id so we know its online
should_reload = True
# Allow ignored entries to be configured on manual user step
if entry.source == SOURCE_IGNORE and self.source == SOURCE_USER:
continue
if should_reload:
self.hass.async_create_task(
self.hass.config_entries.async_reload(entry.entry_id),
f"config entry reload {entry.title} {entry.domain} {entry.entry_id}",
)
raise data_entry_flow.AbortFlow(error)
if not (
entry := self.hass.config_entries.async_entry_for_domain_unique_id(
self.handler, self.unique_id
)
):
return

should_reload = False
if (
updates is not None
and self.hass.config_entries.async_update_entry(
entry, data={**entry.data, **updates}
)
and reload_on_update
and entry.state in (ConfigEntryState.LOADED, ConfigEntryState.SETUP_RETRY)
):
# Existing config entry present, and the
# entry data just changed
should_reload = True
elif (
self.source in DISCOVERY_SOURCES
and entry.state is ConfigEntryState.SETUP_RETRY
):
# Existing config entry present in retry state, and we
# just discovered the unique id so we know its online
should_reload = True
# Allow ignored entries to be configured on manual user step
if entry.source == SOURCE_IGNORE and self.source == SOURCE_USER:
return
if should_reload:
self.hass.async_create_task(
self.hass.config_entries.async_reload(entry.entry_id),
f"config entry reload {entry.title} {entry.domain} {entry.entry_id}",
)
raise data_entry_flow.AbortFlow(error)

async def async_set_unique_id(
self, unique_id: str | None = None, *, raise_on_progress: bool = True
Expand Down Expand Up @@ -1639,11 +1722,9 @@ async def async_set_unique_id(
):
self.hass.config_entries.flow.async_abort(progress["flow_id"])

for entry in self._async_current_entries(include_ignore=True):
if entry.unique_id == unique_id:
return entry

return None
return self.hass.config_entries.async_entry_for_domain_unique_id(
self.handler, unique_id
)

@callback
def _set_confirm_only(
Expand Down
2 changes: 0 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,12 +939,10 @@ def __init__(
def add_to_hass(self, hass: HomeAssistant) -> None:
"""Test helper to add entry to hass."""
hass.config_entries._entries[self.entry_id] = self
hass.config_entries._domain_index.setdefault(self.domain, []).append(self)

def add_to_manager(self, manager: config_entries.ConfigEntries) -> None:
"""Test helper to add entry to entry manager."""
manager._entries[self.entry_id] = self
manager._domain_index.setdefault(self.domain, []).append(self)


def patch_yaml_files(files_dict, endswith=True):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,9 @@ async def test_updating_entry_with_and_without_changes(
state=config_entries.ConfigEntryState.SETUP_ERROR,
)
entry.add_to_manager(manager)
assert "abc123" in str(entry)

assert manager.async_entry_for_domain_unique_id("test", "abc123") is entry

assert manager.async_update_entry(entry) is False

Expand All @@ -3138,6 +3141,10 @@ async def test_updating_entry_with_and_without_changes(
assert manager.async_update_entry(entry, **change) is True
assert manager.async_update_entry(entry, **change) is False

assert manager.async_entry_for_domain_unique_id("test", "abc123") is None
assert manager.async_entry_for_domain_unique_id("test", "abcd1234") is entry
assert "abcd1234" in str(entry)


async def test_entry_reload_calls_on_unload_listeners(
hass: HomeAssistant, manager: config_entries.ConfigEntries
Expand Down Expand Up @@ -4127,3 +4134,13 @@ async def async_step_user_confirm(self, user_input=None):
)

assert result["preview"] is None


def test_raise_trying_to_add_same_config_entry_twice(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test we log an error if trying to add same config entry twice."""
entry = MockConfigEntry(domain="test")
entry.add_to_hass(hass)
entry.add_to_hass(hass)
assert f"An entry with the id {entry.entry_id} already exists" in caplog.text