diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index f8a288940a..01e8ecbaca 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -204,6 +204,11 @@ def __init__( ) +class SavedDatasetLocationAlreadyExists(Exception): + def __init__(self, location: str): + super().__init__(f"Saved dataset location {location} already exists.") + + class FeastOfflineStoreInvalidName(Exception): def __init__(self, offline_store_class_name: str): super().__init__( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index ea13c3a8db..02225a7b52 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1146,6 +1146,7 @@ def create_saved_dataset( storage: SavedDatasetStorage, tags: Optional[Dict[str, str]] = None, feature_service: Optional[FeatureService] = None, + allow_overwrite: bool = False, ) -> SavedDataset: """ Execute provided retrieval job and persist its outcome in given storage. @@ -1154,6 +1155,14 @@ def create_saved_dataset( Name for the saved dataset should be unique within project, since it's possible to overwrite previously stored dataset with the same name. + Args: + from_: The retrieval job whose result should be persisted. + name: The name of the saved dataset. + storage: The saved dataset storage object indicating where the result should be persisted. + tags (optional): A dictionary of key-value pairs to store arbitrary metadata. + feature_service (optional): The feature service that should be associated with this saved dataset. + allow_overwrite (optional): If True, the persisted result can overwrite an existing table or file. + Returns: SavedDataset object with attached RetrievalJob @@ -1186,7 +1195,7 @@ def create_saved_dataset( dataset.min_event_timestamp = from_.metadata.min_event_timestamp dataset.max_event_timestamp = from_.metadata.max_event_timestamp - from_.persist(storage) + from_.persist(storage=storage, allow_overwrite=allow_overwrite) dataset = dataset.with_retrieval_job( self._get_provider().retrieve_saved_dataset( diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index da19bff5ec..5c3535071a 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -493,7 +493,7 @@ def _execute_query( block_until_done(client=self.client, bq_job=bq_job, timeout=timeout) return bq_job - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetBigQueryStorage) self.to_bigquery( diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index d7f11fb39f..92e133d02e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -402,7 +402,7 @@ def _to_arrow_internal(self) -> pa.Table: def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetAthenaStorage) self.to_athena(table_name=storage.athena_options.table) diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py index 1347f8b37c..80b1e089a1 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py @@ -297,7 +297,7 @@ def _to_arrow_internal(self) -> pa.Table: def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetPostgreSQLStorage) df_to_postgres_table( diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 26d414232f..7c19b1e4e3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -275,7 +275,7 @@ def _to_arrow_internal(self) -> pyarrow.Table: self.to_spark_df().write.parquet(temp_dir, mode="overwrite") return pq.read_table(temp_dir) - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): """ Run the retrieval and persist the results in the same offline store used for read. Please note the persisting is done only within the scope of the spark session. diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py index b5f0b1f950..5a3a9737d3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py @@ -126,7 +126,7 @@ def to_trino( self._client.execute_query(query_text=query) return destination_table - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): """ Run the retrieval and persist the results in the same offline store used for read. """ diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index ca945c3ff3..742366d42e 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -1,3 +1,4 @@ +import os import uuid from datetime import datetime from pathlib import Path @@ -11,13 +12,16 @@ import pytz from pydantic.typing import Literal -from feast import FileSource, OnDemandFeatureView from feast.data_source import DataSource -from feast.errors import FeastJoinKeysDuringMaterialization +from feast.errors import ( + FeastJoinKeysDuringMaterialization, + SavedDatasetLocationAlreadyExists, +) from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView from feast.infra.offline_stores.file_source import ( FileLoggingDestination, + FileSource, SavedDatasetFileStorage, ) from feast.infra.offline_stores.offline_store import ( @@ -30,6 +34,7 @@ get_pyarrow_schema_from_batch_source, ) from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -83,8 +88,13 @@ def _to_arrow_internal(self): df = self.evaluation_function().compute() return pyarrow.Table.from_pandas(df) - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetFileStorage) + + # Check if the specified location already exists. + if not allow_overwrite and os.path.exists(storage.file_options.uri): + raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri) + filesystem, path = FileSource.create_filesystem_and_path( storage.file_options.uri, storage.file_options.s3_endpoint_override, diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py index 52a687a52c..135409ed04 100644 --- a/sdk/python/feast/infra/offline_stores/file_source.py +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -96,12 +96,20 @@ def __eq__(self, other): ) @property - def path(self): - """ - Returns the path of this file data source. - """ + def path(self) -> str: + """Returns the path of this file data source.""" return self.file_options.uri + @property + def file_format(self) -> Optional[FileFormat]: + """Returns the file format of this file data source.""" + return self.file_options.file_format + + @property + def s3_endpoint_override(self) -> Optional[str]: + """Returns the s3 endpoint override of this file data source.""" + return self.file_options.s3_endpoint_override + @staticmethod def from_proto(data_source: DataSourceProto): return FileSource( @@ -177,24 +185,33 @@ def get_table_query_string(self) -> str: class FileOptions: """ Configuration options for a file data source. + + Attributes: + uri: File source url, e.g. s3:// or local file. + s3_endpoint_override: Custom s3 endpoint (used only with s3 uri). + file_format: File source format, e.g. parquet. """ + uri: str + file_format: Optional[FileFormat] + s3_endpoint_override: str + def __init__( self, + uri: str, file_format: Optional[FileFormat], s3_endpoint_override: Optional[str], - uri: Optional[str], ): """ Initializes a FileOptions object. Args: + uri: File source url, e.g. s3:// or local file. file_format (optional): File source format, e.g. parquet. s3_endpoint_override (optional): Custom s3 endpoint (used only with s3 uri). - uri (optional): File source url, e.g. s3:// or local file. """ + self.uri = uri self.file_format = file_format - self.uri = uri or "" self.s3_endpoint_override = s3_endpoint_override or "" @classmethod @@ -269,6 +286,17 @@ def to_data_source(self) -> DataSource: s3_endpoint_override=self.file_options.s3_endpoint_override, ) + @staticmethod + def from_data_source(data_source: DataSource) -> "SavedDatasetStorage": + assert isinstance(data_source, FileSource) + return SavedDatasetFileStorage( + path=data_source.path, + file_format=data_source.file_format + if data_source.file_format + else ParquetFormat(), + s3_endpoint_override=data_source.s3_endpoint_override, + ) + class FileLoggingDestination(LoggingDestination): _proto_kind = "file_destination" diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 741b97e2fd..b3b17eaed3 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -173,8 +173,16 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]: pass @abstractmethod - def persist(self, storage: SavedDatasetStorage): - """Synchronously executes the underlying query and persists the result in the same offline store.""" + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): + """ + Synchronously executes the underlying query and persists the result in the same offline store + at the specified destination. + + Args: + storage: The saved dataset storage object specifying where the result should be persisted. + allow_overwrite: If True, a pre-existing location (e.g. table or file) can be overwritten. + Currently not all individual offline store implementations make use of this parameter. + """ pass @property diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 2acf06017d..1c20ff0c5a 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -483,7 +483,7 @@ def to_redshift(self, table_name: str) -> None: query, ) - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetRedshiftStorage) self.to_redshift(table_name=storage.redshift_options.table) diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 98db97b179..8239aec34c 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -460,7 +460,7 @@ def to_arrow_chunks(self, arrow_options: Optional[Dict] = None) -> Optional[List return arrow_batches - def persist(self, storage: SavedDatasetStorage): + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetSnowflakeStorage) self.to_snowflake(table_name=storage.snowflake_options.table) diff --git a/sdk/python/feast/saved_dataset.py b/sdk/python/feast/saved_dataset.py index e2004d15f4..4a3043a873 100644 --- a/sdk/python/feast/saved_dataset.py +++ b/sdk/python/feast/saved_dataset.py @@ -8,6 +8,7 @@ from feast.data_source import DataSource from feast.dqm.profilers.profiler import Profile, Profiler +from feast.importer import import_class from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto from feast.protos.feast.core.SavedDataset_pb2 import SavedDatasetMeta, SavedDatasetSpec from feast.protos.feast.core.SavedDataset_pb2 import ( @@ -31,6 +32,16 @@ def __new__(cls, name, bases, dct): return kls +_DATA_SOURCE_TO_SAVED_DATASET_STORAGE = { + "FileSource": "feast.infra.offline_stores.file_source.SavedDatasetFileStorage", +} + + +def get_saved_dataset_storage_class_from_path(saved_dataset_storage_path: str): + module_name, class_name = saved_dataset_storage_path.rsplit(".", 1) + return import_class(module_name, class_name, "SavedDatasetStorage") + + class SavedDatasetStorage(metaclass=_StorageRegistry): _proto_attr_name: str @@ -43,11 +54,24 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> "SavedDatasetStorage" @abstractmethod def to_proto(self) -> SavedDatasetStorageProto: - ... + pass @abstractmethod def to_data_source(self) -> DataSource: - ... + pass + + @staticmethod + def from_data_source(data_source: DataSource) -> "SavedDatasetStorage": + data_source_type = type(data_source).__name__ + if data_source_type in _DATA_SOURCE_TO_SAVED_DATASET_STORAGE: + cls = get_saved_dataset_storage_class_from_path( + _DATA_SOURCE_TO_SAVED_DATASET_STORAGE[data_source_type] + ) + return cls.from_data_source(data_source) + else: + raise ValueError( + f"This method currently does not support {data_source_type}." + ) class SavedDataset: diff --git a/sdk/python/tests/integration/e2e/test_validation.py b/sdk/python/tests/integration/e2e/test_validation.py index 26b46d9648..771061b206 100644 --- a/sdk/python/tests/integration/e2e/test_validation.py +++ b/sdk/python/tests/integration/e2e/test_validation.py @@ -65,6 +65,7 @@ def test_historical_retrieval_with_validation(environment, universal_data_source from_=reference_job, name="my_training_dataset", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) saved_dataset = store.get_saved_dataset("my_training_dataset") @@ -95,6 +96,7 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so from_=reference_job, name="my_other_dataset", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) job = store.get_historical_features( @@ -172,6 +174,7 @@ def test_logged_features_validation(environment, universal_data_sources): ), name="reference_for_validating_logged_features", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) log_source_df = store.get_historical_features( @@ -245,6 +248,7 @@ def test_e2e_validation_via_cli(environment, universal_data_sources): from_=retrieval_job, name="reference_for_validating_logged_features", storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, ) reference = saved_dataset.as_reference( name="test_reference", profiler=configurable_profiler diff --git a/sdk/python/tests/integration/offline_store/test_persist.py b/sdk/python/tests/integration/offline_store/test_persist.py new file mode 100644 index 0000000000..8e6f182917 --- /dev/null +++ b/sdk/python/tests/integration/offline_store/test_persist.py @@ -0,0 +1,54 @@ +import pytest + +from feast.errors import SavedDatasetLocationAlreadyExists +from feast.saved_dataset import SavedDatasetStorage +from tests.integration.feature_repos.repo_configuration import ( + construct_universal_feature_views, +) +from tests.integration.feature_repos.universal.entities import ( + customer, + driver, + location, +) + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores(only=["file"]) +def test_persist_does_not_overwrite(environment, universal_data_sources): + """ + Tests that the persist method does not overwrite an existing location in the offline store. + + This test currently is only run against the file offline store as it is the only implementation + that prevents overwriting. As more offline stores add this check, they should be added to this test. + """ + store = environment.feature_store + entities, datasets, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + store.apply([driver(), customer(), location(), *feature_views.values()]) + + features = [ + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ] + + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id"] + ) + job = store.get_historical_features( + entity_df=entity_df, + features=features, + ) + + with pytest.raises(SavedDatasetLocationAlreadyExists): + # Copy data source destination to a saved dataset destination. + saved_dataset_destination = SavedDatasetStorage.from_data_source( + data_sources.customer + ) + + # This should fail since persisting to a preexisting location is not allowed. + store.create_saved_dataset( + from_=job, + name="my_training_dataset", + storage=saved_dataset_destination, + ) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index cd61822e1c..73c5152d47 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -381,6 +381,7 @@ def test_historical_features_persisting( name="saved_dataset", storage=environment.data_source_creator.create_saved_dataset_destination(), tags={"env": "test"}, + allow_overwrite=True, ) event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL