Skip to content

Commit

Permalink
Modernize Slots management into dataclasses
Browse files Browse the repository at this point in the history
Originally slots were using just nested ararys indexed by position, but
we can do better now.

This project was originally written in Python 3.6 before many more
modern things like dataclasses exist, so now we can use
dataclass(slots=True) for automatic slots generation as well as easily
use dataclasses for inner-encapsulated data management.
  • Loading branch information
mattsta committed Aug 24, 2024
1 parent e6faadc commit eaf4c3f
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 94 deletions.
239 changes: 147 additions & 92 deletions eventkit/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,121 @@
from __future__ import annotations

import asyncio
import itertools
import logging
import types
import weakref
from dataclasses import dataclass, field
from typing import Any as AnyType
from typing import AsyncIterable, Awaitable, Iterable, List, Optional, Tuple, Union
from typing import (
AsyncIterable,
Awaitable,
Callable,
Final,
Iterable,
List,
Optional,
Tuple,
Union,
)

from .util import NO_VALUE, _NoValue, get_event_loop


@dataclass(slots=True)
class Slot:
obj: AnyType
weakref: Callable[[AnyType], AnyType] | None
func: Callable[[AnyType], AnyType]


@dataclass(slots=True)
class Slots:
slots: list[Slot] = field(default_factory=list)

def add(self, obj, weakref, func) -> None:
self.slots.append(Slot(obj, weakref, func))

def remove(self, obj, func):
"""Remove a specific obj/func combination from the slots collection."""
self.slots = list(
itertools.filterfalse(
lambda x: (x.obj is obj or x.weakref and x.weakref() is obj)
and x.func is func,
self.slots,
)
)

def remove_obj(self, obj):
self.slots = list(
itertools.filterfalse(
lambda x: x.obj is obj or x.weakref and x.weakref() is obj,
self.slots,
)
)

def remove_ref(self, ref):
self.slots = list(
itertools.filterfalse(
lambda x: x.weakref is ref,
self.slots,
)
)

def exists(self, obj, func):
return any(
[
(x.obj is obj or x.weakref and x.weakref() is obj) and x.func is func
for x in self.slots
]
)

@property
def count(self) -> int:
return len(self.slots)

from .util import NO_VALUE, get_event_loop
def clear(self) -> None:
self.slots = []

def __call__(self, caller, *args, **kwargs):
"""Loop over all active callbacks and call them"""
for slot in self.slots.copy():
ref = slot.weakref
func = slot.func

try:
if ref:
obj = ref()
else:
obj = slot.obj

result = None
if obj is None:
if func:
result = func(*args, **kwargs)
else:
if func:
result = func(obj, *args, **kwargs)
else:
result = obj(*args, **kwargs)

# even though asyncio.iscoroutine() would also work here,
# this manual hasattr() check performs better.
if result and hasattr(result, "__await__"):
asyncio.ensure_future(result, loop=get_event_loop())
except Exception as error:
# It's not really clear in the documentation or usage that exceptions
# get returned via an 'error_event' callback. We should make sure
# people know this clearly so event handler callback errors are noticed.
if len(caller.error_event):
caller.error_event.emit(caller, error)
else:
Event.logger.exception(
f"Value {args} caused exception for event {caller}"
)


@dataclass(slots=False)
class Event:
"""
Enable event passing between loosely coupled components.
Expand All @@ -18,50 +126,43 @@ class Event:
name: Name to use for this event.
"""

__slots__ = (
"error_event",
"done_event",
"_name",
"_value",
"_slots",
"_done",
"_source",
"__weakref__",
)
_name: str = ""
_with_error_done_events: bool = True

NO_VALUE = NO_VALUE
logger = logging.getLogger(__name__)
# Sub event that emits errors from this event as ``emit(source, exception)``.
error_event: Event | None = None

error_event: Optional["Event"]
done_event: Optional["Event"]
_name: str
_value: AnyType
_slots: List[List]
_done: bool
_source: Optional["Event"]
# Sub event that emits when this event is done as ``emit(source)``.
done_event: Event | None = None

def __init__(self, name: str = "", _with_error_done_events: bool = True):
self.error_event = None
"""
Sub event that emits errors from this event as
``emit(source, exception)``.
"""
logger: logging.Logger = field(default_factory=lambda: logging.getLogger(__name__))

self.done_event = None
"""
Sub event that emits when this event is done as
``emit(source)``.
"""
_value: AnyType = NO_VALUE
_slots: Final[Slots] = field(default_factory=Slots)
_done: bool = False
_source: Event | None = None
__weakref__: AnyType = None
_task: AnyType = None

NO_VALUE: Final[_NoValue] = NO_VALUE

if _with_error_done_events:
def __post_init__(self) -> None:
if self._with_error_done_events:
self.error_event = Event("error", False)
self.done_event = Event("done", False)

self._slots = [] # list of [obj, weakref, func] sublists
self._name = name or self.__class__.__qualname__
self._value = NO_VALUE
self._done = False
self._source = None
if not self._name:
self._name = self.__class__.__qualname__

def __hash__(self) -> int:
return hash(
(
self.name,
self._with_error_done_events,
self.error_event,
self.done_event,
)
)

def name(self) -> str:
"""
Expand Down Expand Up @@ -145,8 +246,7 @@ def g(a, b):
else:
ref = None

slot = [obj, ref, func]
self._slots.append(slot)
self._slots.add(obj, ref, func)

if self.done_event and done is not None:
self.done_event.connect(done)
Expand All @@ -170,12 +270,7 @@ def disconnect(self, listener, error=None, done=None):
done: The done callback to disconnect.
"""
obj, func = self._split(listener)
for slot in self._slots:
if (slot[0] is obj or slot[1] and slot[1]() is obj) and slot[2] is func:
slot[0] = slot[1] = slot[2] = None
break

self._slots = [s for s in self._slots if s != [None, None, None]]
self._slots.remove(obj, func)

if error is not None:
self.error_event.disconnect(error)
Expand All @@ -194,11 +289,7 @@ def disconnect_obj(self, obj):
obj: The target object that is to be completely removed from
this event.
"""
for slot in self._slots:
if slot[0] is obj or slot[1] and slot[1]() is obj:
slot[0] = slot[1] = slot[2] = None

self._slots = [s for s in self._slots if s != [None, None, None]]
self._slots.remove_obj(obj)

if self.error_event is not None:
self.error_event.disconnect_obj(obj)
Expand All @@ -214,33 +305,7 @@ def emit(self, *args):
args: Argument values to emit to listeners.
"""
self._value = args

for obj, ref, func in self._slots.copy():
try:
if ref:
obj = ref()

result = None
if obj is None:
if func:
result = func(*args)
else:
if func:
result = func(obj, *args)
else:
result = obj(*args)

# even though asyncio.iscoroutine() would also work here,
# this manual hasattr() check performs better.
if result and hasattr(result, "__await__"):
asyncio.ensure_future(result, loop=get_event_loop())
except Exception as error:
if len(self.error_event):
self.error_event.emit(self, error)
else:
Event.logger.exception(
f"Value {args} caused exception for event {self}"
)
self._slots(self, *args)

def emit_threadsafe(self, *args):
"""
Expand All @@ -253,10 +318,7 @@ def clear(self):
"""
Disconnect all listeners.
"""
for slot in self._slots:
slot[0] = slot[1] = slot[2] = None

self._slots = []
self._slots.clear()

def run(self) -> List:
"""
Expand Down Expand Up @@ -340,11 +402,7 @@ def set_source(self, source):
self._source = source

def _onFinalize(self, ref):
for slot in self._slots:
if slot[1] is ref:
slot[0] = slot[1] = slot[2] = None

self._slots = [s for s in self._slots if s != [None, None, None]]
self._slots.remove_ref(ref)

@staticmethod
def _split(c):
Expand Down Expand Up @@ -440,7 +498,7 @@ def __repr__(self):
return f"Event<{self.name()}, {self._slots}>"

def __len__(self):
return len(self._slots)
return self._slots.count

def __bool__(self):
return True
Expand Down Expand Up @@ -502,10 +560,7 @@ def __contains__(self, c):
See if callable is already connected.
"""
obj, func = self._split(c)
return any(
(s[0] is obj or s[1] and s[1]() is obj) and s[2] is func
for s in self._slots
)
return self._slots.exists(obj, func)

def __reduce__(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions eventkit/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import datetime as dt
import functools
from typing import AsyncIterator
from typing import AsyncIterator, Final


class _NoValue:
Expand All @@ -14,7 +14,7 @@ def __repr__(self):
__str__ = __repr__


NO_VALUE = _NoValue()
NO_VALUE: Final = _NoValue()


@functools.cache
Expand Down

0 comments on commit eaf4c3f

Please sign in to comment.