Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support local update file for OTA #884

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,15 @@ async def update_node(self, node_id: int, software_version: int | str) -> None:

# Add update to the OTA provider
ota_provider = ExternalOtaProvider(
self.server.vendor_id, self._ota_provider_dir / f"{node_id}"
self.server.vendor_id,
self._ota_provider_dir,
self._ota_provider_dir / f"{node_id}",
)

await ota_provider.initialize()

node_logger.info("Downloading update from '%s'", update["otaUrl"])
await ota_provider.download_update(update)
await ota_provider.fetch_update(update)

self._attribute_update_callbacks.setdefault(node_id, []).append(
ota_provider.check_update_state
Expand Down
89 changes: 57 additions & 32 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ class ExternalOtaProvider:

ENDPOINT_ID: Final[int] = 0

def __init__(self, vendor_id: int, ota_provider_dir: Path) -> None:
def __init__(
self, vendor_id: int, ota_provider_base_dir: Path, ota_provider_dir: Path
) -> None:
"""Initialize the OTA provider."""
self._vendor_id: int = vendor_id
self._ota_provider_base_dir: Path = ota_provider_base_dir
self._ota_provider_dir: Path = ota_provider_dir
self._ota_file_path: Path | None = None
self._ota_provider_proc: Process | None = None
Expand Down Expand Up @@ -261,10 +264,11 @@ async def stop(self) -> None:
self._ota_provider_proc = None
self._ota_provider_task = None

async def download_update(self, update_desc: dict) -> None:
"""Download update file from OTA Path and add it to the OTA provider."""
async def _download_update(
self, url: str, checksum_alg: hashlib._Hash | None
) -> Path:
"""Download update file from OTA URL."""

url = update_desc["otaUrl"]
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

Expand All @@ -273,20 +277,6 @@ async def download_update(self, update_desc: dict) -> None:
file_path = self._ota_provider_dir / file_name

try:
checksum_alg = None
if (
"otaChecksum" in update_desc
and "otaChecksumType" in update_desc
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
):
checksum_alg = hashlib.new(
CHECHKSUM_TYPES[update_desc["otaChecksumType"]]
)
else:
LOGGER.warning(
"No OTA checksum type or not supported, OTA will not be checked."
)

async with ClientSession(raise_for_status=True) as session:
# fetch the paa certificates list
LOGGER.debug("Download update from '%s'.", url)
Expand All @@ -300,20 +290,6 @@ async def download_update(self, update_desc: dict) -> None:
if checksum_alg:
checksum_alg.update(chunk)

# Download finished, check checksum if necessary
if checksum_alg:
checksum = b64encode(checksum_alg.digest()).decode("ascii")
checksum_expected = update_desc["otaChecksum"].strip()
if checksum != checksum_expected:
LOGGER.error(
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
file_name,
checksum_expected,
checksum,
)
await loop.run_in_executor(None, file_path.unlink)
raise UpdateError("Checksum mismatch!")

LOGGER.info(
"Update file '%s' downloaded to '%s'",
file_name,
Expand All @@ -326,6 +302,55 @@ async def download_update(self, update_desc: dict) -> None:
)
raise UpdateError("Fetching software version failed") from err

return file_path

async def fetch_update(self, update_desc: dict) -> None:
"""Fetch update file from OTA URL."""
url = update_desc["otaUrl"]
parsed_url = urlparse(url)
file_name = unquote(Path(parsed_url.path).name)

loop = asyncio.get_running_loop()

checksum_alg = None
if (
"otaChecksum" in update_desc
and "otaChecksumType" in update_desc
and update_desc["otaChecksumType"] in CHECHKSUM_TYPES
):
checksum_alg = hashlib.new(CHECHKSUM_TYPES[update_desc["otaChecksumType"]])
else:
LOGGER.warning(
"No OTA checksum type or not supported, OTA will not be checked."
)

if parsed_url.scheme in ["http", "https"]:
file_path = await self._download_update(url, checksum_alg)
elif parsed_url.scheme in ["file"]:
file_path = self._ota_provider_base_dir / Path(parsed_url.path[1:])
if not file_path.exists():
agners marked this conversation as resolved.
Show resolved Hide resolved
logging.warning("Local update file not found: %s", file_path)
raise UpdateError("Local update file not found")
if checksum_alg:
checksum_alg.update(
await loop.run_in_executor(None, file_path.read_bytes)
agners marked this conversation as resolved.
Show resolved Hide resolved
)
else:
raise UpdateError("Unsupported OTA URL scheme")

# Download finished, check checksum if necessary
if checksum_alg:
checksum_expected = update_desc["otaChecksum"].strip()
checksum = b64encode(checksum_alg.digest()).decode("ascii")
if checksum != checksum_expected:
LOGGER.error(
"Checksum mismatch for file '%s', expected: '%s', got: '%s'",
file_name,
checksum_expected,
checksum,
)
raise UpdateError("Checksum mismatch!")

self._ota_file_path = file_path

async def check_update_state(
Expand Down