diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 1211edd54a..3bc9d98e62 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -503,7 +503,7 @@ def __hash__(self): @staticmethod def from_proto(data_source: DataSourceProto): watermark = None - if data_source.kafka_options.HasField("watermark"): + if data_source.kafka_options.watermark: watermark = ( timedelta(days=0) if data_source.kafka_options.watermark.ToNanoseconds() == 0 diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 12d7f9b74b..214ab083ab 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -1,3 +1,4 @@ +import copy import functools import warnings from datetime import timedelta @@ -9,7 +10,7 @@ from feast import utils from feast.aggregation import Aggregation -from feast.data_source import DataSource, KafkaSource +from feast.data_source import DataSource, KafkaSource, PushSource from feast.entity import Entity from feast.feature_view import FeatureView from feast.field import Field @@ -39,6 +40,26 @@ class StreamFeatureView(FeatureView): """ NOTE: Stream Feature Views are not yet fully implemented and exist to allow users to register their stream sources and schemas with Feast. + + Attributes: + name: str. The unique name of the stream feature view. + entities: Union[List[Entity], List[str]]. List of entities or entity join keys. + ttl: timedelta. The amount of time this group of features lives. A ttl of 0 indicates that + this group of features lives forever. Note that large ttl's or a ttl of 0 + can result in extremely computationally intensive queries. + tags: Dict[str, str]. A dictionary of key-value pairs to store arbitrary metadata. + online: bool. Defines whether this stream feature view is used in online feature retrieval. + description: str. A human-readable description. + owner: The owner of the on demand feature view, typically the email of the primary + maintainer. + schema: List[Field] The schema of the feature view, including feature, timestamp, and entity + columns. If not specified, can be inferred from the underlying data source. + source: DataSource. The stream source of data where this group of features + is stored. + aggregations (optional): List[Aggregation]. List of aggregations registered with the stream feature view. + mode(optional): str. The mode of execution. + timestamp_field (optional): Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows. + udf (optional): MethodType The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function. """ def __init__( @@ -54,8 +75,8 @@ def __init__( schema: Optional[List[Field]] = None, source: Optional[DataSource] = None, aggregations: Optional[List[Aggregation]] = None, - mode: Optional[str] = "spark", # Mode of ingestion/transformation - timestamp_field: Optional[str] = "", # Timestamp for aggregation + mode: Optional[str] = "spark", + timestamp_field: Optional[str] = "", udf: Optional[MethodType] = None, ): warnings.warn( @@ -63,9 +84,10 @@ def __init__( "Some functionality may still be unstable so functionality can change in the future.", RuntimeWarning, ) + if source is None: - raise ValueError("Stream Feature views need a source specified") - # source uses the batch_source of the kafkasource in feature_view + raise ValueError("Stream Feature views need a source to be specified") + if ( type(source).__name__ not in SUPPORTED_STREAM_SOURCES and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE @@ -74,18 +96,26 @@ def __init__( f"Stream feature views need a stream source, expected one of {SUPPORTED_STREAM_SOURCES} " f"or CUSTOM_SOURCE, got {type(source).__name__}: {source.name} instead " ) + + if aggregations and not timestamp_field: + raise ValueError( + "aggregations must have a timestamp field associated with them to perform the aggregations" + ) + self.aggregations = aggregations or [] - self.mode = mode - self.timestamp_field = timestamp_field + self.mode = mode or "" + self.timestamp_field = timestamp_field or "" self.udf = udf _batch_source = None - if isinstance(source, KafkaSource): + if isinstance(source, KafkaSource) or isinstance(source, PushSource): _batch_source = source.batch_source if source.batch_source else None - + _ttl = ttl + if not _ttl: + _ttl = timedelta(days=0) super().__init__( name=name, entities=entities, - ttl=ttl, + ttl=_ttl, batch_source=_batch_source, stream_source=source, tags=tags, @@ -102,7 +132,10 @@ def __eq__(self, other): if not super().__eq__(other): return False - + if not self.udf: + return not other.udf + if not other.udf: + return False if ( self.mode != other.mode or self.timestamp_field != other.timestamp_field @@ -113,13 +146,14 @@ def __eq__(self, other): return True - def __hash__(self): + def __hash__(self) -> int: return super().__hash__() def to_proto(self): meta = StreamFeatureViewMetaProto(materialization_intervals=[]) if self.created_timestamp: meta.created_timestamp.FromDatetime(self.created_timestamp) + if self.last_updated_timestamp: meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp) @@ -134,6 +168,7 @@ def to_proto(self): ttl_duration = Duration() ttl_duration.FromTimedelta(self.ttl) + batch_source_proto = None if self.batch_source: batch_source_proto = self.batch_source.to_proto() batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}" @@ -143,23 +178,24 @@ def to_proto(self): stream_source_proto = self.stream_source.to_proto() stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}" + udf_proto = None + if self.udf: + udf_proto = UserDefinedFunctionProto( + name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True), + ) spec = StreamFeatureViewSpecProto( name=self.name, entities=self.entities, entity_columns=[field.to_proto() for field in self.entity_columns], features=[field.to_proto() for field in self.schema], - user_defined_function=UserDefinedFunctionProto( - name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True), - ) - if self.udf - else None, + user_defined_function=udf_proto, description=self.description, tags=self.tags, owner=self.owner, - ttl=(ttl_duration if ttl_duration is not None else None), + ttl=ttl_duration, online=self.online, batch_source=batch_source_proto or None, - stream_source=stream_source_proto, + stream_source=stream_source_proto or None, timestamp_field=self.timestamp_field, aggregations=[agg.to_proto() for agg in self.aggregations], mode=self.mode, @@ -239,6 +275,25 @@ def from_proto(cls, sfv_proto): return sfv_feature_view + def __copy__(self): + fv = StreamFeatureView( + name=self.name, + schema=self.schema, + entities=self.entities, + ttl=self.ttl, + tags=self.tags, + online=self.online, + description=self.description, + owner=self.owner, + aggregations=self.aggregations, + mode=self.mode, + timestamp_field=self.timestamp_field, + sources=self.sources, + udf=self.udf, + ) + fv.projection = copy.copy(self.projection) + return fv + def stream_feature_view( *, @@ -251,11 +306,13 @@ def stream_feature_view( schema: Optional[List[Field]] = None, source: Optional[DataSource] = None, aggregations: Optional[List[Aggregation]] = None, - mode: Optional[str] = "spark", # Mode of ingestion/transformation - timestamp_field: Optional[str] = "", # Timestamp for aggregation + mode: Optional[str] = "spark", + timestamp_field: Optional[str] = "", ): """ Creates an StreamFeatureView object with the given user function as udf. + Please make sure that the udf contains all non-built in imports within the function to ensure that the execution + of a deserialized function does not miss imports. """ def mainify(obj): diff --git a/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py b/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py index e19641f291..29cd2f1c26 100644 --- a/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py +++ b/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py @@ -70,3 +70,71 @@ def simple_sfv(df): assert features["test_key"] == [1001] assert "dummy_field" in features assert features["dummy_field"] == [None] + + +@pytest.mark.integration +def test_stream_feature_view_udf(environment) -> None: + """ + Test apply of StreamFeatureView udfs are serialized correctly and usable. + """ + fs = environment.feature_store + + # Create Feature Views + entity = Entity(name="driver_entity", join_keys=["test_key"]) + + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"), + watermark=timedelta(days=1), + ) + + @stream_feature_view( + entities=[entity], + ttl=timedelta(days=30), + owner="test@example.com", + online=True, + schema=[Field(name="dummy_field", dtype=Float32)], + description="desc", + aggregations=[ + Aggregation( + column="dummy_field", function="max", time_window=timedelta(days=1), + ), + Aggregation( + column="dummy_field2", function="count", time_window=timedelta(days=24), + ), + ], + timestamp_field="event_timestamp", + mode="spark", + source=stream_source, + tags={}, + ) + def pandas_view(pandas_df): + import pandas as pd + + assert type(pandas_df) == pd.DataFrame + df = pandas_df.transform(lambda x: x + 10, axis=1) + df.insert(2, "C", [20.2, 230.0, 34.0], True) + return df + + import pandas as pd + + df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}) + + fs.apply([entity, pandas_view]) + stream_feature_views = fs.list_stream_feature_views() + assert len(stream_feature_views) == 1 + assert stream_feature_views[0].name == "pandas_view" + assert stream_feature_views[0] == pandas_view + + sfv = stream_feature_views[0] + + new_df = sfv.udf(df) + + expected_df = pd.DataFrame( + {"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]} + ) + assert new_df.equals(expected_df) diff --git a/sdk/python/tests/unit/test_feature_views.py b/sdk/python/tests/unit/test_feature_views.py index 904260dfe6..64b23edd2c 100644 --- a/sdk/python/tests/unit/test_feature_views.py +++ b/sdk/python/tests/unit/test_feature_views.py @@ -9,7 +9,7 @@ from feast.entity import Entity from feast.field import Field from feast.infra.offline_stores.file_source import FileSource -from feast.stream_feature_view import StreamFeatureView +from feast.stream_feature_view import StreamFeatureView, stream_feature_view from feast.types import Float32 @@ -129,3 +129,75 @@ def test_stream_feature_view_serialization(): new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto) assert new_sfv == sfv + + +def test_stream_feature_view_udfs(): + entity = Entity(name="driver_entity", join_keys=["test_key"]) + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=FileSource(path="some path"), + ) + + @stream_feature_view( + entities=[entity], + ttl=timedelta(days=30), + owner="test@example.com", + online=True, + schema=[Field(name="dummy_field", dtype=Float32)], + description="desc", + aggregations=[ + Aggregation( + column="dummy_field", function="max", time_window=timedelta(days=1), + ) + ], + timestamp_field="event_timestamp", + source=stream_source, + ) + def pandas_udf(pandas_df): + import pandas as pd + + assert type(pandas_df) == pd.DataFrame + df = pandas_df.transform(lambda x: x + 10, axis=1) + return df + + import pandas as pd + + df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}) + sfv = pandas_udf + sfv_proto = sfv.to_proto() + new_sfv = StreamFeatureView.from_proto(sfv_proto) + new_df = new_sfv.udf(df) + + expected_df = pd.DataFrame({"A": [11, 12, 13], "B": [20, 30, 40]}) + + assert new_df.equals(expected_df) + + +def test_stream_feature_view_initialization_with_optional_fields_omitted(): + entity = Entity(name="driver_entity", join_keys=["test_key"]) + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=FileSource(path="some path"), + ) + + sfv = StreamFeatureView( + name="test kafka stream feature view", + entities=[entity], + schema=[], + description="desc", + timestamp_field="event_timestamp", + source=stream_source, + tags={}, + ) + sfv_proto = sfv.to_proto() + + new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto) + assert new_sfv == sfv