Skip to content

Commit

Permalink
feat: Add s3 remote storage export for duckdb (#4195)
Browse files Browse the repository at this point in the history
add s3 remote export, tests for duckdb

Signed-off-by: tokoko <[email protected]>
  • Loading branch information
tokoko authored May 16, 2024
1 parent a17725d commit 6a04c48
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 26 deletions.
30 changes: 27 additions & 3 deletions sdk/python/feast/infra/offline_stores/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _read_data_source(data_source: DataSource) -> Table:
if isinstance(data_source.file_format, ParquetFormat):
return ibis.read_parquet(data_source.path)
elif isinstance(data_source.file_format, DeltaFormat):
return ibis.read_delta(data_source.path)
storage_options = {
"AWS_ENDPOINT_URL": data_source.s3_endpoint_override,
}

return ibis.read_delta(data_source.path, storage_options=storage_options)


def _write_data_source(
Expand Down Expand Up @@ -72,10 +76,18 @@ def _write_data_source(
new_table = pyarrow.concat_tables([table, prev_table])
ibis.memtable(new_table).to_parquet(file_options.uri)
elif isinstance(data_source.file_format, DeltaFormat):
storage_options = {
"AWS_ENDPOINT_URL": str(data_source.s3_endpoint_override),
}

if mode == "append":
from deltalake import DeltaTable

prev_schema = DeltaTable(file_options.uri).schema().to_pyarrow()
prev_schema = (
DeltaTable(file_options.uri, storage_options=storage_options)
.schema()
.to_pyarrow()
)
table = table.cast(ibis.Schema.from_pyarrow(prev_schema))
write_mode = "append"
elif mode == "overwrite":
Expand All @@ -85,13 +97,19 @@ def _write_data_source(
else "error"
)

table.to_delta(file_options.uri, mode=write_mode)
table.to_delta(
file_options.uri, mode=write_mode, storage_options=storage_options
)


class DuckDBOfflineStoreConfig(FeastConfigBaseModel):
type: StrictStr = "duckdb"
# """ Offline store type selector"""

staging_location: Optional[str] = None

staging_location_endpoint_override: Optional[str] = None


class DuckDBOfflineStore(OfflineStore):
@staticmethod
Expand All @@ -116,6 +134,8 @@ def pull_latest_from_table_or_query(
end_date=end_date,
data_source_reader=_read_data_source,
data_source_writer=_write_data_source,
staging_location=config.offline_store.staging_location,
staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override,
)

@staticmethod
Expand All @@ -138,6 +158,8 @@ def get_historical_features(
full_feature_names=full_feature_names,
data_source_reader=_read_data_source,
data_source_writer=_write_data_source,
staging_location=config.offline_store.staging_location,
staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override,
)

@staticmethod
Expand All @@ -160,6 +182,8 @@ def pull_all_from_table_or_query(
end_date=end_date,
data_source_reader=_read_data_source,
data_source_writer=_write_data_source,
staging_location=config.offline_store.staging_location,
staging_location_endpoint_override=config.offline_store.staging_location_endpoint_override,
)

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,15 @@ def get_table_column_names_and_types(
elif isinstance(self.file_format, DeltaFormat):
from deltalake import DeltaTable

schema = DeltaTable(self.path).schema().to_pyarrow()
storage_options = {
"AWS_ENDPOINT_URL": str(self.s3_endpoint_override),
}

schema = (
DeltaTable(self.path, storage_options=storage_options)
.schema()
.to_pyarrow()
)
else:
raise Exception(f"Unknown FileFormat -> {self.file_format}")

Expand Down
45 changes: 45 additions & 0 deletions sdk/python/feast/infra/offline_stores/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def pull_latest_from_table_or_query_ibis(
end_date: datetime,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
fields = join_key_columns + feature_name_columns + [timestamp_field]
if created_timestamp_column:
Expand Down Expand Up @@ -82,6 +84,8 @@ def pull_latest_from_table_or_query_ibis(
full_feature_names=False,
metadata=None,
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
)


Expand Down Expand Up @@ -140,6 +144,8 @@ def get_historical_features_ibis(
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
full_feature_names: bool = False,
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
entity_schema = _get_entity_schema(
entity_df=entity_df,
Expand Down Expand Up @@ -231,6 +237,8 @@ def read_fv(
max_event_timestamp=timestamp_range[1],
),
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
)


Expand All @@ -244,6 +252,8 @@ def pull_all_from_table_or_query_ibis(
end_date: datetime,
data_source_reader: Callable[[DataSource], Table],
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
staging_location: Optional[str] = None,
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
fields = join_key_columns + feature_name_columns + [timestamp_field]
start_date = start_date.astimezone(tz=utc)
Expand All @@ -270,6 +280,8 @@ def pull_all_from_table_or_query_ibis(
full_feature_names=False,
metadata=None,
data_source_writer=data_source_writer,
staging_location=staging_location,
staging_location_endpoint_override=staging_location_endpoint_override,
)


Expand Down Expand Up @@ -411,6 +423,23 @@ def point_in_time_join(
return acc_table


def list_s3_files(path: str, endpoint_url: str) -> List[str]:
import boto3

s3 = boto3.client("s3", endpoint_url=endpoint_url)
if path.startswith("s3://"):
path = path[len("s3://") :]
bucket, prefix = path.split("/", 1)
objects = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
contents = objects["Contents"]
files = [
f"s3://{bucket}/{content['Key']}"
for content in contents
if content["Key"].endswith("parquet")
]
return files


class IbisRetrievalJob(RetrievalJob):
def __init__(
self,
Expand All @@ -419,6 +448,8 @@ def __init__(
full_feature_names,
metadata,
data_source_writer,
staging_location,
staging_location_endpoint_override,
) -> None:
super().__init__()
self.table = table
Expand All @@ -428,6 +459,8 @@ def __init__(
self._full_feature_names = full_feature_names
self._metadata = metadata
self.data_source_writer = data_source_writer
self.staging_location = staging_location
self.staging_location_endpoint_override = staging_location_endpoint_override

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
return self.table.execute()
Expand Down Expand Up @@ -456,3 +489,15 @@ def persist(
@property
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata

def supports_remote_storage_export(self) -> bool:
return self.staging_location is not None

def to_remote_storage(self) -> List[str]:
path = self.staging_location + f"/{str(uuid.uuid4())}"

storage_options = {"AWS_ENDPOINT_URL": self.staging_location_endpoint_override}

self.table.to_delta(path, storage_options=storage_options)

return list_s3_files(path, self.staging_location_endpoint_override)
8 changes: 7 additions & 1 deletion sdk/python/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@ markers =

env =
FEAST_USAGE=False
IS_TEST=True
IS_TEST=True

filterwarnings =
ignore::DeprecationWarning:pyspark.sql.pandas.*:
ignore::DeprecationWarning:pyspark.sql.connect.*:
ignore::DeprecationWarning:httpx.*:
ignore::FutureWarning:ibis_substrait.compiler.*:
8 changes: 7 additions & 1 deletion sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import logging
import multiprocessing
import os
import random
from datetime import datetime, timedelta
from multiprocessing import Process
from sys import platform
from typing import Any, Dict, List, Tuple, no_type_check
from unittest import mock

import pandas as pd
import pytest
Expand Down Expand Up @@ -180,7 +182,11 @@ def environment(request, worker_id):
request.param, worker_id=worker_id, fixture_request=request
)

yield e
if hasattr(e.data_source_creator, "mock_environ"):
with mock.patch.dict(os.environ, e.data_source_creator.mock_environ):
yield e
else:
yield e

e.feature_store.teardown()
e.data_source_creator.teardown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tests.integration.feature_repos.universal.data_sources.file import (
DuckDBDataSourceCreator,
DuckDBDeltaDataSourceCreator,
DuckDBDeltaS3DataSourceCreator,
FileDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.redshift import (
Expand Down Expand Up @@ -122,6 +123,14 @@
("local", DuckDBDeltaDataSourceCreator),
]

if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True":
AVAILABLE_OFFLINE_STORES.extend(
[
("local", DuckDBDeltaS3DataSourceCreator),
]
)


AVAILABLE_ONLINE_STORES: Dict[
str, Tuple[Union[str, Dict[Any, Any]], Optional[Type[OnlineStoreCreator]]]
] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from minio import Minio
from testcontainers.core.generic import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.minio import MinioContainer

from feast import FileSource
from feast.data_format import DeltaFormat, ParquetFormat
Expand Down Expand Up @@ -134,6 +135,74 @@ def create_logged_features_destination(self) -> LoggingDestination:
return FileLoggingDestination(path=d)


class DeltaS3FileSourceCreator(FileDataSourceCreator):
def __init__(self, project_name: str, *args, **kwargs):
super().__init__(project_name)
self.minio = MinioContainer()
self.minio.start()
client = self.minio.get_client()
client.make_bucket("test")
host_ip = self.minio.get_container_host_ip()
exposed_port = self.minio.get_exposed_port(self.minio.port)
self.endpoint_url = f"http://{host_ip}:{exposed_port}"

self.mock_environ = {
"AWS_ACCESS_KEY_ID": self.minio.access_key,
"AWS_SECRET_ACCESS_KEY": self.minio.secret_key,
"AWS_EC2_METADATA_DISABLED": "true",
"AWS_REGION": "us-east-1",
"AWS_ALLOW_HTTP": "true",
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}

def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
timestamp_field: Optional[str] = "ts",
) -> DataSource:
from deltalake.writer import write_deltalake

destination_name = self.get_prefixed_table_name(destination_name)

storage_options = {
"AWS_ACCESS_KEY_ID": self.minio.access_key,
"AWS_SECRET_ACCESS_KEY": self.minio.secret_key,
"AWS_ENDPOINT_URL": self.endpoint_url,
}

path = f"s3://test/{str(uuid.uuid4())}/{destination_name}"

write_deltalake(path, df, storage_options=storage_options)

return FileSource(
file_format=DeltaFormat(),
path=path,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping or {"ts_1": "ts"},
s3_endpoint_override=self.endpoint_url,
)

def create_saved_dataset_destination(self) -> SavedDatasetFileStorage:
return SavedDatasetFileStorage(
path=f"s3://test/{str(uuid.uuid4())}",
file_format=DeltaFormat(),
s3_endpoint_override=self.endpoint_url,
)

# LoggingDestination is parquet-only
def create_logged_features_destination(self) -> LoggingDestination:
d = tempfile.mkdtemp(prefix=self.project_name)
self.keep.append(d)
return FileLoggingDestination(path=d)

def teardown(self):
self.minio.stop()


class FileParquetDatasetSourceCreator(FileDataSourceCreator):
def create_data_source(
self,
Expand Down Expand Up @@ -273,3 +342,12 @@ class DuckDBDeltaDataSourceCreator(DeltaFileSourceCreator):
def create_offline_store_config(self):
self.duckdb_offline_store_config = DuckDBOfflineStoreConfig()
return self.duckdb_offline_store_config


class DuckDBDeltaS3DataSourceCreator(DeltaS3FileSourceCreator):
def create_offline_store_config(self):
self.duckdb_offline_store_config = DuckDBOfflineStoreConfig(
staging_location="s3://test/staging",
staging_location_endpoint_override=self.endpoint_url,
)
return self.duckdb_offline_store_config
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_spark_materialization_consistency():
batch_engine={"type": "spark.engine", "partitions": 10},
)
spark_environment = construct_test_environment(
spark_config, None, entity_key_serialization_version=1
spark_config, None, entity_key_serialization_version=2
)

df = create_basic_driver_dataset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def test_historical_features(

if job_from_df.supports_remote_storage_export():
files = job_from_df.to_remote_storage()
print(files)
assert len(files) > 0 # This test should be way more detailed
assert len(files) # 0 # This test should be way more detailed

start_time = datetime.utcnow()
actual_df_from_df_entities = job_from_df.to_df()
Expand Down
Loading

0 comments on commit 6a04c48

Please sign in to comment.