Skip to content

Commit

Permalink
feat: properly deal with roaming devices over multiple interfaces (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmoesbergen authored Jul 21, 2024
1 parent 046df9f commit e20487e
Showing 1 changed file with 48 additions and 25 deletions.
73 changes: 48 additions & 25 deletions presence-detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, List, Callable, Optional, Tuple
from urllib import request

VERSION = "2.1.1"
VERSION = "2.2.0"


class Logger:
Expand Down Expand Up @@ -83,6 +83,7 @@ class Action(IntEnum):
QUIT = 3

device: str
interface: str
action: Action


Expand All @@ -96,7 +97,10 @@ def __init__(self, config_file: str) -> None:
self._queue: Queue = Queue()
self._watchers: List[UbusWatcher] = []
self._killed = False
self._last_seen_clients: set[str] = set([])
self._last_seen_clients: set[tuple[str, str]] = set()
self._online_clients: dict[str, set[str]] = {}
for interface in self._settings.interfaces:
self._online_clients[interface] = set()

@staticmethod
def _post(url: str, data: dict, headers: dict) -> Tuple[str, bool]:
Expand Down Expand Up @@ -134,21 +138,34 @@ def _ha_seen(self, device: str, seen: bool = True) -> bool:

return ok

def set_device_away(self, device: str) -> None:
def set_device_away(self, interface: str, device: str) -> None:
"""Mark a client as away in HA"""
if not self._should_handle_device(device):
return
self._queue.put(QueueItem(device, QueueItem.Action.DELETE))
self._logger.log(f"Device {device} is now away")
if device in self._online_clients[interface]:
self._online_clients[interface].remove(device)
for intf in set(self._settings.interfaces) - {interface}:
if device in self._online_clients[intf]:
# Device is still connected to another interface -> ignore
self._logger.log(
f"Device {device} still connected to {intf}, ignoring away event.",
True,
)
return
self._queue.put(QueueItem(device, interface, QueueItem.Action.DELETE))
self._logger.log(f"Device {device} on {interface} is now away")

def set_device_home(self, device: str) -> None:
def set_device_home(self, interface: str, device: str) -> None:
"""Add client to the 'add' queue"""
if not self._should_handle_device(device):
return
self._queue.put(QueueItem(device, QueueItem.Action.ADD))
self._logger.log(f"Device {device} is now at {self._settings.location}")
self._queue.put(QueueItem(device, interface, QueueItem.Action.ADD))
self._online_clients[interface].add(device)
self._logger.log(
f"Device {device} on {interface} is now at {self._settings.location}"
)

def _get_all_online_devices(self) -> List[str]:
def _get_all_online_devices(self) -> List[Tuple[str, str]]:
"""Call ubus and get all online devices"""
devices = []
for interface in self._settings.interfaces:
Expand All @@ -164,7 +181,7 @@ def _get_all_online_devices(self) -> List[str]:
)
continue
response: dict = json.loads(process.stdout)
devices.extend(response["clients"].keys())
devices.extend([(interface, key) for key in response["clients"].keys()])
return devices

def _should_handle_device(self, device: str) -> bool:
Expand Down Expand Up @@ -196,18 +213,18 @@ def stop(self, _signum: Optional[int] = None, _frame: Optional[int] = None):
self._logger.log("Stopping...")
self.stop_watchers()
self._killed = True
self._queue.put(QueueItem("quit", QueueItem.Action.QUIT))
self._queue.put(QueueItem("quit", "", QueueItem.Action.QUIT))

def _do_full_sync(self, away_only=False):
"""Perform a full sync of all currently online devices compared to last time"""
seen_now = set(self._get_all_online_devices())
away = self._last_seen_clients - seen_now
self._last_seen_clients = seen_now
if not away_only:
for client in seen_now:
self.set_device_home(client)
for client in away:
self.set_device_away(client)
for interface, client in seen_now:
if not away_only:
self.set_device_home(interface, client)
for interface, client in away:
self.set_device_away(interface, client)

def _update_version_entity(self):
"""Create a script version entity in home assistant"""
Expand All @@ -218,11 +235,15 @@ def _update_version_entity(self):
)
entity_id = f"sensor.{ap_name}_presence_detector_version"

response, ok = self._post(
f"{self._settings.hass_url}/api/states/{entity_id}",
data={"state": VERSION},
headers={"Authorization": f"Bearer {self._settings.hass_token}"},
)
try:
response, ok = self._post(
f"{self._settings.hass_url}/api/states/{entity_id}",
data={"state": VERSION},
headers={"Authorization": f"Bearer {self._settings.hass_token}"},
)
except Exception as ex: # pylint: disable=broad-except
ok = False
response = str(ex)
if not ok:
self._logger.log(
f"Unable to create/update version entity in HA: {response}"
Expand Down Expand Up @@ -283,8 +304,8 @@ class UbusWatcher(Thread):
def __init__(
self,
interface: str,
on_join: Callable[[str], None],
on_leave: Callable[[str], None],
on_join: Callable[[str, str], None],
on_leave: Callable[[str, str], None],
) -> None:
super().__init__()
self._on_join = on_join
Expand Down Expand Up @@ -323,9 +344,11 @@ def run(self) -> None:
# Ignore incomplete / invalid json
pass
if "assoc" in event:
self._on_join(event["assoc"]["address"].lower())
self._on_join(self._interface, event["assoc"]["address"].lower())
elif "disassoc" in event:
self._on_leave(event["disassoc"]["address"].lower())
self._on_leave(
self._interface, event["disassoc"]["address"].lower()
)
ubus.terminate()
ubus.wait()

Expand Down

0 comments on commit e20487e

Please sign in to comment.