Skip to content

Commit

Permalink
Add concurrency to coordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
Snuffy2 committed Sep 28, 2024
1 parent c75bc74 commit b97e59b
Showing 1 changed file with 184 additions and 182 deletions.
366 changes: 184 additions & 182 deletions custom_components/opnsense/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import Mapping
import copy
import logging
Expand Down Expand Up @@ -51,58 +52,44 @@ async def inner(self, *args, **kwargs):

return inner

@_log_timing
async def _get_system_info(self) -> Mapping[str, Any]:
return await self._client.get_system_info()

@_log_timing
async def _get_firmware_update_info(self):
try:
return await self._client.get_firmware_update_info()
except Exception as e:
_LOGGER.error(
f"Error in get_firmware_update_info. {e.__class__.__qualname__}: {e}"
)
return None
# raise err

@_log_timing
async def _get_telemetry(self):
return await self._client.get_telemetry()

@_log_timing
async def _get_host_firmware_version(self) -> None | str:
return await self._client.get_host_firmware_version()

@_log_timing
async def _get_config(self):
return await self._client.get_config()

@_log_timing
async def _get_services(self):
return await self._client.get_services()

@_log_timing
async def _get_carp_interfaces(self):
return await self._client.get_carp_interfaces()

@_log_timing
async def _get_carp_status(self):
return await self._client.get_carp_status()

@_log_timing
async def _get_dhcp_leases(self):
return await self._client.get_dhcp_leases()

@_log_timing
async def _get_notices(self):
return await self._client.get_notices()
async def _get_states(self, categories: list) -> Mapping[str, Any]:
state: Mapping[str, Any] = {}
tasks: list = []
for cat in categories:
method = getattr(self._client, cat.get("function", ""), None)
if method:
tasks.append(method())
else:
_LOGGER.error(f"Method {cat.get('function','')} not found.")

results: list = await asyncio.gather(*tasks, return_exceptions=True)

for i, cat in enumerate(categories):
if not isinstance(results[i], Exception):
state[cat.get("state_key")] = results[i]
else:
_LOGGER.error(
f"Error getting {cat.get('state_key')}. "
f"{results[i].__class__.__qualname__}: {results[i]}"
)
return state

async def _get_dhcp_stats(self, leases: list) -> Mapping[str, Any]:
lease_stats: Mapping[str, Any] = {"total": 0, "online": 0, "offline": 0}
for lease in leases:
if not isinstance(lease, Mapping) or lease.get("act", "") == "expired":
continue

lease_stats["total"] += 1
if "online" in lease:
if lease["online"] == "online":
lease_stats["online"] += 1
if lease["online"] == "offline":
lease_stats["offline"] += 1
return lease_stats

@_log_timing
async def _get_arp_table(self):
return await self._client.get_arp_table(True)

async def _async_update_data(self):
async def _async_update_data(self) -> Mapping[str, Any]:
"""Fetch the latest state from OPNsense."""
_LOGGER.info(
f"{'DT ' if self._device_tracker_coordinator else ''}Updating Data"
Expand All @@ -119,144 +106,159 @@ async def _async_update_data(self):
self._state["update_time"] = current_time
self._state["previous_state"] = previous_state

self._state["system_info"] = await self._get_system_info()
self._state["host_firmware_version"] = await self._get_host_firmware_version()

if self._device_tracker_coordinator:
try:
self._state["arp_table"] = await self._get_arp_table()
except Exception as e:
_LOGGER.error(
f"Error getting arp table. {e.__class__.__qualname__}: {e}"
)
else:
self._state["firmware_update_info"] = await self._get_firmware_update_info()
self._state["telemetry"] = await self._get_telemetry()
self._state["config"] = await self._get_config()
self._state["services"] = await self._get_services()
self._state["carp_interfaces"] = await self._get_carp_interfaces()
self._state["carp_status"] = await self._get_carp_status()
# self._state["dhcp_leases"] = await self._client.get_dhcp_leases()
self._state["dhcp_leases"] = []
self._state["dhcp_stats"] = {}
self._state["notices"] = await self._get_notices()
self._state[ATTR_UNBOUND_BLOCKLIST] = (
await self._client.get_unbound_blocklist()
)

lease_stats: Mapping[str, int] = {"total": 0, "online": 0, "offline": 0}
for lease in self._state["dhcp_leases"]:
if "act" in lease.keys() and lease["act"] == "expired":
categories: list = [
{"function": "get_system_info", "state_key": "system_info"},
{
"function": "get_host_firmware_version",
"state_key": "host_firmware_version",
},
{
"function": "get_arp_table",
"state_key": "arp_table",
},
]
self._state.update(await self._get_states(categories))
return self._state

categories: list = [
{"function": "get_system_info", "state_key": "system_info"},
{
"function": "get_host_firmware_version",
"state_key": "host_firmware_version",
},
{
"function": "get_firmware_update_info",
"state_key": "firmware_update_info",
},
{"function": "get_telemetry", "state_key": "telemetry"},
{"function": "get_config", "state_key": "config"},
{"function": "get_services", "state_key": "services"},
{"function": "get_carp_interfaces", "state_key": "carp_interfaces"},
{"function": "get_carp_status", "state_key": "carp_status"},
{"function": "get_notices", "state_key": "notices"},
{
"function": "get_unbound_blocklist",
"state_key": ATTR_UNBOUND_BLOCKLIST,
},
]

self._state.update(await self._get_states(categories))

# self._state["dhcp_leases"] = []
self._state["dhcp_stats"] = {}
self._state["dhcp_stats"]["leases"] = await self._get_dhcp_stats(
self._state.get("dhcp_leases", [])
)

# calcule pps and kbps
update_time = dict_get(self._state, "update_time")
previous_update_time = dict_get(self._state, "previous_state.update_time")

if previous_update_time is not None:
elapsed_time = update_time - previous_update_time

for interface_name, interface in dict_get(
self._state, "telemetry.interfaces", {}
).items():
previous_interface = dict_get(
self._state,
f"previous_state.telemetry.interfaces.{interface_name}",
)
if previous_interface is None:
continue

lease_stats["total"] += 1
if "online" in lease.keys():
if lease["online"] == "online":
lease_stats["online"] += 1
if lease["online"] == "offline":
lease_stats["offline"] += 1

self._state["dhcp_stats"]["leases"] = lease_stats

# calcule pps and kbps
update_time = dict_get(self._state, "update_time")
previous_update_time = dict_get(self._state, "previous_state.update_time")

if previous_update_time is not None:
elapsed_time = update_time - previous_update_time

for interface_name in dict_get(
self._state, "telemetry.interfaces", {}
).keys():
interface = dict_get(
self._state, f"telemetry.interfaces.{interface_name}"
)
previous_interface = dict_get(
self._state,
f"previous_state.telemetry.interfaces.{interface_name}",
)
if previous_interface is None:
break

for prop_name in [
"inbytes",
"outbytes",
# "inbytespass",
# "outbytespass",
# "inbytesblock",
# "outbytesblock",
"inpkts",
"outpkts",
# "inpktspass",
# "outpktspass",
# "inpktsblock",
# "outpktsblock",
]:
current_parent_value = interface[prop_name]
previous_parent_value = previous_interface[prop_name]
change = abs(current_parent_value - previous_parent_value)
rate = change / elapsed_time

value = 0
if "pkts" in prop_name:
label = "packets_per_second"
value = rate
if "bytes" in prop_name:
label = "kilobytes_per_second"
# 1 Byte = 8 bits
# 1 byte is equal to 0.001 kilobytes
KBs = rate / 1000
# Kbs = KBs * 8
value = KBs

new_property = f"{prop_name}_{label}"
interface[new_property] = int(round(value, 0))

for server_name in dict_get(
self._state, "telemetry.openvpn.servers", {}
).keys():
if (
server_name
not in dict_get(
self._state, "telemetry.openvpn.servers", {}
).keys()
):
for prop_name in [
"inbytes",
"outbytes",
# "inbytespass",
# "outbytespass",
# "inbytesblock",
# "outbytesblock",
"inpkts",
"outpkts",
# "inpktspass",
# "outpktspass",
# "inpktsblock",
# "outpktsblock",
]:
try:
current_parent_value: float = interface[prop_name]
previous_parent_value: float = previous_interface[prop_name]
change: float = abs(
current_parent_value - previous_parent_value
)
rate: float = change / elapsed_time
except (TypeError, KeyError, ZeroDivisionError):
rate: float = 0

value: float = 0
if "pkts" in prop_name:
label = "packets_per_second"
value = rate
elif "bytes" in prop_name:
label = "kilobytes_per_second"
# 1 Byte = 8 bits
# 1 byte is equal to 0.001 kilobytes
KBs: float = rate / 1000
# Kbs = KBs * 8
value = KBs
else:
continue

if (
server_name
not in dict_get(
self._state, "previous_state.telemetry.openvpn.servers", {}
).keys()
):
new_property = f"{prop_name}_{label}"
interface[new_property] = int(round(value, 0))

for server_name in dict_get(self._state, "telemetry.openvpn.servers", {}):

if server_name not in dict_get(
self._state, "previous_state.telemetry.openvpn.servers", {}
):
continue

server: Mapping[str, Any] = (
self._state.get("telemetry", {})
.get("openvpn", {})
.get("servers", {})
.get(server_name, {})
)
previous_server: Mapping[str, Any] = (
self._state.get("previous_state", {})
.get("telemetry", {})
.get("openvpn", {})
.get("servers", {})
.get(server_name, {})
)

for prop_name in [
"total_bytes_recv",
"total_bytes_sent",
]:
try:
current_parent_value: float = server[prop_name]
previous_parent_value: float = previous_server[prop_name]
change: float = abs(
current_parent_value - previous_parent_value
)
rate: float = change / elapsed_time
except (TypeError, KeyError, ZeroDivisionError):
rate: float = 0

value: float = 0
if "pkts" in prop_name:
label = "packets_per_second"
value = rate
elif "bytes" in prop_name:
label = "kilobytes_per_second"
# 1 Byte = 8 bits
# 1 byte is equal to 0.001 kilobytes
KBs: float = rate / 1000
# Kbs = KBs * 8
value = KBs
else:
continue

server = self._state["telemetry"]["openvpn"]["servers"][server_name]
previous_server = self._state["previous_state"]["telemetry"][
"openvpn"
]["servers"][server_name]

for prop_name in [
"total_bytes_recv",
"total_bytes_sent",
]:
current_parent_value = server[prop_name]
previous_parent_value = previous_server[prop_name]
change = abs(current_parent_value - previous_parent_value)
rate = change / elapsed_time

value = 0
if "pkts" in prop_name:
label = "packets_per_second"
value = rate
if "bytes" in prop_name:
label = "kilobytes_per_second"
# 1 Byte = 8 bits
# 1 byte is equal to 0.001 kilobytes
KBs = rate / 1000
# Kbs = KBs * 8
value = KBs

new_property: str = f"{prop_name}_{label}"
server[new_property] = int(round(value, 0))
new_property: str = f"{prop_name}_{label}"
server[new_property] = int(round(value, 0))
return self._state

0 comments on commit b97e59b

Please sign in to comment.