Skip to content

Commit

Permalink
Merge pull request #10 from tmihalac/implement-remaining-offline-methods
Browse files Browse the repository at this point in the history
Implement remaining Remote Offline Store methods
  • Loading branch information
tmihalac authored Jun 7, 2024
2 parents a14cd59 + 6d22f18 commit 73938d6
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 197 deletions.
364 changes: 294 additions & 70 deletions sdk/python/feast/infra/offline_stores/remote.py

Large diffs are not rendered by default.

213 changes: 159 additions & 54 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import json
import logging
import traceback
from datetime import datetime
from typing import Any, Dict, List

import pyarrow as pa
import pyarrow.flight as fl

from feast import FeatureStore, FeatureView
from feast import FeatureStore, FeatureView, utils
from feast.feature_logging import FeatureServiceLoggingSource
from feast.feature_view import DUMMY_ENTITY_NAME
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +24,7 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
# A dictionary of configured flights, e.g. API calls received and not yet served
self.flights: Dict[str, Any] = {}
self.store = store
self.offline_store = get_offline_store_from_config(store.config.offline_store)

@classmethod
def descriptor_to_key(self, descriptor):
Expand Down Expand Up @@ -126,67 +131,167 @@ def do_get(self, context, ticket):
api = command["api"]
logger.debug(f"get command is {command}")
logger.debug(f"requested api is {api}")
if api == "get_historical_features":
# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)
logger.debug(f"do_get: entity_df is {entity_df}")

feature_view_names = command["feature_view_names"]
logger.debug(f"do_get: feature_view_names is {feature_view_names}")
name_aliases = command["name_aliases"]
logger.debug(f"do_get: name_aliases is {name_aliases}")
feature_refs = command["feature_refs"]
logger.debug(f"do_get: feature_refs is {feature_refs}")
project = command["project"]
logger.debug(f"do_get: project is {project}")
full_feature_names = command["full_feature_names"]
feature_views = self.list_feature_views_by_name(
feature_view_names=feature_view_names,
name_aliases=name_aliases,
project=project,
)
logger.debug(f"do_get: feature_views is {feature_views}")
try:
if api == OfflineServer.get_historical_features.__name__:
df = self.get_historical_features(command, key).to_df()
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
df = self.pull_all_from_table_or_query(command).to_df()
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
df = self.pull_latest_from_table_or_query(command).to_df()
else:
raise NotImplementedError
except Exception as e:
logger.exception(e)
traceback.print_exc()
raise e

logger.info(
f"get_historical_features for: entity_df from {entity_df.index[0]} to {entity_df.index[len(entity_df)-1]}, "
f"feature_views is {[(fv.name, fv.entities) for fv in feature_views]}"
f"feature_refs is {feature_refs}"
)
table = pa.Table.from_pandas(df)

try:
training_df = (
self.store._get_provider()
.get_historical_features(
config=self.store.config,
feature_views=feature_views,
feature_refs=feature_refs,
entity_df=entity_df,
registry=self.store._registry,
project=project,
full_feature_names=full_feature_names,
)
.to_df()
)
logger.debug(f"Len of training_df is {len(training_df)}")
table = pa.Table.from_pandas(training_df)
except Exception as e:
logger.exception(e)
traceback.print_exc()
raise e
# Get service is consumed, so we clear the corresponding flight and data
del self.flights[key]
return fl.RecordBatchStream(table)

# Get service is consumed, so we clear the corresponding flight and data
del self.flights[key]
def offline_write_batch(self, command, key):
feature_view_names = command["feature_view_names"]
assert (
len(feature_view_names) == 1
), "feature_view_names list should only have one item"
name_aliases = command["name_aliases"]
assert len(name_aliases) == 1, "name_aliases list should only have one item"
project = self.store.config.project
feature_views = self.list_feature_views_by_name(
feature_view_names=feature_view_names,
name_aliases=name_aliases,
project=project,
)

return fl.RecordBatchStream(table)
else:
raise NotImplementedError
assert len(feature_views) == 1
table = self.flights[key]
self.offline_store.offline_write_batch(
self.store.config, feature_views[0], table, command["progress"]
)

def write_logged_features(self, command, key):
table = self.flights[key]
feature_service = self.store.get_feature_service(
command["feature_service_name"]
)

self.offline_store.write_logged_features(
config=self.store.config,
data=table,
source=FeatureServiceLoggingSource(
feature_service, self.store.config.project
),
logging_config=feature_service.logging_config,
registry=self.store.registry,
)

def pull_all_from_table_or_query(self, command):
return self.offline_store.pull_all_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
command["join_key_columns"],
command["feature_name_columns"],
command["timestamp_field"],
utils.make_tzaware(datetime.fromisoformat(command["start_date"])),
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

def pull_latest_from_table_or_query(self, command):
return self.offline_store.pull_latest_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
command["join_key_columns"],
command["feature_name_columns"],
command["timestamp_field"],
command["created_timestamp_column"],
utils.make_tzaware(datetime.fromisoformat(command["start_date"])),
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

def list_actions(self, context):
return []
return [
(
OfflineServer.offline_write_batch.__name__,
"Writes the specified arrow table to the data source underlying the specified feature view.",
),
(
OfflineServer.write_logged_features.__name__,
"Writes logged features to a specified destination in the offline store.",
),
(
OfflineServer.persist.__name__,
"Synchronously executes the underlying query and persists the result in the same offline store at the "
"specified destination.",
),
]

def get_historical_features(self, command, key):
# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)
feature_view_names = command["feature_view_names"]
name_aliases = command["name_aliases"]
feature_refs = command["feature_refs"]
project = command["project"]
full_feature_names = command["full_feature_names"]
feature_views = self.list_feature_views_by_name(
feature_view_names=feature_view_names,
name_aliases=name_aliases,
project=project,
)
retJob = self.offline_store.get_historical_features(
config=self.store.config,
feature_views=feature_views,
feature_refs=feature_refs,
entity_df=entity_df,
registry=self.store.registry,
project=project,
full_feature_names=full_feature_names,
)
return retJob

def persist(self, command, key):
try:
api = command["api"]
if api == OfflineServer.get_historical_features.__name__:
ret_job = self.get_historical_features(command, key)
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
ret_job = self.pull_latest_from_table_or_query(command)
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
ret_job = self.pull_all_from_table_or_query(command)
else:
raise NotImplementedError

data_source = self.store.get_data_source(command["data_source_name"])
storage = SavedDatasetStorage.from_data_source(data_source)
ret_job.persist(storage, command["allow_overwrite"], command["timeout"])
except Exception as e:
logger.exception(e)
traceback.print_exc()
raise e

def do_action(self, context, action):
raise NotImplementedError
command_descriptor = fl.FlightDescriptor.deserialize(action.body.to_pybytes())

key = OfflineServer.descriptor_to_key(command_descriptor)
command = json.loads(key[1])
logger.info(f"do_action command is {command}")

try:
if action.type == OfflineServer.offline_write_batch.__name__:
self.offline_write_batch(command, key)
elif action.type == OfflineServer.write_logged_features.__name__:
self.write_logged_features(command, key)
elif action.type == OfflineServer.persist.__name__:
self.persist(command, key)
else:
raise NotImplementedError
except Exception as e:
logger.exception(e)
traceback.print_exc()
raise e

def do_drop_dataset(self, dataset):
pass
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/templates/local/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def bootstrap():

example_py_file = repo_path / "example_repo.py"
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path))


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions sdk/python/feast/templates/local/feature_repo/example_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
PushSource,
RequestSource,
)
from feast.feature_logging import LoggingConfig
from feast.infra.offline_stores.file_source import FileLoggingDestination
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64

Expand Down Expand Up @@ -88,6 +90,9 @@ def transformed_conv_rate(inputs: pd.DataFrame) -> pd.DataFrame:
driver_stats_fv[["conv_rate"]], # Sub-selects a feature from a feature view
transformed_conv_rate, # Selects all features from the feature view
],
logging_config=LoggingConfig(
destination=FileLoggingDestination(path="%LOGGING_PATH%")
),
)
driver_activity_v2 = FeatureService(
name="driver_activity_v2", features=[driver_stats_fv, transformed_conv_rate]
Expand Down
13 changes: 11 additions & 2 deletions sdk/python/tests/integration/offline_store/test_feature_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def test_feature_service_logging(environment, universal_data_sources, pass_as_pa
(_, datasets, data_sources) = universal_data_sources

feature_views = construct_universal_feature_views(data_sources)
store.apply([customer(), driver(), location(), *feature_views.values()])

feature_service = FeatureService(
name="test_service",
features=[
Expand All @@ -49,6 +47,17 @@ def test_feature_service_logging(environment, universal_data_sources, pass_as_pa
),
)

store.apply(
[customer(), driver(), location(), *feature_views.values()], feature_service
)

# Added to handle the case that the offline store is remote
store.registry.apply_feature_service(feature_service, store.config.project)
store.registry.apply_data_source(
feature_service.logging_config.destination.to_data_source(),
store.config.project,
)

driver_df = datasets.driver_df
driver_df["val_to_add"] = 50
driver_df = driver_df.join(conv_rate_plus_100(driver_df))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
construct_universal_feature_views,
table_name_from_data_source,
)
from tests.integration.feature_repos.universal.data_sources.file import (
RemoteOfflineStoreDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.snowflake import (
SnowflakeDataSourceCreator,
)
Expand Down Expand Up @@ -157,22 +160,25 @@ def test_historical_features_main(
timestamp_precision=timedelta(milliseconds=1),
)

assert_feature_service_correctness(
store,
feature_service,
full_feature_names,
entity_df_with_request_data,
expected_df,
event_timestamp,
)
assert_feature_service_entity_mapping_correctness(
store,
feature_service_entity_mapping,
full_feature_names,
entity_df_with_request_data,
full_expected_df,
event_timestamp,
)
if not isinstance(
environment.data_source_creator, RemoteOfflineStoreDataSourceCreator
):
assert_feature_service_correctness(
store,
feature_service,
full_feature_names,
entity_df_with_request_data,
expected_df,
event_timestamp,
)
assert_feature_service_entity_mapping_correctness(
store,
feature_service_entity_mapping,
full_feature_names,
entity_df_with_request_data,
full_expected_df,
event_timestamp,
)
table_from_df_entities: pd.DataFrame = job_from_df.to_arrow().to_pandas()

validate_dataframes(
Expand Down Expand Up @@ -375,8 +381,13 @@ def test_historical_features_persisting(
(entities, datasets, data_sources) = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

storage = environment.data_source_creator.create_saved_dataset_destination()

store.apply([driver(), customer(), location(), *feature_views.values()])

# Added to handle the case that the offline store is remote
store.registry.apply_data_source(storage.to_data_source(), store.config.project)

entity_df = datasets.entity_df.drop(
columns=["order_id", "origin_id", "destination_id"]
)
Expand All @@ -398,7 +409,7 @@ def test_historical_features_persisting(
saved_dataset = store.create_saved_dataset(
from_=job,
name="saved_dataset",
storage=environment.data_source_creator.create_saved_dataset_destination(),
storage=storage,
tags={"env": "test"},
allow_overwrite=True,
)
Expand Down
Loading

0 comments on commit 73938d6

Please sign in to comment.