Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent overwriting existing file during persist #3088

Merged
merged 5 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def __init__(
)


class SavedDatasetLocationAlreadyExists(Exception):
def __init__(self, location: str):
super().__init__(f"Saved dataset location {location} already exists.")


class FeastOfflineStoreInvalidName(Exception):
def __init__(self, offline_store_class_name: str):
super().__init__(
Expand Down
11 changes: 10 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ def create_saved_dataset(
storage: SavedDatasetStorage,
tags: Optional[Dict[str, str]] = None,
feature_service: Optional[FeatureService] = None,
allow_overwrite: bool = False,
) -> SavedDataset:
"""
Execute provided retrieval job and persist its outcome in given storage.
Expand All @@ -1154,6 +1155,14 @@ def create_saved_dataset(
Name for the saved dataset should be unique within project, since it's possible to overwrite previously stored dataset
with the same name.

Args:
from_: The retrieval job whose result should be persisted.
name: The name of the saved dataset.
storage: The saved dataset storage object indicating where the result should be persisted.
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
feature_service (optional): The feature service that should be associated with this saved dataset.
allow_overwrite (optional): If True, the persisted result can overwrite an existing table or file.

Returns:
SavedDataset object with attached RetrievalJob

Expand Down Expand Up @@ -1186,7 +1195,7 @@ def create_saved_dataset(
dataset.min_event_timestamp = from_.metadata.min_event_timestamp
dataset.max_event_timestamp = from_.metadata.max_event_timestamp

from_.persist(storage)
from_.persist(storage=storage, allow_overwrite=allow_overwrite)

dataset = dataset.with_retrieval_job(
self._get_provider().retrieve_saved_dataset(
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def _execute_query(
block_until_done(client=self.client, bq_job=bq_job, timeout=timeout)
return bq_job

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetBigQueryStorage)

self.to_bigquery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _to_arrow_internal(self) -> pa.Table:
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetAthenaStorage)
self.to_athena(table_name=storage.athena_options.table)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _to_arrow_internal(self) -> pa.Table:
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetPostgreSQLStorage)

df_to_postgres_table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _to_arrow_internal(self) -> pyarrow.Table:
self.to_spark_df().write.parquet(temp_dir, mode="overwrite")
return pq.read_table(temp_dir)

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
"""
Run the retrieval and persist the results in the same offline store used for read.
Please note the persisting is done only within the scope of the spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def to_trino(
self._client.execute_query(query_text=query)
return destination_table

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
"""
Run the retrieval and persist the results in the same offline store used for read.
"""
Expand Down
16 changes: 13 additions & 3 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import uuid
from datetime import datetime
from pathlib import Path
Expand All @@ -11,13 +12,16 @@
import pytz
from pydantic.typing import Literal

from feast import FileSource, OnDemandFeatureView
from feast.data_source import DataSource
from feast.errors import FeastJoinKeysDuringMaterialization
from feast.errors import (
FeastJoinKeysDuringMaterialization,
SavedDatasetLocationAlreadyExists,
)
from feast.feature_logging import LoggingConfig, LoggingSource
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores.file_source import (
FileLoggingDestination,
FileSource,
SavedDatasetFileStorage,
)
from feast.infra.offline_stores.offline_store import (
Expand All @@ -30,6 +34,7 @@
get_pyarrow_schema_from_batch_source,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.usage import log_exceptions_and_usage
Expand Down Expand Up @@ -83,8 +88,13 @@ def _to_arrow_internal(self):
df = self.evaluation_function().compute()
return pyarrow.Table.from_pandas(df)

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetFileStorage)

# Check if the specified location already exists.
if not allow_overwrite and os.path.exists(storage.file_options.uri):
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)

filesystem, path = FileSource.create_filesystem_and_path(
storage.file_options.uri,
storage.file_options.s3_endpoint_override,
Expand Down
42 changes: 35 additions & 7 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,20 @@ def __eq__(self, other):
)

@property
def path(self):
"""
Returns the path of this file data source.
"""
def path(self) -> str:
"""Returns the path of this file data source."""
return self.file_options.uri

@property
def file_format(self) -> Optional[FileFormat]:
"""Returns the file format of this file data source."""
return self.file_options.file_format

@property
def s3_endpoint_override(self) -> Optional[str]:
"""Returns the s3 endpoint override of this file data source."""
return self.file_options.s3_endpoint_override

@staticmethod
def from_proto(data_source: DataSourceProto):
return FileSource(
Expand Down Expand Up @@ -177,24 +185,33 @@ def get_table_query_string(self) -> str:
class FileOptions:
"""
Configuration options for a file data source.

Attributes:
uri: File source url, e.g. s3:// or local file.
s3_endpoint_override: Custom s3 endpoint (used only with s3 uri).
file_format: File source format, e.g. parquet.
"""

uri: str
file_format: Optional[FileFormat]
s3_endpoint_override: str

def __init__(
self,
uri: str,
file_format: Optional[FileFormat],
s3_endpoint_override: Optional[str],
uri: Optional[str],
):
"""
Initializes a FileOptions object.

Args:
uri: File source url, e.g. s3:// or local file.
file_format (optional): File source format, e.g. parquet.
s3_endpoint_override (optional): Custom s3 endpoint (used only with s3 uri).
uri (optional): File source url, e.g. s3:// or local file.
"""
self.uri = uri
self.file_format = file_format
self.uri = uri or ""
self.s3_endpoint_override = s3_endpoint_override or ""

@classmethod
Expand Down Expand Up @@ -269,6 +286,17 @@ def to_data_source(self) -> DataSource:
s3_endpoint_override=self.file_options.s3_endpoint_override,
)

@staticmethod
def from_data_source(data_source: DataSource) -> "SavedDatasetStorage":
assert isinstance(data_source, FileSource)
return SavedDatasetFileStorage(
path=data_source.path,
file_format=data_source.file_format
if data_source.file_format
else ParquetFormat(),
s3_endpoint_override=data_source.s3_endpoint_override,
)


class FileLoggingDestination(LoggingDestination):
_proto_kind = "file_destination"
Expand Down
12 changes: 10 additions & 2 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,16 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
pass

@abstractmethod
def persist(self, storage: SavedDatasetStorage):
"""Synchronously executes the underlying query and persists the result in the same offline store."""
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
"""
Synchronously executes the underlying query and persists the result in the same offline store
at the specified destination.

Args:
storage: The saved dataset storage object specifying where the result should be persisted.
allow_overwrite: If True, a pre-existing location (e.g. table or file) can be overwritten.
Currently not all individual offline store implementations make use of this parameter.
"""
pass

@property
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def to_redshift(self, table_name: str) -> None:
query,
)

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetRedshiftStorage)
self.to_redshift(table_name=storage.redshift_options.table)

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def to_arrow_chunks(self, arrow_options: Optional[Dict] = None) -> Optional[List

return arrow_batches

def persist(self, storage: SavedDatasetStorage):
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetSnowflakeStorage)
self.to_snowflake(table_name=storage.snowflake_options.table)

Expand Down
28 changes: 26 additions & 2 deletions sdk/python/feast/saved_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from feast.data_source import DataSource
from feast.dqm.profilers.profiler import Profile, Profiler
from feast.importer import import_class
from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto
from feast.protos.feast.core.SavedDataset_pb2 import SavedDatasetMeta, SavedDatasetSpec
from feast.protos.feast.core.SavedDataset_pb2 import (
Expand All @@ -31,6 +32,16 @@ def __new__(cls, name, bases, dct):
return kls


_DATA_SOURCE_TO_SAVED_DATASET_STORAGE = {
"FileSource": "feast.infra.offline_stores.file_source.SavedDatasetFileStorage",
}


def get_saved_dataset_storage_class_from_path(saved_dataset_storage_path: str):
module_name, class_name = saved_dataset_storage_path.rsplit(".", 1)
return import_class(module_name, class_name, "SavedDatasetStorage")


class SavedDatasetStorage(metaclass=_StorageRegistry):
_proto_attr_name: str

Expand All @@ -43,11 +54,24 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> "SavedDatasetStorage"

@abstractmethod
def to_proto(self) -> SavedDatasetStorageProto:
...
pass

@abstractmethod
def to_data_source(self) -> DataSource:
...
pass

@staticmethod
def from_data_source(data_source: DataSource) -> "SavedDatasetStorage":
data_source_type = type(data_source).__name__
if data_source_type in _DATA_SOURCE_TO_SAVED_DATASET_STORAGE:
cls = get_saved_dataset_storage_class_from_path(
_DATA_SOURCE_TO_SAVED_DATASET_STORAGE[data_source_type]
)
return cls.from_data_source(data_source)
else:
raise ValueError(
f"This method currently does not support {data_source_type}."
)


class SavedDataset:
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/tests/integration/e2e/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_historical_retrieval_with_validation(environment, universal_data_source
from_=reference_job,
name="my_training_dataset",
storage=environment.data_source_creator.create_saved_dataset_destination(),
allow_overwrite=True,
)
saved_dataset = store.get_saved_dataset("my_training_dataset")

Expand Down Expand Up @@ -95,6 +96,7 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so
from_=reference_job,
name="my_other_dataset",
storage=environment.data_source_creator.create_saved_dataset_destination(),
allow_overwrite=True,
)

job = store.get_historical_features(
Expand Down Expand Up @@ -172,6 +174,7 @@ def test_logged_features_validation(environment, universal_data_sources):
),
name="reference_for_validating_logged_features",
storage=environment.data_source_creator.create_saved_dataset_destination(),
allow_overwrite=True,
)

log_source_df = store.get_historical_features(
Expand Down Expand Up @@ -245,6 +248,7 @@ def test_e2e_validation_via_cli(environment, universal_data_sources):
from_=retrieval_job,
name="reference_for_validating_logged_features",
storage=environment.data_source_creator.create_saved_dataset_destination(),
allow_overwrite=True,
)
reference = saved_dataset.as_reference(
name="test_reference", profiler=configurable_profiler
Expand Down
54 changes: 54 additions & 0 deletions sdk/python/tests/integration/offline_store/test_persist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from feast.errors import SavedDatasetLocationAlreadyExists
from feast.saved_dataset import SavedDatasetStorage
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
)
from tests.integration.feature_repos.universal.entities import (
customer,
driver,
location,
)


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file"])
def test_persist_does_not_overwrite(environment, universal_data_sources):
"""
Tests that the persist method does not overwrite an existing location in the offline store.

This test currently is only run against the file offline store as it is the only implementation
that prevents overwriting. As more offline stores add this check, they should be added to this test.
"""
store = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
store.apply([driver(), customer(), location(), *feature_views.values()])

features = [
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
]

entity_df = datasets.entity_df.drop(
columns=["order_id", "origin_id", "destination_id"]
)
job = store.get_historical_features(
entity_df=entity_df,
features=features,
)

with pytest.raises(SavedDatasetLocationAlreadyExists):
# Copy data source destination to a saved dataset destination.
saved_dataset_destination = SavedDatasetStorage.from_data_source(
data_sources.customer
)

# This should fail since persisting to a preexisting location is not allowed.
store.create_saved_dataset(
from_=job,
name="my_training_dataset",
storage=saved_dataset_destination,
)
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def test_historical_features_persisting(
name="saved_dataset",
storage=environment.data_source_creator.create_saved_dataset_destination(),
tags={"env": "test"},
allow_overwrite=True,
)

event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
Expand Down