diff --git a/homeassistant/components/swiss_public_transport/__init__.py b/homeassistant/components/swiss_public_transport/__init__.py index 74a7d90cfb2673..1242c95269e199 100644 --- a/homeassistant/components/swiss_public_transport/__init__.py +++ b/homeassistant/components/swiss_public_transport/__init__.py @@ -14,8 +14,9 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession -from .const import CONF_DESTINATION, CONF_START, DOMAIN +from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS from .coordinator import SwissPublicTransportDataUpdateCoordinator +from .helper import unique_id_from_config _LOGGER = logging.getLogger(__name__) @@ -33,19 +34,28 @@ async def async_setup_entry( destination = config[CONF_DESTINATION] session = async_get_clientsession(hass) - opendata = OpendataTransport(start, destination, session) + opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA)) try: await opendata.async_get_data() except OpendataTransportConnectionError as e: raise ConfigEntryNotReady( - f"Timeout while connecting for entry '{start} {destination}'" + translation_domain=DOMAIN, + translation_key="request_timeout", + translation_placeholders={ + "config_title": entry.title, + "error": e, + }, ) from e except OpendataTransportError as e: raise ConfigEntryError( - f"Setup failed for entry '{start} {destination}' with invalid data, check " - "at http://transport.opendata.ch/examples/stationboard.html if your " - "station names are valid" + translation_domain=DOMAIN, + translation_key="invalid_data", + translation_placeholders={ + **PLACEHOLDERS, + "config_title": entry.title, + "error": e, + }, ) from e coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata) @@ -72,15 +82,13 @@ async def async_migrate_entry( """Migrate config entry.""" _LOGGER.debug("Migrating from version %s", config_entry.version) - if config_entry.minor_version > 3: + if config_entry.version > 2: # This means the user has downgraded from a future version return False - if config_entry.minor_version == 1: + if config_entry.version == 1 and config_entry.minor_version == 1: # Remove wrongly registered devices and entries - new_unique_id = ( - f"{config_entry.data[CONF_START]} {config_entry.data[CONF_DESTINATION]}" - ) + new_unique_id = unique_id_from_config(config_entry.data) entity_registry = er.async_get(hass) device_registry = dr.async_get(hass) device_entries = dr.async_entries_for_config_entry( @@ -109,6 +117,10 @@ async def async_migrate_entry( config_entry, unique_id=new_unique_id, minor_version=2 ) + if config_entry.version < 2: + # Via stations now available, which are not backwards compatible if used, changes unique id + hass.config_entries.async_update_entry(config_entry, version=2, minor_version=1) + _LOGGER.debug( "Migration to version %s.%s successful", config_entry.version, diff --git a/homeassistant/components/swiss_public_transport/config_flow.py b/homeassistant/components/swiss_public_transport/config_flow.py index bb852efd211287..74c6223f1d99a5 100644 --- a/homeassistant/components/swiss_public_transport/config_flow.py +++ b/homeassistant/components/swiss_public_transport/config_flow.py @@ -13,12 +13,24 @@ from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.selector import ( + TextSelector, + TextSelectorConfig, + TextSelectorType, +) -from .const import CONF_DESTINATION, CONF_START, DOMAIN, PLACEHOLDERS +from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS +from .helper import unique_id_from_config DATA_SCHEMA = vol.Schema( { vol.Required(CONF_START): cv.string, + vol.Optional(CONF_VIA): TextSelector( + TextSelectorConfig( + type=TextSelectorType.TEXT, + multiple=True, + ), + ), vol.Required(CONF_DESTINATION): cv.string, } ) @@ -29,8 +41,8 @@ class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN): """Swiss public transport config flow.""" - VERSION = 1 - MINOR_VERSION = 2 + VERSION = 2 + MINOR_VERSION = 1 async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -38,29 +50,34 @@ async def async_step_user( """Async user step to set up the connection.""" errors: dict[str, str] = {} if user_input is not None: - await self.async_set_unique_id( - f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}" - ) + unique_id = unique_id_from_config(user_input) + await self.async_set_unique_id(unique_id) self._abort_if_unique_id_configured() - session = async_get_clientsession(self.hass) - opendata = OpendataTransport( - user_input[CONF_START], user_input[CONF_DESTINATION], session - ) - try: - await opendata.async_get_data() - except OpendataTransportConnectionError: - errors["base"] = "cannot_connect" - except OpendataTransportError: - errors["base"] = "bad_config" - except Exception: - _LOGGER.exception("Unknown error") - errors["base"] = "unknown" + if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA: + errors["base"] = "too_many_via_stations" else: - return self.async_create_entry( - title=f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}", - data=user_input, + session = async_get_clientsession(self.hass) + opendata = OpendataTransport( + user_input[CONF_START], + user_input[CONF_DESTINATION], + session, + via=user_input.get(CONF_VIA), ) + try: + await opendata.async_get_data() + except OpendataTransportConnectionError: + errors["base"] = "cannot_connect" + except OpendataTransportError: + errors["base"] = "bad_config" + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unknown error") + errors["base"] = "unknown" + else: + return self.async_create_entry( + title=unique_id, + data=user_input, + ) return self.async_show_form( step_id="user", diff --git a/homeassistant/components/swiss_public_transport/const.py b/homeassistant/components/swiss_public_transport/const.py index 6ae3cc9fd2f337..32b6427ced57c4 100644 --- a/homeassistant/components/swiss_public_transport/const.py +++ b/homeassistant/components/swiss_public_transport/const.py @@ -1,12 +1,16 @@ """Constants for the swiss_public_transport integration.""" +from typing import Final + DOMAIN = "swiss_public_transport" -CONF_DESTINATION = "to" -CONF_START = "from" +CONF_DESTINATION: Final = "to" +CONF_START: Final = "from" +CONF_VIA: Final = "via" DEFAULT_NAME = "Next Destination" +MAX_VIA = 5 SENSOR_CONNECTIONS_COUNT = 3 diff --git a/homeassistant/components/swiss_public_transport/helper.py b/homeassistant/components/swiss_public_transport/helper.py new file mode 100644 index 00000000000000..af03f7ad193a25 --- /dev/null +++ b/homeassistant/components/swiss_public_transport/helper.py @@ -0,0 +1,15 @@ +"""Helper functions for swiss_public_transport.""" + +from types import MappingProxyType +from typing import Any + +from .const import CONF_DESTINATION, CONF_START, CONF_VIA + + +def unique_id_from_config(config: MappingProxyType[str, Any] | dict[str, Any]) -> str: + """Build a unique id from a config entry.""" + return f"{config[CONF_START]} {config[CONF_DESTINATION]}" + ( + " via " + ", ".join(config[CONF_VIA]) + if CONF_VIA in config and len(config[CONF_VIA]) > 0 + else "" + ) diff --git a/homeassistant/components/swiss_public_transport/strings.json b/homeassistant/components/swiss_public_transport/strings.json index 4732bb0f5274ec..4f4bc0522fc183 100644 --- a/homeassistant/components/swiss_public_transport/strings.json +++ b/homeassistant/components/swiss_public_transport/strings.json @@ -3,6 +3,7 @@ "error": { "cannot_connect": "Cannot connect to server", "bad_config": "Request failed due to bad config: Check at [stationboard]({stationboard_url}) if your station names are valid", + "too_many_via_stations": "Too many via stations, only up to 5 via stations are allowed per connection.", "unknown": "An unknown error was raised by python-opendata-transport" }, "abort": { @@ -15,9 +16,10 @@ "user": { "data": { "from": "Start station", - "to": "End station" + "to": "End station", + "via": "List of up to 5 via stations" }, - "description": "Provide start and end station for your connection\n\nCheck the [stationboard]({stationboard_url}) for valid stations.", + "description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.", "title": "Swiss Public Transport" } } @@ -46,5 +48,13 @@ "name": "Delay" } } + }, + "exceptions": { + "invalid_data": { + "message": "Setup failed for entry {config_title} with invalid data, check at the [stationboard]({stationboard_url}) if your station names are valid.\n{error}" + }, + "request_timeout": { + "message": "Timeout while connecting for entry {config_title}.\n{error}" + } } } diff --git a/tests/components/swiss_public_transport/test_config_flow.py b/tests/components/swiss_public_transport/test_config_flow.py index b728c87d4b08a6..027336e28a675a 100644 --- a/tests/components/swiss_public_transport/test_config_flow.py +++ b/tests/components/swiss_public_transport/test_config_flow.py @@ -12,7 +12,10 @@ from homeassistant.components.swiss_public_transport.const import ( CONF_DESTINATION, CONF_START, + CONF_VIA, + MAX_VIA, ) +from homeassistant.components.swiss_public_transport.helper import unique_id_from_config from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -25,8 +28,36 @@ CONF_DESTINATION: "test_destination", } +MOCK_DATA_STEP_ONE_VIA = { + **MOCK_DATA_STEP, + CONF_VIA: ["via_station"], +} + +MOCK_DATA_STEP_MANY_VIA = { + **MOCK_DATA_STEP, + CONF_VIA: ["via_station_1", "via_station_2", "via_station_3"], +} + +MOCK_DATA_STEP_TOO_MANY_STATIONS = { + **MOCK_DATA_STEP, + CONF_VIA: MOCK_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1), +} + -async def test_flow_user_init_data_success(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("user_input", "config_title"), + [ + (MOCK_DATA_STEP, "test_start test_destination"), + (MOCK_DATA_STEP_ONE_VIA, "test_start test_destination via via_station"), + ( + MOCK_DATA_STEP_MANY_VIA, + "test_start test_destination via via_station_1, via_station_2, via_station_3", + ), + ], +) +async def test_flow_user_init_data_success( + hass: HomeAssistant, user_input, config_title +) -> None: """Test success response.""" result = await hass.config_entries.flow.async_init( config_flow.DOMAIN, context={"source": "user"} @@ -47,25 +78,26 @@ async def test_flow_user_init_data_success(hass: HomeAssistant) -> None: ) result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=MOCK_DATA_STEP, + user_input=user_input, ) - assert result["type"] is FlowResultType.CREATE_ENTRY - assert result["result"].title == "test_start test_destination" + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["result"].title == config_title - assert result["data"] == MOCK_DATA_STEP + assert result["data"] == user_input @pytest.mark.parametrize( - ("raise_error", "text_error"), + ("raise_error", "text_error", "user_input_error"), [ - (OpendataTransportConnectionError(), "cannot_connect"), - (OpendataTransportError(), "bad_config"), - (IndexError(), "unknown"), + (OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP), + (OpendataTransportError(), "bad_config", MOCK_DATA_STEP), + (None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS), + (IndexError(), "unknown", MOCK_DATA_STEP), ], ) async def test_flow_user_init_data_error_and_recover( - hass: HomeAssistant, raise_error, text_error + hass: HomeAssistant, raise_error, text_error, user_input_error ) -> None: """Test unknown errors.""" with patch( @@ -78,7 +110,7 @@ async def test_flow_user_init_data_error_and_recover( ) result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=MOCK_DATA_STEP, + user_input=user_input_error, ) assert result["type"] is FlowResultType.FORM @@ -92,7 +124,7 @@ async def test_flow_user_init_data_error_and_recover( user_input=MOCK_DATA_STEP, ) - assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["type"] == FlowResultType.CREATE_ENTRY assert result["result"].title == "test_start test_destination" assert result["data"] == MOCK_DATA_STEP @@ -104,7 +136,7 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No entry = MockConfigEntry( domain=config_flow.DOMAIN, data=MOCK_DATA_STEP, - unique_id=f"{MOCK_DATA_STEP[CONF_START]} {MOCK_DATA_STEP[CONF_DESTINATION]}", + unique_id=unique_id_from_config(MOCK_DATA_STEP), ) entry.add_to_hass(hass) diff --git a/tests/components/swiss_public_transport/test_init.py b/tests/components/swiss_public_transport/test_init.py index e1b27cf5fe1641..47360f93cf21d1 100644 --- a/tests/components/swiss_public_transport/test_init.py +++ b/tests/components/swiss_public_transport/test_init.py @@ -2,22 +2,32 @@ from unittest.mock import AsyncMock, patch +import pytest + from homeassistant.components.swiss_public_transport.const import ( CONF_DESTINATION, CONF_START, + CONF_VIA, DOMAIN, ) +from homeassistant.components.swiss_public_transport.helper import unique_id_from_config +from homeassistant.config_entries import ConfigEntryState from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er from tests.common import MockConfigEntry -MOCK_DATA_STEP = { +MOCK_DATA_STEP_BASE = { CONF_START: "test_start", CONF_DESTINATION: "test_destination", } +MOCK_DATA_STEP_VIA = { + **MOCK_DATA_STEP_BASE, + CONF_VIA: ["via_station"], +} + CONNECTIONS = [ { "departure": "2024-01-06T18:03:00+0100", @@ -46,19 +56,38 @@ ] -async def test_migration_1_1_to_1_2( - hass: HomeAssistant, entity_registry: er.EntityRegistry +@pytest.mark.parametrize( + ( + "from_version", + "from_minor_version", + "config_data", + "overwrite_unique_id", + ), + [ + (1, 1, MOCK_DATA_STEP_BASE, "None_departure"), + (1, 2, MOCK_DATA_STEP_BASE, None), + (2, 1, MOCK_DATA_STEP_VIA, None), + ], +) +async def test_migration_from( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + from_version, + from_minor_version, + config_data, + overwrite_unique_id, ) -> None: """Test successful setup.""" - config_entry_faulty = MockConfigEntry( + config_entry = MockConfigEntry( domain=DOMAIN, - data=MOCK_DATA_STEP, - title="MIGRATION_TEST", - version=1, - minor_version=1, + data=config_data, + title=f"MIGRATION_TEST from {from_version}.{from_minor_version}", + version=from_version, + minor_version=from_minor_version, + unique_id=overwrite_unique_id or unique_id_from_config(config_data), ) - config_entry_faulty.add_to_hass(hass) + config_entry.add_to_hass(hass) with patch( "homeassistant.components.swiss_public_transport.OpendataTransport", @@ -67,21 +96,53 @@ async def test_migration_1_1_to_1_2( mock().connections = CONNECTIONS # Setup the config entry - await hass.config_entries.async_setup(config_entry_faulty.entry_id) + unique_id = unique_id_from_config(config_entry.data) + await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() assert entity_registry.async_is_registered( entity_registry.entities.get_entity_id( - (Platform.SENSOR, DOMAIN, "test_start test_destination_departure") + ( + Platform.SENSOR, + DOMAIN, + f"{unique_id}_departure", + ) ) ) - # Check change in config entry - assert config_entry_faulty.minor_version == 2 - assert config_entry_faulty.unique_id == "test_start test_destination" + # Check change in config entry and verify most recent version + assert config_entry.version == 2 + assert config_entry.minor_version == 1 + assert config_entry.unique_id == unique_id - # Check "None" is gone + # Check "None" is gone from version 1.1 to 1.2 assert not entity_registry.async_is_registered( entity_registry.entities.get_entity_id( (Platform.SENSOR, DOMAIN, "None_departure") ) ) + + +async def test_migrate_error_from_future(hass: HomeAssistant) -> None: + """Test a future version isn't migrated.""" + + mock_entry = MockConfigEntry( + domain=DOMAIN, + version=3, + minor_version=1, + unique_id="some_crazy_future_unique_id", + data=MOCK_DATA_STEP_BASE, + ) + + mock_entry.add_to_hass(hass) + + with patch( + "homeassistant.components.swiss_public_transport.OpendataTransport", + return_value=AsyncMock(), + ) as mock: + mock().connections = CONNECTIONS + + await hass.config_entries.async_setup(mock_entry.entry_id) + await hass.async_block_till_done() + + entry = hass.config_entries.async_get_entry(mock_entry.entry_id) + assert entry.state is ConfigEntryState.MIGRATION_ERROR