diff --git a/sdk/python/feast/infra/registry_stores/sql.py b/sdk/python/feast/infra/registry_stores/sql.py index 503aaf8688..f793ef7376 100644 --- a/sdk/python/feast/infra/registry_stores/sql.py +++ b/sdk/python/feast/infra/registry_stores/sql.py @@ -1,7 +1,7 @@ from datetime import datetime from pathlib import Path from threading import Lock -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from sqlalchemy import ( # type: ignore BigInteger, @@ -39,6 +39,7 @@ FeatureService as FeatureServiceProto, ) from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( OnDemandFeatureView as OnDemandFeatureViewProto, ) @@ -138,6 +139,14 @@ Column("validation_reference_proto", LargeBinary, nullable=False), ) +managed_infra = Table( + "managed_infra", + metadata, + Column("infra_name", String(50), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("infra_proto", LargeBinary, nullable=False), +) + class SqlRegistry(BaseRegistry): def __init__( @@ -168,6 +177,7 @@ def teardown(self): conn.execute(stmt) def refresh(self): + # This method is a no-op since we're always reading the latest values from the db. pass def get_stream_feature_view( @@ -353,16 +363,7 @@ def apply_data_source( def apply_feature_view( self, feature_view: BaseFeatureView, project: str, commit: bool = True ): - if isinstance(feature_view, StreamFeatureView): - fv_table = stream_feature_views - elif isinstance(feature_view, FeatureView): - fv_table = feature_views - elif isinstance(feature_view, OnDemandFeatureView): - fv_table = on_demand_feature_views - elif isinstance(feature_view, RequestFeatureView): - fv_table = request_feature_views - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + fv_table = self._infer_fv_table(feature_view) return self._apply_object( fv_table, "feature_view_name", feature_view, "feature_view_proto" @@ -457,7 +458,25 @@ def apply_materialization( end_date: datetime, commit: bool = True, ): - pass + table = self._infer_fv_table(feature_view) + python_class, proto_class = self._infer_fv_classes(feature_view) + + if python_class in {RequestFeatureView, OnDemandFeatureView}: + raise ValueError( + f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" + ) + fv: Union[FeatureView, StreamFeatureView] = self._get_object( + table, + feature_view.name, + project, + proto_class, + python_class, + "feature_view_name", + "feature_view_proto", + FeatureViewNotFoundException, + ) + fv.materialization_intervals.append((start_date, end_date)) + self._apply_object(table, "feature_view_name", fv, "feature_view_proto") def delete_validation_reference(self, name: str, project: str, commit: bool = True): self._delete_object( @@ -469,10 +488,21 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr ) def update_infra(self, infra: Infra, project: str, commit: bool = True): - pass + self._apply_object( + managed_infra, "infra_name", infra, "infra_proto", name="infra_obj" + ) def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - return Infra() + return self._get_object( + managed_infra, + "infra_obj", + project, + InfraProto, + Infra, + "infra_name", + "infra_proto", + None, + ) def apply_user_metadata( self, @@ -480,16 +510,7 @@ def apply_user_metadata( feature_view: BaseFeatureView, metadata_bytes: Optional[bytes], ): - if isinstance(feature_view, StreamFeatureView): - table = stream_feature_views - elif isinstance(feature_view, FeatureView): - table = feature_views - elif isinstance(feature_view, OnDemandFeatureView): - table = on_demand_feature_views - elif isinstance(feature_view, RequestFeatureView): - table = request_feature_views - else: - raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + table = self._infer_fv_table(feature_view) name = feature_view.name with self.engine.connect() as conn: @@ -511,9 +532,7 @@ def apply_user_metadata( else: raise FeatureViewNotFoundException(feature_view.name, project=project) - def get_user_metadata( - self, project: str, feature_view: BaseFeatureView - ) -> Optional[bytes]: + def _infer_fv_table(self, feature_view): if isinstance(feature_view, StreamFeatureView): table = stream_feature_views elif isinstance(feature_view, FeatureView): @@ -524,6 +543,25 @@ def get_user_metadata( table = request_feature_views else: raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return table + + def _infer_fv_classes(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + python_class, proto_class = StreamFeatureView, StreamFeatureViewProto + elif isinstance(feature_view, FeatureView): + python_class, proto_class = FeatureView, FeatureViewProto + elif isinstance(feature_view, OnDemandFeatureView): + python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto + elif isinstance(feature_view, RequestFeatureView): + python_class, proto_class = RequestFeatureView, RequestFeatureViewProto + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return python_class, proto_class + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + table = self._infer_fv_table(feature_view) name = feature_view.name with self.engine.connect() as conn: @@ -556,12 +594,11 @@ def proto(self) -> RegistryProto: return r def commit(self): + # This method is a no-op since we're always writing values eagerly to the db. pass - def _apply_object( - self, table, id_field_name, obj, proto_field_name, - ): - name = obj.name + def _apply_object(self, table, id_field_name, obj, proto_field_name, name=None): + name = name or obj.name with self.engine.connect() as conn: stmt = select(table).where(getattr(table.c, id_field_name) == name) row = conn.execute(stmt).first()