Skip to content

Commit

Permalink
Merge pull request #34 from tradewelltech/25-add-empty-factory
Browse files Browse the repository at this point in the history
Add possibility to specify empty with a factory
  • Loading branch information
0x26res authored Sep 19, 2023
2 parents 1a2527a + a7669fa commit c013ecc
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 82 deletions.
109 changes: 78 additions & 31 deletions beavers/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@
O = typing.TypeVar("O") # noqa E741

_STATE_EMPTY = object()
_VALUE_EMPTY = object()
_STATE_UNCHANGED = object()

UTC_EPOCH = pd.to_datetime(0, utc=True)
UTC_MAX = pd.Timestamp.max.tz_localize("UTC")


class _SourceStreamFunction(typing.Generic[T]):
def __init__(self, empty: T, name: str):
self._empty = empty
def __init__(self, empty_factory: typing.Callable[[], T], name: str):
self._empty_factory = empty_factory
self._name = name
self._value = empty
self._value = empty_factory()

def set(self, value: T):
self._value = value

def __call__(self) -> T:
result = self._value
self._value = self._empty
self._value = self._empty_factory()
return result


Expand Down Expand Up @@ -211,7 +212,7 @@ class Node(typing.Generic[T]):

_function: typing.Optional[typing.Callable[[...], T]]
_inputs: _NodeInputs = dataclasses.field(repr=False)
_empty: typing.Any
_empty_factory: typing.Any
_observers: list[Node] = dataclasses.field(repr=False)
_runtime_data: _RuntimeNodeData

Expand All @@ -220,20 +221,23 @@ def _create(
value: T = None,
function: typing.Optional[typing.Callable[[...], T]] = None,
inputs: _NodeInputs = _NO_INPUTS,
empty: typing.Any = _STATE_EMPTY,
empty_factory: typing.Any = _STATE_EMPTY,
notifications: int = 1,
) -> Node:
return Node(
_function=function,
_inputs=inputs,
_empty=empty,
_empty_factory=empty_factory,
_runtime_data=_RuntimeNodeData(value, notifications, 0),
_observers=[],
)

def get_value(self) -> T:
"""Return the value of the output for the last update."""
return self._runtime_data.value
if self._runtime_data.value is _VALUE_EMPTY:
return self._empty_factory()
else:
return self._runtime_data.value

def get_cycle_id(self) -> int:
"""Return id of the cycle at which this node last updated."""
Expand Down Expand Up @@ -261,15 +265,15 @@ def _clean(self, cycle_id: int) -> bool:
return True
else:
if self._is_stream():
self._runtime_data.value = self._empty
self._runtime_data.value = _VALUE_EMPTY
self._runtime_data.notifications = 0
return False

def _is_stream(self) -> bool:
return self._empty is not _STATE_EMPTY
return self._empty_factory is not _STATE_EMPTY

def _is_state(self) -> bool:
return self._empty is _STATE_EMPTY
return self._empty_factory is _STATE_EMPTY

def _should_recalculate(self) -> bool:
return self._runtime_data.notifications != 0
Expand Down Expand Up @@ -374,7 +378,10 @@ def const(self, value: T) -> Node[T]:
)

def source_stream(
self, empty: typing.Optional[T] = None, name: typing.Optional[str] = None
self,
empty: typing.Optional[T] = None,
empty_factory: typing.Optional[typing.Callable[[], T]] = None,
name: typing.Optional[str] = None,
) -> Node[T]:
"""
Add a source stream `Node`.
Expand All @@ -384,47 +391,61 @@ def source_stream(
empty:
The value to which the stream reset to when there are no update.
Must implement `__len__` and be empty
empty_factory:
A provider for the empty value
A callable returning an object that implements `__len__` and is empty
name:
The name of the source
"""
empty = _check_empty(empty)
empty_factory = _check_empty(empty, empty_factory)
existing = self._sources.get(name) if name else None
if existing is not None:
if existing._empty != empty:
if existing._empty_factory != empty_factory:
raise ValueError(f"Duplicate source: {name}")
else:
return existing
else:
node = self._add_stream(
function=_SourceStreamFunction(empty, name),
empty=empty,
function=_SourceStreamFunction(empty_factory, name),
empty_factory=empty_factory,
inputs=_NO_INPUTS,
)
if name:
self._sources[name] = node
return node

def stream(
self, function: typing.Callable[P, T], empty: typing.Optional[T] = None
self,
function: typing.Callable[P, T],
empty: typing.Optional[T] = None,
empty_factory: typing.Optional[typing.Callable[[], T]] = None,
) -> NodePrototype:
"""
Add a stream `NodePrototype`.
Stream nodes are reset to their empty value after each cycle.
Therefore, the user must provide an `empty` value or an `empty_factory`
The default is to use `list` as the `empty_factory`.
Parameters
----------
function:
The processing function of the `Node`.
empty:
The value to which the stream reset to when there are no update.
Must implement `__len__` and be empty
empty_factory:
A provider for the empty value
A callable returning an object that implements `__len__` and is empty
"""
empty = _check_empty(empty)
empty_factory = _check_empty(empty, empty_factory)
_check_function(function)

def add_to_dag(inputs: _NodeInputs) -> Node:
return self._add_stream(function, empty, inputs)
return self._add_stream(function, empty_factory, inputs)

return NodePrototype(add_to_dag)

Expand Down Expand Up @@ -528,7 +549,7 @@ def silence(self, node: Node[T]) -> Node[T]:
function=SilentUpdate,
inputs=_NodeInputs.create([node], {}),
value=node.get_value(),
empty=node._empty,
empty_factory=node._empty_factory,
)
)

Expand Down Expand Up @@ -577,12 +598,19 @@ def flush_metrics(self) -> DagMetrics:
return results

def _add_stream(
self, function: typing.Callable[[...], T], empty: T, inputs: _NodeInputs
self,
function: typing.Callable[[...], T],
empty_factory: typing.Callable[[], T],
inputs: _NodeInputs,
) -> Node[T]:
_check_function(function)
empty = _check_empty(empty)
return self._add_node(
Node._create(value=empty, function=function, inputs=inputs, empty=empty)
Node._create(
value=empty_factory(),
function=function,
inputs=inputs,
empty_factory=empty_factory,
)
)

def _flush_timers(self, now: pd.Timestamp) -> int:
Expand Down Expand Up @@ -614,15 +642,34 @@ def _add_node(self, node: Node) -> Node:
return node


def _check_empty(empty: T) -> T:
if empty is None:
return []
elif not isinstance(empty, collections.abc.Sized):
raise TypeError("`empty` should implement `__len__`")
elif len(empty) != 0:
raise TypeError("`len(empty)` should be 0")
def _check_empty(
empty: typing.Optional[T], empty_factory: typing.Optional[typing.Callable[[], T]]
) -> typing.Callable[[], T]:
if empty is not None and empty_factory is not None:
raise ValueError(f"Can't provide both {empty=} and {empty_factory=}")
elif empty is None and empty_factory is None:
return list
elif empty is not None:
if not isinstance(empty, collections.abc.Sized):
raise TypeError("`empty` should implement `__len__`")
elif len(empty) != 0:
raise TypeError("`len(empty)` should be 0")
else:
return lambda: empty
else:
return empty
assert empty is None
if not callable(empty_factory):
raise TypeError(f"{empty_factory=} should be a callable")

empty_value = empty_factory()
if empty_value is None:
raise TypeError(f"{empty_factory=} should not return None")
elif not isinstance(empty_value, collections.abc.Sized):
raise TypeError(f"{empty_value=} should implement `__len__`")
elif len(empty_value) != 0:
raise TypeError("`len(empty)` should be 0")
else:
return empty_factory


def _check_input(node: Node) -> Node:
Expand Down
10 changes: 7 additions & 3 deletions docs/concepts/1_dag.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# DAG

At its core, `beavers` executes a Directed Acyclic Graph (DAG), where each node is a python function.
This section discuss the different type of nodes in the DAG.
This section discusses the different type of nodes in the DAG.

## Stream Source

Expand Down Expand Up @@ -68,7 +68,7 @@ A state node retains its value from one DAG execution to the next, even if it di
--8<-- "examples/dag_concepts.py:state_node"
```

State nodes have an empty value
Because they retain their value when they are not updated, state nodes don't require an empty value

## Const Node

Expand All @@ -77,10 +77,12 @@ A const node is a node whose value doesn't change.
--8<-- "examples/dag_concepts.py:const_node"
```

Const nodes behave like state nodes (their value isn't reset when they don't update).

## Connecting Nodes (aka `map`)

Nodes are connected by calling the `map` function.
Stream nodes can be connected to state nodes, stream nodes or const nodes, and vice versa.
Any stream or state node can be connected to state nodes, stream nodes or const nodes.

> :warning: The `map` function doesn't execute the underlying node.
> Instead it adds a node to the DAG
Expand All @@ -99,12 +101,14 @@ Or key word arguments:
## State vs Stream

Stream Nodes:

- need their return type to implement `collections.abc.Sized`
- need an empty value to be specfied (which default to `[]`)
- have their value reset to empty when they don't update
- are not considered updated if they return empty

State Nodes:

- Can return any type
- don't require an empty value
- retain their value on cycle they don't update
Expand Down
6 changes: 4 additions & 2 deletions docs/concepts/2_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ This section discuss advanced features that control how updates propagate in the

## Now node

Beavers can be used in both live and replay mode.
In replay mode, the wall clock isn't relevant.
Beavers can be used in both `live` and `replay` mode.
In `replay` mode, the wall clock isn't relevant.
To access the current time of the replay, you should use the now node:

```python
Expand All @@ -52,6 +52,8 @@ In this case it's possible to silence them:
--8<-- "examples/advanced_concepts.py:silence"
```

`silence` returns a new silenced node (rather than modify the existing node)

## Value Cutoff

By default, state nodes will update everytime they are notified.
Expand Down
4 changes: 2 additions & 2 deletions docs/concepts/3_replay.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ The `ReplayContext` contains timing information:

A `DataSourceProvider` provides a way of creating `DataSource`.

Assuming the data is stored in a csv file:
For example, if the data is stored in a csv file:

```csv
timestamp,message
2023-01-01 01:00:00+00:00,Hello
2023-01-01 01:01:00+00:00,How are you
```

Provided with the `ReplayContext`, it will load the and return a `DataSource`
Provided with the `ReplayContext`, our `DataSourceProvider` will load the and return a `DataSource`

```python
--8<-- "examples/replay_concepts.py:data_source_provider"
Expand Down
4 changes: 2 additions & 2 deletions examples/dag_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# --8<-- [end:source_stream_name]

# --8<-- [start:source_stream_empty]
dict_source_stream = dag.source_stream(empty={})
dict_source_stream = dag.source_stream(empty_factory=dict)
dict_source_stream.set_stream({"hello": "world"})
dag.execute()
assert dict_source_stream.get_value() == {"hello": "world"}
Expand Down Expand Up @@ -55,7 +55,7 @@ def multiply_by_2(values: list[int]) -> list[int]:


# --8<-- [start:stream_node_empty]
set_stream_node = dag.stream(set, empty=set()).map(source_stream)
set_stream_node = dag.stream(set, empty_factory=set).map(source_stream)
source_stream.set_stream([1, 2, 3, 1, 2, 3])
dag.execute()
assert set_stream_node.get_value() == {1, 2, 3}
Expand Down
4 changes: 2 additions & 2 deletions examples/etfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def get_updated_tickers(

def create_dag() -> Dag:
dag = Dag()
price_stream = dag.source_stream([], "price")
etf_composition_stream = dag.source_stream([], "etf_composition")
price_stream = dag.source_stream([], name="price")
etf_composition_stream = dag.source_stream([], name="etf_composition")
price_latest = dag.state(GetLatest(attrgetter("ticker"))).map(price_stream)
etf_composition_latest = dag.state(GetLatest(attrgetter("ticker"))).map(
etf_composition_stream
Expand Down
Loading

0 comments on commit c013ecc

Please sign in to comment.