diff --git a/sdk/python/feast/infra/offline_stores/duckdb.py b/sdk/python/feast/infra/offline_stores/duckdb.py index 8e392425ea..a639d54add 100644 --- a/sdk/python/feast/infra/offline_stores/duckdb.py +++ b/sdk/python/feast/infra/offline_stores/duckdb.py @@ -33,7 +33,11 @@ def _read_data_source(data_source: DataSource) -> Table: if isinstance(data_source.file_format, ParquetFormat): return ibis.read_parquet(data_source.path) elif isinstance(data_source.file_format, DeltaFormat): - return ibis.read_delta(data_source.path) + storage_options = { + "AWS_ENDPOINT_URL": data_source.s3_endpoint_override, + } + + return ibis.read_delta(data_source.path, storage_options=storage_options) def _write_data_source( @@ -72,10 +76,18 @@ def _write_data_source( new_table = pyarrow.concat_tables([table, prev_table]) ibis.memtable(new_table).to_parquet(file_options.uri) elif isinstance(data_source.file_format, DeltaFormat): + storage_options = { + "AWS_ENDPOINT_URL": str(data_source.s3_endpoint_override), + } + if mode == "append": from deltalake import DeltaTable - prev_schema = DeltaTable(file_options.uri).schema().to_pyarrow() + prev_schema = ( + DeltaTable(file_options.uri, storage_options=storage_options) + .schema() + .to_pyarrow() + ) table = table.cast(ibis.Schema.from_pyarrow(prev_schema)) write_mode = "append" elif mode == "overwrite": @@ -85,13 +97,19 @@ def _write_data_source( else "error" ) - table.to_delta(file_options.uri, mode=write_mode) + table.to_delta( + file_options.uri, mode=write_mode, storage_options=storage_options + ) class DuckDBOfflineStoreConfig(FeastConfigBaseModel): type: StrictStr = "duckdb" # """ Offline store type selector""" + staging_location: Optional[str] = None + + staging_location_endpoint_override: Optional[str] = None + class DuckDBOfflineStore(OfflineStore): @staticmethod @@ -116,6 +134,8 @@ def pull_latest_from_table_or_query( end_date=end_date, data_source_reader=_read_data_source, data_source_writer=_write_data_source, + staging_location=config.offline_store.staging_location, + staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override, ) @staticmethod @@ -138,6 +158,8 @@ def get_historical_features( full_feature_names=full_feature_names, data_source_reader=_read_data_source, data_source_writer=_write_data_source, + staging_location=config.offline_store.staging_location, + staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override, ) @staticmethod @@ -160,6 +182,8 @@ def pull_all_from_table_or_query( end_date=end_date, data_source_reader=_read_data_source, data_source_writer=_write_data_source, + staging_location=config.offline_store.staging_location, + staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override, ) @staticmethod diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py index 596f3464a9..3fdc6cba31 100644 --- a/sdk/python/feast/infra/offline_stores/file_source.py +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -179,7 +179,15 @@ def get_table_column_names_and_types( elif isinstance(self.file_format, DeltaFormat): from deltalake import DeltaTable - schema = DeltaTable(self.path).schema().to_pyarrow() + storage_options = { + "AWS_ENDPOINT_URL": str(self.s3_endpoint_override), + } + + schema = ( + DeltaTable(self.path, storage_options=storage_options) + .schema() + .to_pyarrow() + ) else: raise Exception(f"Unknown FileFormat -> {self.file_format}") diff --git a/sdk/python/feast/infra/offline_stores/ibis.py b/sdk/python/feast/infra/offline_stores/ibis.py index b9efb87a36..6cc1606a45 100644 --- a/sdk/python/feast/infra/offline_stores/ibis.py +++ b/sdk/python/feast/infra/offline_stores/ibis.py @@ -47,6 +47,8 @@ def pull_latest_from_table_or_query_ibis( end_date: datetime, data_source_reader: Callable[[DataSource], Table], data_source_writer: Callable[[pyarrow.Table, DataSource], None], + staging_location: Optional[str] = None, + staging_location_endpoint_override: Optional[str] = None, ) -> RetrievalJob: fields = join_key_columns + feature_name_columns + [timestamp_field] if created_timestamp_column: @@ -82,6 +84,8 @@ def pull_latest_from_table_or_query_ibis( full_feature_names=False, metadata=None, data_source_writer=data_source_writer, + staging_location=staging_location, + staging_location_endpoint_override=staging_location_endpoint_override, ) @@ -140,6 +144,8 @@ def get_historical_features_ibis( data_source_reader: Callable[[DataSource], Table], data_source_writer: Callable[[pyarrow.Table, DataSource], None], full_feature_names: bool = False, + staging_location: Optional[str] = None, + staging_location_endpoint_override: Optional[str] = None, ) -> RetrievalJob: entity_schema = _get_entity_schema( entity_df=entity_df, @@ -231,6 +237,8 @@ def read_fv( max_event_timestamp=timestamp_range[1], ), data_source_writer=data_source_writer, + staging_location=staging_location, + staging_location_endpoint_override=staging_location_endpoint_override, ) @@ -244,6 +252,8 @@ def pull_all_from_table_or_query_ibis( end_date: datetime, data_source_reader: Callable[[DataSource], Table], data_source_writer: Callable[[pyarrow.Table, DataSource], None], + staging_location: Optional[str] = None, + staging_location_endpoint_override: Optional[str] = None, ) -> RetrievalJob: fields = join_key_columns + feature_name_columns + [timestamp_field] start_date = start_date.astimezone(tz=utc) @@ -270,6 +280,8 @@ def pull_all_from_table_or_query_ibis( full_feature_names=False, metadata=None, data_source_writer=data_source_writer, + staging_location=staging_location, + staging_location_endpoint_override=staging_location_endpoint_override, ) @@ -411,6 +423,23 @@ def point_in_time_join( return acc_table +def list_s3_files(path: str, endpoint_url: str) -> List[str]: + import boto3 + + s3 = boto3.client("s3", endpoint_url=endpoint_url) + if path.startswith("s3://"): + path = path[len("s3://") :] + bucket, prefix = path.split("/", 1) + objects = s3.list_objects_v2(Bucket=bucket, Prefix=prefix) + contents = objects["Contents"] + files = [ + f"s3://{bucket}/{content['Key']}" + for content in contents + if content["Key"].endswith("parquet") + ] + return files + + class IbisRetrievalJob(RetrievalJob): def __init__( self, @@ -419,6 +448,8 @@ def __init__( full_feature_names, metadata, data_source_writer, + staging_location, + staging_location_endpoint_override, ) -> None: super().__init__() self.table = table @@ -428,6 +459,8 @@ def __init__( self._full_feature_names = full_feature_names self._metadata = metadata self.data_source_writer = data_source_writer + self.staging_location = staging_location + self.staging_location_endpoint_override = staging_location_endpoint_override def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: return self.table.execute() @@ -456,3 +489,15 @@ def persist( @property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata + + def supports_remote_storage_export(self) -> bool: + return self.staging_location is not None + + def to_remote_storage(self) -> List[str]: + path = self.staging_location + f"/{str(uuid.uuid4())}" + + storage_options = {"AWS_ENDPOINT_URL": self.staging_location_endpoint_override} + + self.table.to_delta(path, storage_options=storage_options) + + return list_s3_files(path, self.staging_location_endpoint_override) diff --git a/sdk/python/pytest.ini b/sdk/python/pytest.ini index 83317d36c9..8a16294322 100644 --- a/sdk/python/pytest.ini +++ b/sdk/python/pytest.ini @@ -5,4 +5,10 @@ markers = env = FEAST_USAGE=False - IS_TEST=True \ No newline at end of file + IS_TEST=True + +filterwarnings = + ignore::DeprecationWarning:pyspark.sql.pandas.*: + ignore::DeprecationWarning:pyspark.sql.connect.*: + ignore::DeprecationWarning:httpx.*: + ignore::FutureWarning:ibis_substrait.compiler.*: diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 6abe30822f..c4a62be0c0 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -13,11 +13,13 @@ # limitations under the License. import logging import multiprocessing +import os import random from datetime import datetime, timedelta from multiprocessing import Process from sys import platform from typing import Any, Dict, List, Tuple, no_type_check +from unittest import mock import pandas as pd import pytest @@ -180,7 +182,11 @@ def environment(request, worker_id): request.param, worker_id=worker_id, fixture_request=request ) - yield e + if hasattr(e.data_source_creator, "mock_environ"): + with mock.patch.dict(os.environ, e.data_source_creator.mock_environ): + yield e + else: + yield e e.feature_store.teardown() e.data_source_creator.teardown() diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 4007106a06..311325536e 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -33,6 +33,7 @@ from tests.integration.feature_repos.universal.data_sources.file import ( DuckDBDataSourceCreator, DuckDBDeltaDataSourceCreator, + DuckDBDeltaS3DataSourceCreator, FileDataSourceCreator, ) from tests.integration.feature_repos.universal.data_sources.redshift import ( @@ -122,6 +123,14 @@ ("local", DuckDBDeltaDataSourceCreator), ] +if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": + AVAILABLE_OFFLINE_STORES.extend( + [ + ("local", DuckDBDeltaS3DataSourceCreator), + ] + ) + + AVAILABLE_ONLINE_STORES: Dict[ str, Tuple[Union[str, Dict[Any, Any]], Optional[Type[OnlineStoreCreator]]] ] = { diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index 9cdc91a6c8..6f0ac02a00 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -10,6 +10,7 @@ from minio import Minio from testcontainers.core.generic import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.minio import MinioContainer from feast import FileSource from feast.data_format import DeltaFormat, ParquetFormat @@ -134,6 +135,74 @@ def create_logged_features_destination(self) -> LoggingDestination: return FileLoggingDestination(path=d) +class DeltaS3FileSourceCreator(FileDataSourceCreator): + def __init__(self, project_name: str, *args, **kwargs): + super().__init__(project_name) + self.minio = MinioContainer() + self.minio.start() + client = self.minio.get_client() + client.make_bucket("test") + host_ip = self.minio.get_container_host_ip() + exposed_port = self.minio.get_exposed_port(self.minio.port) + self.endpoint_url = f"http://{host_ip}:{exposed_port}" + + self.mock_environ = { + "AWS_ACCESS_KEY_ID": self.minio.access_key, + "AWS_SECRET_ACCESS_KEY": self.minio.secret_key, + "AWS_EC2_METADATA_DISABLED": "true", + "AWS_REGION": "us-east-1", + "AWS_ALLOW_HTTP": "true", + "AWS_S3_ALLOW_UNSAFE_RENAME": "true", + } + + def create_data_source( + self, + df: pd.DataFrame, + destination_name: str, + created_timestamp_column="created_ts", + field_mapping: Optional[Dict[str, str]] = None, + timestamp_field: Optional[str] = "ts", + ) -> DataSource: + from deltalake.writer import write_deltalake + + destination_name = self.get_prefixed_table_name(destination_name) + + storage_options = { + "AWS_ACCESS_KEY_ID": self.minio.access_key, + "AWS_SECRET_ACCESS_KEY": self.minio.secret_key, + "AWS_ENDPOINT_URL": self.endpoint_url, + } + + path = f"s3://test/{str(uuid.uuid4())}/{destination_name}" + + write_deltalake(path, df, storage_options=storage_options) + + return FileSource( + file_format=DeltaFormat(), + path=path, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + field_mapping=field_mapping or {"ts_1": "ts"}, + s3_endpoint_override=self.endpoint_url, + ) + + def create_saved_dataset_destination(self) -> SavedDatasetFileStorage: + return SavedDatasetFileStorage( + path=f"s3://test/{str(uuid.uuid4())}", + file_format=DeltaFormat(), + s3_endpoint_override=self.endpoint_url, + ) + + # LoggingDestination is parquet-only + def create_logged_features_destination(self) -> LoggingDestination: + d = tempfile.mkdtemp(prefix=self.project_name) + self.keep.append(d) + return FileLoggingDestination(path=d) + + def teardown(self): + self.minio.stop() + + class FileParquetDatasetSourceCreator(FileDataSourceCreator): def create_data_source( self, @@ -273,3 +342,12 @@ class DuckDBDeltaDataSourceCreator(DeltaFileSourceCreator): def create_offline_store_config(self): self.duckdb_offline_store_config = DuckDBOfflineStoreConfig() return self.duckdb_offline_store_config + + +class DuckDBDeltaS3DataSourceCreator(DeltaS3FileSourceCreator): + def create_offline_store_config(self): + self.duckdb_offline_store_config = DuckDBOfflineStoreConfig( + staging_location="s3://test/staging", + staging_location_endpoint_override=self.endpoint_url, + ) + return self.duckdb_offline_store_config diff --git a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py b/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py index e85c1d7311..bb4c4e63fc 100644 --- a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py +++ b/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py @@ -31,7 +31,7 @@ def test_spark_materialization_consistency(): batch_engine={"type": "spark.engine", "partitions": 10}, ) spark_environment = construct_test_environment( - spark_config, None, entity_key_serialization_version=1 + spark_config, None, entity_key_serialization_version=2 ) df = create_basic_driver_dataset() diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 958b829a60..6d355e093c 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -139,8 +139,7 @@ def test_historical_features( if job_from_df.supports_remote_storage_export(): files = job_from_df.to_remote_storage() - print(files) - assert len(files) > 0 # This test should be way more detailed + assert len(files) # 0 # This test should be way more detailed start_time = datetime.utcnow() actual_df_from_df_entities = job_from_df.to_df() diff --git a/sdk/python/tests/integration/registration/test_registry.py b/sdk/python/tests/integration/registration/test_registry.py index 70d118ecf9..9ad1a98a05 100644 --- a/sdk/python/tests/integration/registration/test_registry.py +++ b/sdk/python/tests/integration/registration/test_registry.py @@ -18,7 +18,7 @@ import pytest from pytest_lazyfixture import lazy_fixture -from testcontainers.core.container import DockerContainer +from testcontainers.minio import MinioContainer from feast import FileSource from feast.data_format import ParquetFormat @@ -64,25 +64,15 @@ def s3_registry() -> Registry: @pytest.fixture def minio_registry() -> Registry: - minio_user = "minio99" - minio_password = "minio123" bucket_name = "test-bucket" - container: DockerContainer = ( - DockerContainer("quay.io/minio/minio") - .with_exposed_ports(9000, 9001) - .with_env("MINIO_ROOT_USER", minio_user) - .with_env("MINIO_ROOT_PASSWORD", minio_password) - .with_command('server /data --console-address ":9001"') - .with_exposed_ports() - ) - + container = MinioContainer() container.start() + client = container.get_client() + client.make_bucket(bucket_name) - exposed_port = container.get_exposed_port("9000") container_host = container.get_container_host_ip() - - container.exec(f"mkdir /data/{bucket_name}") + exposed_port = container.get_exposed_port(container.port) registry_config = RegistryConfig( path=f"s3://{bucket_name}/registry.db", cache_ttl_seconds=600 @@ -90,8 +80,8 @@ def minio_registry() -> Registry: mock_environ = { "FEAST_S3_ENDPOINT_URL": f"http://{container_host}:{exposed_port}", - "AWS_ACCESS_KEY_ID": minio_user, - "AWS_SECRET_ACCESS_KEY": minio_password, + "AWS_ACCESS_KEY_ID": container.access_key, + "AWS_SECRET_ACCESS_KEY": container.secret_key, "AWS_SESSION_TOKEN": "", }