diff --git a/beavers/engine.py b/beavers/engine.py index 97dbad9..178a8d4 100644 --- a/beavers/engine.py +++ b/beavers/engine.py @@ -16,6 +16,7 @@ O = typing.TypeVar("O") # noqa E741 _STATE_EMPTY = object() +_VALUE_EMPTY = object() _STATE_UNCHANGED = object() UTC_EPOCH = pd.to_datetime(0, utc=True) @@ -23,17 +24,17 @@ class _SourceStreamFunction(typing.Generic[T]): - def __init__(self, empty: T, name: str): - self._empty = empty + def __init__(self, empty_factory: typing.Callable[[], T], name: str): + self._empty_factory = empty_factory self._name = name - self._value = empty + self._value = empty_factory() def set(self, value: T): self._value = value def __call__(self) -> T: result = self._value - self._value = self._empty + self._value = self._empty_factory() return result @@ -211,7 +212,7 @@ class Node(typing.Generic[T]): _function: typing.Optional[typing.Callable[[...], T]] _inputs: _NodeInputs = dataclasses.field(repr=False) - _empty: typing.Any + _empty_factory: typing.Any _observers: list[Node] = dataclasses.field(repr=False) _runtime_data: _RuntimeNodeData @@ -220,20 +221,23 @@ def _create( value: T = None, function: typing.Optional[typing.Callable[[...], T]] = None, inputs: _NodeInputs = _NO_INPUTS, - empty: typing.Any = _STATE_EMPTY, + empty_factory: typing.Any = _STATE_EMPTY, notifications: int = 1, ) -> Node: return Node( _function=function, _inputs=inputs, - _empty=empty, + _empty_factory=empty_factory, _runtime_data=_RuntimeNodeData(value, notifications, 0), _observers=[], ) def get_value(self) -> T: """Return the value of the output for the last update.""" - return self._runtime_data.value + if self._runtime_data.value is _VALUE_EMPTY: + return self._empty_factory() + else: + return self._runtime_data.value def get_cycle_id(self) -> int: """Return id of the cycle at which this node last updated.""" @@ -261,15 +265,15 @@ def _clean(self, cycle_id: int) -> bool: return True else: if self._is_stream(): - self._runtime_data.value = self._empty + self._runtime_data.value = _VALUE_EMPTY self._runtime_data.notifications = 0 return False def _is_stream(self) -> bool: - return self._empty is not _STATE_EMPTY + return self._empty_factory is not _STATE_EMPTY def _is_state(self) -> bool: - return self._empty is _STATE_EMPTY + return self._empty_factory is _STATE_EMPTY def _should_recalculate(self) -> bool: return self._runtime_data.notifications != 0 @@ -374,7 +378,10 @@ def const(self, value: T) -> Node[T]: ) def source_stream( - self, empty: typing.Optional[T] = None, name: typing.Optional[str] = None + self, + empty: typing.Optional[T] = None, + empty_factory: typing.Optional[typing.Callable[[], T]] = None, + name: typing.Optional[str] = None, ) -> Node[T]: """ Add a source stream `Node`. @@ -384,21 +391,24 @@ def source_stream( empty: The value to which the stream reset to when there are no update. Must implement `__len__` and be empty + empty_factory: + A provider for the empty value + A callable returning an object that implements `__len__` and is empty name: The name of the source """ - empty = _check_empty(empty) + empty_factory = _check_empty(empty, empty_factory) existing = self._sources.get(name) if name else None if existing is not None: - if existing._empty != empty: + if existing._empty_factory != empty_factory: raise ValueError(f"Duplicate source: {name}") else: return existing else: node = self._add_stream( - function=_SourceStreamFunction(empty, name), - empty=empty, + function=_SourceStreamFunction(empty_factory, name), + empty_factory=empty_factory, inputs=_NO_INPUTS, ) if name: @@ -406,11 +416,19 @@ def source_stream( return node def stream( - self, function: typing.Callable[P, T], empty: typing.Optional[T] = None + self, + function: typing.Callable[P, T], + empty: typing.Optional[T] = None, + empty_factory: typing.Optional[typing.Callable[[], T]] = None, ) -> NodePrototype: """ Add a stream `NodePrototype`. + Stream nodes are reset to their empty value after each cycle. + Therefore, the user must provide an `empty` value or an `empty_factory` + + The default is to use `list` as the `empty_factory`. + Parameters ---------- function: @@ -418,13 +436,16 @@ def stream( empty: The value to which the stream reset to when there are no update. Must implement `__len__` and be empty + empty_factory: + A provider for the empty value + A callable returning an object that implements `__len__` and is empty """ - empty = _check_empty(empty) + empty_factory = _check_empty(empty, empty_factory) _check_function(function) def add_to_dag(inputs: _NodeInputs) -> Node: - return self._add_stream(function, empty, inputs) + return self._add_stream(function, empty_factory, inputs) return NodePrototype(add_to_dag) @@ -528,7 +549,7 @@ def silence(self, node: Node[T]) -> Node[T]: function=SilentUpdate, inputs=_NodeInputs.create([node], {}), value=node.get_value(), - empty=node._empty, + empty_factory=node._empty_factory, ) ) @@ -577,12 +598,19 @@ def flush_metrics(self) -> DagMetrics: return results def _add_stream( - self, function: typing.Callable[[...], T], empty: T, inputs: _NodeInputs + self, + function: typing.Callable[[...], T], + empty_factory: typing.Callable[[], T], + inputs: _NodeInputs, ) -> Node[T]: _check_function(function) - empty = _check_empty(empty) return self._add_node( - Node._create(value=empty, function=function, inputs=inputs, empty=empty) + Node._create( + value=empty_factory(), + function=function, + inputs=inputs, + empty_factory=empty_factory, + ) ) def _flush_timers(self, now: pd.Timestamp) -> int: @@ -614,15 +642,34 @@ def _add_node(self, node: Node) -> Node: return node -def _check_empty(empty: T) -> T: - if empty is None: - return [] - elif not isinstance(empty, collections.abc.Sized): - raise TypeError("`empty` should implement `__len__`") - elif len(empty) != 0: - raise TypeError("`len(empty)` should be 0") +def _check_empty( + empty: typing.Optional[T], empty_factory: typing.Optional[typing.Callable[[], T]] +) -> typing.Callable[[], T]: + if empty is not None and empty_factory is not None: + raise ValueError(f"Can't provide both {empty=} and {empty_factory=}") + elif empty is None and empty_factory is None: + return list + elif empty is not None: + if not isinstance(empty, collections.abc.Sized): + raise TypeError("`empty` should implement `__len__`") + elif len(empty) != 0: + raise TypeError("`len(empty)` should be 0") + else: + return lambda: empty else: - return empty + assert empty is None + if not callable(empty_factory): + raise TypeError(f"{empty_factory=} should be a callable") + + empty_value = empty_factory() + if empty_value is None: + raise TypeError(f"{empty_factory=} should not return None") + elif not isinstance(empty_value, collections.abc.Sized): + raise TypeError(f"{empty_value=} should implement `__len__`") + elif len(empty_value) != 0: + raise TypeError("`len(empty)` should be 0") + else: + return empty_factory def _check_input(node: Node) -> Node: diff --git a/docs/concepts/1_dag.md b/docs/concepts/1_dag.md index 2d1e531..e72371b 100644 --- a/docs/concepts/1_dag.md +++ b/docs/concepts/1_dag.md @@ -2,7 +2,7 @@ # DAG At its core, `beavers` executes a Directed Acyclic Graph (DAG), where each node is a python function. -This section discuss the different type of nodes in the DAG. +This section discusses the different type of nodes in the DAG. ## Stream Source @@ -68,7 +68,7 @@ A state node retains its value from one DAG execution to the next, even if it di --8<-- "examples/dag_concepts.py:state_node" ``` -State nodes have an empty value +Because they retain their value when they are not updated, state nodes don't require an empty value ## Const Node @@ -77,10 +77,12 @@ A const node is a node whose value doesn't change. --8<-- "examples/dag_concepts.py:const_node" ``` +Const nodes behave like state nodes (their value isn't reset when they don't update). + ## Connecting Nodes (aka `map`) Nodes are connected by calling the `map` function. -Stream nodes can be connected to state nodes, stream nodes or const nodes, and vice versa. +Any stream or state node can be connected to state nodes, stream nodes or const nodes. > :warning: The `map` function doesn't execute the underlying node. > Instead it adds a node to the DAG @@ -99,12 +101,14 @@ Or key word arguments: ## State vs Stream Stream Nodes: + - need their return type to implement `collections.abc.Sized` - need an empty value to be specfied (which default to `[]`) - have their value reset to empty when they don't update - are not considered updated if they return empty State Nodes: + - Can return any type - don't require an empty value - retain their value on cycle they don't update diff --git a/docs/concepts/2_advanced.md b/docs/concepts/2_advanced.md index 062f025..3ce419f 100644 --- a/docs/concepts/2_advanced.md +++ b/docs/concepts/2_advanced.md @@ -24,8 +24,8 @@ This section discuss advanced features that control how updates propagate in the ## Now node -Beavers can be used in both live and replay mode. -In replay mode, the wall clock isn't relevant. +Beavers can be used in both `live` and `replay` mode. +In `replay` mode, the wall clock isn't relevant. To access the current time of the replay, you should use the now node: ```python @@ -52,6 +52,8 @@ In this case it's possible to silence them: --8<-- "examples/advanced_concepts.py:silence" ``` +`silence` returns a new silenced node (rather than modify the existing node) + ## Value Cutoff By default, state nodes will update everytime they are notified. diff --git a/docs/concepts/3_replay.md b/docs/concepts/3_replay.md index 642af91..704cae3 100644 --- a/docs/concepts/3_replay.md +++ b/docs/concepts/3_replay.md @@ -52,7 +52,7 @@ The `ReplayContext` contains timing information: A `DataSourceProvider` provides a way of creating `DataSource`. -Assuming the data is stored in a csv file: +For example, if the data is stored in a csv file: ```csv timestamp,message @@ -60,7 +60,7 @@ timestamp,message 2023-01-01 01:01:00+00:00,How are you ``` -Provided with the `ReplayContext`, it will load the and return a `DataSource` +Provided with the `ReplayContext`, our `DataSourceProvider` will load the and return a `DataSource` ```python --8<-- "examples/replay_concepts.py:data_source_provider" diff --git a/examples/dag_concepts.py b/examples/dag_concepts.py index 3425de8..8a01f2a 100644 --- a/examples/dag_concepts.py +++ b/examples/dag_concepts.py @@ -26,7 +26,7 @@ # --8<-- [end:source_stream_name] # --8<-- [start:source_stream_empty] -dict_source_stream = dag.source_stream(empty={}) +dict_source_stream = dag.source_stream(empty_factory=dict) dict_source_stream.set_stream({"hello": "world"}) dag.execute() assert dict_source_stream.get_value() == {"hello": "world"} @@ -55,7 +55,7 @@ def multiply_by_2(values: list[int]) -> list[int]: # --8<-- [start:stream_node_empty] -set_stream_node = dag.stream(set, empty=set()).map(source_stream) +set_stream_node = dag.stream(set, empty_factory=set).map(source_stream) source_stream.set_stream([1, 2, 3, 1, 2, 3]) dag.execute() assert set_stream_node.get_value() == {1, 2, 3} diff --git a/examples/etfs.py b/examples/etfs.py index 3a124ad..289fee3 100644 --- a/examples/etfs.py +++ b/examples/etfs.py @@ -115,8 +115,8 @@ def get_updated_tickers( def create_dag() -> Dag: dag = Dag() - price_stream = dag.source_stream([], "price") - etf_composition_stream = dag.source_stream([], "etf_composition") + price_stream = dag.source_stream([], name="price") + etf_composition_stream = dag.source_stream([], name="etf_composition") price_latest = dag.state(GetLatest(attrgetter("ticker"))).map(price_stream) etf_composition_latest = dag.state(GetLatest(attrgetter("ticker"))).map( etf_composition_stream diff --git a/tests/test_engine.py b/tests/test_engine.py index 7d4b423..975eacf 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,6 +12,7 @@ Dag, DagMetrics, TimerManager, + _check_empty, _NodeInputs, _unchanged_callback, ) @@ -28,9 +29,9 @@ def test_state_positional(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(2)).map(y_source) z = dag.state(add).map(x, y) @@ -46,9 +47,9 @@ def test_state_positional(): def test_map_state_key_word(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(2)).map(y_source) z = dag.state(add).map(left=x, right=y) @@ -63,9 +64,9 @@ def test_map_state_key_word(): def test_map_positional_and_key_word_not_valid(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(2)).map(y_source) dag.state(add).map(x, left=y) @@ -76,9 +77,9 @@ def test_map_positional_and_key_word_not_valid(): def test_map_runtime_failure(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(40)).map(x_source) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(1)).map(y_source) z = dag.state(add_no_42).map(x, y) @@ -93,9 +94,9 @@ def test_map_runtime_failure(): def test_using_lambda(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(40)).map(x_source) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(41)).map(y_source) z = dag.state(add).map(x, y) @@ -108,7 +109,7 @@ def test_using_lambda(): def test_scalar(): dag = Dag() x = dag.const(40) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(1)).map(y_source) z = dag.state(add).map(x, y) @@ -128,18 +129,18 @@ def test_stream_no_empty(): dag = Dag() source1 = dag.source_stream() assert source1.get_value() == [] - assert source1._empty == [] + assert source1._empty_factory is list stream = dag.stream(lambda x: x).map(source1) assert stream.get_value() == [] - assert stream._empty == [] + assert stream._empty_factory is list def test_stream_to_state(): dag = Dag() - hello_stream = dag.source_stream([], "hello") - world_stream = dag.source_stream([], "world") + hello_stream = dag.source_stream([], name="hello") + world_stream = dag.source_stream([], name="world") hello_count = dag.state(WordCount()).map(hello_stream) world_count = dag.state(WordCount()).map(world_stream) @@ -231,7 +232,7 @@ def run_get_squares(xs: list[int]) -> list[int]: def test_time(): dag = Dag() - source = dag.source_stream([], "x") + source = dag.source_stream([], name="x") add_time = dag.state(lambda x, t: [(v, t) for v in x]).map(source, dag.now()) time0 = pd.to_datetime("2022-09-15", utc=True) @@ -259,7 +260,7 @@ def test_time(): def test_cutoff_update(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) x_change_only = dag.cutoff(x) @@ -282,7 +283,7 @@ def test_cutoff_update(): def test_cutoff_custom(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) x_change_only = dag.cutoff(x, comparator=lambda x, y: abs(x - y) < 0.1) @@ -316,7 +317,7 @@ def test_cutoff_custom(): def test_cutoff_not_callable(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) with pytest.raises(TypeError, match="`comparator` should be callable"): dag.cutoff(x, comparator="not a callable") @@ -324,7 +325,7 @@ def test_cutoff_not_callable(): def test_silence_state(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(1)).map(x_source) x_silent = dag.silence(x) @@ -345,7 +346,7 @@ def test_silence_state(): def test_silence_stream(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x_silent = dag.silence(x_source) x_source.set_stream(["a", "b"]) @@ -435,8 +436,8 @@ def test_timer_manager(): def test_sinks_and_sources(): dag = Dag() - source_1 = dag.source_stream([], "source_1") - source_2 = dag.source_stream([], "source_2") + source_1 = dag.source_stream([], name="source_1") + source_2 = dag.source_stream([], name="source_2") both = dag.stream(lambda left, right: left + right, []).map(source_1, source_2) sink = dag.sink("sink", both) @@ -464,22 +465,22 @@ def test_sinks_and_sources(): def test_duplicate_source(): dag = Dag() - source_1 = dag.source_stream([], "source") - source_2 = dag.source_stream([], "source") + source_1 = dag.source_stream(name="source") + source_2 = dag.source_stream(name="source") assert source_1 is source_2 def test_duplicate_source_different_empty(): dag = Dag() - dag.source_stream([], "source_1") + dag.source_stream([], name="source_1") with pytest.raises(ValueError, match=r"Duplicate source: source_1"): - dag.source_stream({}, "source_1") + dag.source_stream({}, name="source_1") def test_node_with_same_input_positional(): dag = Dag() - source_1 = dag.source_stream([], "source") + source_1 = dag.source_stream([], name="source") node = dag.stream(lambda a, b: a + b, []).map(source_1, source_1) assert node._inputs.positional == (source_1, source_1) assert node._inputs.key_word == {} @@ -488,7 +489,7 @@ def test_node_with_same_input_positional(): def test_node_with_same_input_key_word(): dag = Dag() - source_1 = dag.source_stream([], "source") + source_1 = dag.source_stream([], name="source") node = dag.stream(lambda a, b: a + b, []).map(a=source_1, b=source_1) assert node._inputs.positional == () assert node._inputs.key_word == {"a": source_1, "b": source_1} @@ -497,7 +498,7 @@ def test_node_with_same_input_key_word(): def test_node_with_same_input_mixed(): dag = Dag() - source_1 = dag.source_stream([], "source") + source_1 = dag.source_stream([], name="source") node = dag.stream(lambda a, b: a + b, []).map(source_1, b=source_1) assert node._inputs.positional == (source_1,) assert node._inputs.key_word == {"b": source_1} @@ -515,7 +516,7 @@ def test_wrong_usage(): def test_add_existing_node(): dag = Dag() - source = dag.source_stream([], "source") + source = dag.source_stream([], name="source") node = dag.stream(lambda x: x, []).map(source) with pytest.raises(ValueError, match="New Node can't have observers"): dag._add_node(source) @@ -526,14 +527,14 @@ def test_add_existing_node(): def test_mixed_dags(): dag = Dag() other_dag = Dag() - other_source = other_dag.source_stream([], "source") + other_source = other_dag.source_stream([], name="source") with pytest.raises(ValueError, match="Input Node not in dag"): dag.stream(lambda x: x, []).map(other_source) def test_get_sink_value_on_other_node(): dag = Dag() - source = dag.source_stream([], "source") + source = dag.source_stream([], name="source") node = dag.stream(lambda x: x, []).map(source) with pytest.raises(TypeError, match="Only _SinkFunction can be read"): node.get_sink_value() @@ -541,7 +542,7 @@ def test_get_sink_value_on_other_node(): def test_node_inputs_kwargs_not_str(): dag = Dag() - source = dag.source_stream([], "source") + source = dag.source_stream([], name="source") with pytest.raises(TypeError, match="class 'int'"): _NodeInputs.create([], {1: source}) @@ -598,9 +599,9 @@ def test_unchanged_callback(): def test_metrics(): dag = Dag() - x_source = dag.source_stream([], "x") + x_source = dag.source_stream([], name="x") x = dag.state(GetLatest(40)).map(x_source) - y_source = dag.source_stream([], "y") + y_source = dag.source_stream([], name="y") y = dag.state(GetLatest(41)).map(y_source) z = dag.state(add).map(x, y) @@ -621,3 +622,65 @@ def test_metrics(): x_source.set_stream([1, 2, 3]) dag.execute() assert dag.flush_metrics() == DagMetrics(4, 4, 1, 8) + + +def test_check_empty(): + assert _check_empty(None, None) is list + empty_list_empty = _check_empty([], None) + assert callable(empty_list_empty) + assert empty_list_empty() == [] + + with pytest.raises(TypeError, match=r"`len\(empty\)` should be 0"): + _check_empty([1], None) + + with pytest.raises(TypeError, match=r"`empty` should implement `__len__`"): + _check_empty(123, None) + + with pytest.raises( + ValueError, + match=r"Can't provide both empty=\[\] and empty_factory=", + ): + _check_empty([], list) + + with pytest.raises(TypeError, match=r"`len\(empty\)` should be 0"): + _check_empty(None, lambda: [1]) + + with pytest.raises(TypeError, match=r"empty_value=123 should implement `__len__`"): + _check_empty(None, lambda: 123) + + with pytest.raises(TypeError, match=r"empty_factory=123 should be a callable"): + _check_empty(None, 123) + + with pytest.raises(TypeError, match=r"should not return None"): + _check_empty(empty=None, empty_factory=lambda: None) + + assert _check_empty(empty=None, empty_factory=list) is list + + +def _modify(values: list[int], right: list[int]) -> list[int]: + values.extend(right) + return values + + +def test_mutate_inputs(): + dag = Dag() + source = dag.source_stream() + right = dag.source_stream() + modifier = dag.stream(_modify).map(source, right) + passthrough = dag.stream(lambda x, _: x).map(source, right) + + source.set_stream([1, 2, 3]) + right.set_stream([4]) + dag.execute() + assert modifier.get_value() == [1, 2, 3, 4] + assert passthrough.get_value() == [1, 2, 3, 4] + + dag.execute() + assert modifier.get_value() == [] # Not notified + assert passthrough.get_value() == [] + + right.set_stream([1]) + dag.execute() + assert modifier.get_value() == [1] + assert passthrough.get_value() == [] # Notified but got the factory list + assert passthrough.get_cycle_id() != dag.get_cycle_id() # considered not updated diff --git a/tests/test_kafka.py b/tests/test_kafka.py index fa432ac..1ea636c 100644 --- a/tests/test_kafka.py +++ b/tests/test_kafka.py @@ -571,7 +571,7 @@ def _timestamp_to_bytes(timestamp: pd.Timestamp) -> bytes: def test_kafka_driver_timer(): dag = Dag() - messages_stream = dag.source_stream([], "messages") + messages_stream = dag.source_stream([], name="messages") timestamp_stream = dag.stream( lambda x: [TimerEntry(pd.to_datetime(v), [1, 2, 3]) for v in x if v], [] ).map(messages_stream) diff --git a/tests/test_replay.py b/tests/test_replay.py index ba59888..bd108cd 100644 --- a/tests/test_replay.py +++ b/tests/test_replay.py @@ -255,7 +255,7 @@ def test_replay_read_sources(): ) dag = Dag() - dag.source_stream([], "hello") + dag.source_stream([], name="hello") driver = ReplayDriver.create( dag=dag, replay_context=ReplayContext( @@ -284,7 +284,7 @@ def test_replay_run_cycle(): ) dag = Dag() - dag.source_stream([], "hello") + dag.source_stream([], name="hello") driver = ReplayDriver.create( dag=dag, replay_context=ReplayContext( diff --git a/tests/test_util.py b/tests/test_util.py index c6255a5..06d6e7d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -95,7 +95,7 @@ def __call__( def create_word_count_dag() -> tuple[Dag, WordCount]: dag = Dag() - messages_stream = dag.source_stream([], "messages") + messages_stream = dag.source_stream([], name="messages") word_count = WordCount() state = dag.state(word_count).map(messages_stream) changed_key = dag.stream(lambda x: sorted(set(x)), []).map(messages_stream)