diff --git a/docs/public_api_test.py b/docs/public_api_test.py index aceb3b6d9..254083123 100644 --- a/docs/public_api_test.py +++ b/docs/public_api_test.py @@ -30,6 +30,7 @@ "has_leak", "event_set", "input_node", + "input_node_from_schema", "plot", "compile", "config", diff --git a/docs/src/reference/temporian/input_node_from_schema.md b/docs/src/reference/temporian/input_node_from_schema.md new file mode 100644 index 000000000..e69de29bb diff --git a/temporian/__init__.py b/temporian/__init__.py index a2d745a78..07aa3ae57 100644 --- a/temporian/__init__.py +++ b/temporian/__init__.py @@ -39,7 +39,7 @@ # EventSetNodes from temporian.core.data.node import EventSetNode -from temporian.core.data.node import input_node +from temporian.core.data.node import input_node, input_node_from_schema # Dtypes from temporian.core.data.dtype import float64 diff --git a/temporian/beam/evaluation.py b/temporian/beam/evaluation.py index 673879699..a8f642bd2 100644 --- a/temporian/beam/evaluation.py +++ b/temporian/beam/evaluation.py @@ -137,6 +137,14 @@ def run_multi_io( data = {**inputs} + # Check that operators implementations are available + needed_operators = set() + for step in schedule.steps: + needed_operators.add(step.op.definition.key) + implementation_lib.check_operators_implementations_are_available( + needed_operators + ) + num_steps = len(schedule.steps) for step_idx, step in enumerate(schedule.steps): operator_def = step.op.definition diff --git a/temporian/beam/implementation_lib.py b/temporian/beam/implementation_lib.py index 874dbe9b4..ee50c7886 100644 --- a/temporian/beam/implementation_lib.py +++ b/temporian/beam/implementation_lib.py @@ -14,7 +14,7 @@ """Registering mechanism for operator implementation classes.""" -from typing import Any, Dict +from typing import Any, Dict, Set _OPERATOR_IMPLEMENTATIONS = {} @@ -22,6 +22,18 @@ # registration. +def check_operators_implementations_are_available(needed: Set[str]): + """Checks if operator implementations are available.""" + missing = set(needed) - set(_OPERATOR_IMPLEMENTATIONS.keys()) + if missing: + raise ValueError( + f"Unknown operator implementations '{missing}' for Beam backend. It" + " seems this operator is only available for the in-process" + " Temporian backend. Available Beam operator implementations are:" + f" {list(_OPERATOR_IMPLEMENTATIONS.keys())}." + ) + + def register_operator_implementation( operator_class, operator_implementation_class ): diff --git a/temporian/core/data/node.py b/temporian/core/data/node.py index 948c15a32..97bc88cbe 100644 --- a/temporian/core/data/node.py +++ b/temporian/core/data/node.py @@ -176,8 +176,10 @@ def __repr__(self) -> str: def input_node( - features: List[Tuple[str, DType]], - indexes: Optional[List[Tuple[str, IndexDType]]] = None, + features: Union[List[FeatureSchema], List[Tuple[str, DType]]], + indexes: Optional[ + Union[List[IndexSchema], List[Tuple[str, IndexDType]]] + ] = None, is_unix_timestamp: bool = False, same_sampling_as: Optional[EventSetNode] = None, name: Optional[str] = None, @@ -245,6 +247,43 @@ def input_node( ) +def input_node_from_schema( + schema: Schema, + same_sampling_as: Optional[EventSetNode] = None, + name: Optional[str] = None, +) -> EventSetNode: + """Creates an input [`EventSetNode`][temporian.EventSetNode] from a schema. + + Usage example: + + ```python + >>> # Create two nodes with the same schema. + >>> a = tp.input_node(features=[("f1", tp.float64), ("f2", tp.str_)]) + >>> b = tp.input_node_from_schema(a.schema) + + ``` + + Args: + schema: Schema of the node. + same_sampling_as: If set, the created EventSetNode is guaranteed to have the + same sampling as same_sampling_as`. In this case, `indexes` and + `is_unix_timestamp` should not be provided. Some operators require + for input EventSetNodes to have the same sampling. + name: Name for the EventSetNode. + + Returns: + EventSetNode with the given specifications. + """ + + return input_node( + features=schema.features, + indexes=schema.indexes, + is_unix_timestamp=schema.is_unix_timestamp, + same_sampling_as=same_sampling_as, + name=name, + ) + + @dataclass class Sampling: """A sampling is a reference to the way data is sampled.""" diff --git a/temporian/core/data/schema.py b/temporian/core/data/schema.py index 04602f5c1..82c04e40b 100644 --- a/temporian/core/data/schema.py +++ b/temporian/core/data/schema.py @@ -19,6 +19,7 @@ from typing import List, Tuple, Dict, Union from temporian.core.data.dtype import DType, IndexDType +from google.protobuf import text_format @dataclass(frozen=True) @@ -134,6 +135,53 @@ def check_compatible_features(self, other: Schema, check_order: bool): f"{self.feature_names} and {other.feature_names}." ) + def to_proto(self) -> "serialization.pb.Schema": + """Converts the schema into a protobuf. + + Usage example: + ``` + schema = tp.Schema(features=[("f1",int), (f2, float)]) + proto_schema = schema.to_proto() + restored_schema = tp.Schema.from_proto(proto_schema) + ``` + """ + from temporian.core import serialization + + return serialization._serialize_schema(self) + + def to_proto_file(self, path: str) -> None: + """Save the schema to a file with text protobuf format. + + Usage example: + ``` + schema = tp.Schema(features=[("f1",int), (f2, float)]) + path = "/tmp/my_schema.pbtxt" + schema.to_proto_file(path) + restored_schema = tp.Schema.from_proto_file(path) + ``` + """ + proto = self.to_proto() + with open(path, "wb") as f: + f.write(text_format.MessageToBytes(proto)) + + @classmethod + def from_proto(cls, proto: "serialization.pb.Schema") -> "Schema": + """Creates a schema from a protobuf.""" + + from temporian.core import serialization + + return serialization._unserialize_schema(proto) + + @classmethod + def from_proto_file(cls, path: str) -> "Schema": + """Creates a schema from a file text protobuf.""" + + from temporian.core import serialization + + with open(path, "rb") as f: + proto = text_format.Parse(f.read(), serialization.pb.Schema()) + return Schema.from_proto(proto) + def _normalize_feature(x): if isinstance(x, FeatureSchema): diff --git a/temporian/core/data/test/BUILD b/temporian/core/data/test/BUILD index f233a1995..443909e6f 100644 --- a/temporian/core/data/test/BUILD +++ b/temporian/core/data/test/BUILD @@ -27,3 +27,16 @@ py_test( "//temporian/core/data:dtype", ], ) + + +py_test( + name = "schema_test", + srcs = ["schema_test.py"], + srcs_version = "PY3", + deps = [ + # already_there/absl/testing:absltest + # already_there/absl/testing:parameterized + "//temporian/core/data:schema", + "//temporian/core:serialization", + ], +) diff --git a/temporian/core/data/test/schema_test.py b/temporian/core/data/test/schema_test.py new file mode 100644 index 000000000..3674c74e4 --- /dev/null +++ b/temporian/core/data/test/schema_test.py @@ -0,0 +1,29 @@ +# Copyright 2021 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from temporian.core.data import schema +from temporian.core.data.dtype import DType + + +class SchemaTest(absltest.TestCase): + def test_proto(self): + a = schema.Schema(features=[("f1", DType.INT32), ("f2", DType.FLOAT64)]) + p = a.to_proto() + b = schema.Schema.from_proto(p) + self.assertEqual(a, b) + + +if __name__ == "__main__": + absltest.main() diff --git a/temporian/core/data/test/test_node.py b/temporian/core/data/test/test_node.py index 6e0a4a920..015503eae 100644 --- a/temporian/core/data/test/test_node.py +++ b/temporian/core/data/test/test_node.py @@ -16,6 +16,9 @@ from temporian.core.test import utils from temporian.implementation.numpy.data.event_set import EventSet +from temporian.core.data.node import ( + input_node_from_schema, +) class EventSetNodeTest(absltest.TestCase): @@ -26,6 +29,11 @@ def test_run_input(self): self.assertIsInstance(result, EventSet) self.assertTrue(result is evset) + def test_input_node_from_schema(self): + node = utils.create_input_node() + other_node = input_node_from_schema(node.schema) + self.assertEqual(node.schema, other_node.schema) + def test_run_single_operator(self): evset = utils.create_input_event_set() result = evset.node().simple_moving_average(10) diff --git a/temporian/core/operators/filter_moving_count.py b/temporian/core/operators/filter_moving_count.py index b53ccabde..8d8e3c7f0 100644 --- a/temporian/core/operators/filter_moving_count.py +++ b/temporian/core/operators/filter_moving_count.py @@ -48,7 +48,7 @@ def __init__( self.add_output( "output", create_node_new_features_new_sampling( - features=[], + features=input.schema.features, indexes=input.schema.indexes, is_unix_timestamp=input.schema.is_unix_timestamp, creator=self, diff --git a/temporian/core/operators/test/test_filter_moving_count.py b/temporian/core/operators/test/test_filter_moving_count.py index f49d5891f..4c24c83e2 100644 --- a/temporian/core/operators/test/test_filter_moving_count.py +++ b/temporian/core/operators/test/test_filter_moving_count.py @@ -61,6 +61,16 @@ def test_base(self, input_timestamps, expected_timestamps, win_length): self, result, expected_output, check_sampling=False ) + def test_with_feature( + self, + ): + evset = event_set([1, 2, 4], {"f": [10, 11, 14]}) + expected_output = event_set([1, 4], {"f": [10, 14]}) + result = evset.filter_moving_count(window_length=1.5) + assertOperatorResult( + self, result, expected_output, check_sampling=False + ) + if __name__ == "__main__": absltest.main() diff --git a/temporian/implementation/numpy/operators/filter_moving_count.py b/temporian/implementation/numpy/operators/filter_moving_count.py index 3a188122c..58ec910ec 100644 --- a/temporian/implementation/numpy/operators/filter_moving_count.py +++ b/temporian/implementation/numpy/operators/filter_moving_count.py @@ -45,7 +45,7 @@ def __call__(self, input: EventSet) -> Dict[str, EventSet]: # Fill output EventSet's data for index_key, index_data in input.data.items(): - dst_timestamps = operators_cc.filter_moving_count( + sampling_idx = operators_cc.filter_moving_count( index_data.timestamps, window_length=window_length, ) @@ -53,8 +53,8 @@ def __call__(self, input: EventSet) -> Dict[str, EventSet]: output_evset.set_index_value( index_key, IndexData( - features=[], - timestamps=dst_timestamps, + features=[f[sampling_idx] for f in index_data.features], + timestamps=index_data.timestamps[sampling_idx], schema=output_schema, ), ) diff --git a/temporian/implementation/numpy_cc/operators/filter_moving_count.cc b/temporian/implementation/numpy_cc/operators/filter_moving_count.cc index e96225b13..40f167a23 100644 --- a/temporian/implementation/numpy_cc/operators/filter_moving_count.cc +++ b/temporian/implementation/numpy_cc/operators/filter_moving_count.cc @@ -12,7 +12,7 @@ namespace { namespace py = pybind11; -py::array_t filter_moving_count( +py::array_t filter_moving_count( const py::array_t &event_timestamps, const double window_length) { // Input size const Idx n_event = event_timestamps.shape(0); @@ -20,7 +20,7 @@ py::array_t filter_moving_count( // Access raw input / output data auto v_event = event_timestamps.unchecked<1>(); - std::vector output; + std::vector output; // Index of the last emitted event. If -1, no event was emitted so far. Idx last_emitted_idx = -1; @@ -31,7 +31,7 @@ py::array_t filter_moving_count( (t - v_event[last_emitted_idx]) >= window_length) { // Emitting event. last_emitted_idx = event_idx; - output.push_back(t); + output.push_back(event_idx); } } diff --git a/temporian/test/api_test.py b/temporian/test/api_test.py index 6ce4a1c5a..d6639aec8 100644 --- a/temporian/test/api_test.py +++ b/temporian/test/api_test.py @@ -217,6 +217,20 @@ def test_duration_to_string(self): "2d3h", ) + def test_schema_to_from_proto(self): + a = tp.Schema(features=[("f1", tp.int32), ("f2", tp.float64)]) + p = a.to_proto() + b = tp.Schema.from_proto(p) + self.assertEqual(a, b) + + def test_schema_to_from_proto_file(self): + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "schema.pbtxt") + a = tp.Schema(features=[("f1", tp.int32), ("f2", tp.float64)]) + a.to_proto_file(path) + b = tp.Schema.from_proto_file(path) + self.assertEqual(a, b) + if __name__ == "__main__": absltest.main()