Skip to content

Commit

Permalink
fix: Fix on demand feature view crash from inference when it uses df.…
Browse files Browse the repository at this point in the history
…apply (#2713)

* fix: Fix on demand feature view crash from inference when transformation uses df.apply

Signed-off-by: Danny Chiao <[email protected]>

* Fix inference

Signed-off-by: Danny Chiao <[email protected]>

* Fix test

Signed-off-by: Danny Chiao <[email protected]>
  • Loading branch information
adchia authored May 17, 2022
1 parent cebf609 commit c5539fd
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 3 deletions.
18 changes: 15 additions & 3 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import functools
import warnings
from datetime import datetime
from types import MethodType
from typing import Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

import dill
import pandas as pd
Expand Down Expand Up @@ -442,18 +443,29 @@ def infer_features(self):
Raises:
RegistryInferenceFailure: The set of features could not be inferred.
"""
rand_df_value: Dict[str, Any] = {
"float": 1.0,
"int": 1,
"str": "hello world",
"bytes": str.encode("hello world"),
"bool": True,
"datetime64[ns]": datetime.utcnow(),
}

df = pd.DataFrame()
for feature_view_projection in self.source_feature_view_projections.values():
for feature in feature_view_projection.features:
dtype = feast_value_type_to_pandas_type(feature.dtype.to_value_type())
df[f"{feature_view_projection.name}__{feature.name}"] = pd.Series(
dtype=dtype
)
df[f"{feature.name}"] = pd.Series(dtype=dtype)
sample_val = rand_df_value[dtype] if dtype in rand_df_value else None
df[f"{feature.name}"] = pd.Series(data=sample_val, dtype=dtype)
for request_data in self.source_request_sources.values():
for field in request_data.schema:
dtype = feast_value_type_to_pandas_type(field.dtype.to_value_type())
df[f"{field.name}"] = pd.Series(dtype=dtype)
sample_val = rand_df_value[dtype] if dtype in rand_df_value else None
df[f"{field.name}"] = pd.Series(sample_val, dtype=dtype)
output_df: pd.DataFrame = self.udf.__call__(df)
inferred_features = []
for f, dt in zip(output_df.columns, output_df.dtypes):
Expand Down
48 changes: 48 additions & 0 deletions sdk/python/tests/example_repos/on_demand_feature_view_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from datetime import timedelta

import pandas as pd

from feast import FeatureView, Field, FileSource
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, String

driver_stats = FileSource(
name="driver_stats_source",
path="data/driver_stats_lat_lon.parquet",
timestamp_field="event_timestamp",
created_timestamp_column="created",
description="A table describing the stats of a driver based on hourly logs",
owner="[email protected]",
)

driver_daily_features_view = FeatureView(
name="driver_daily_features",
entities=["driver"],
ttl=timedelta(seconds=8640000000),
schema=[
Field(name="daily_miles_driven", dtype=Float32),
Field(name="lat", dtype=Float32),
Field(name="lon", dtype=Float32),
Field(name="string_feature", dtype=String),
],
online=True,
source=driver_stats,
tags={"production": "True"},
owner="[email protected]",
)


@on_demand_feature_view(
sources=[driver_daily_features_view],
schema=[
Field(name="first_char", dtype=String),
Field(name="concat_string", dtype=String),
],
)
def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["concat_string"] = inputs.apply(
lambda x: x.string_feature + "hello", axis=1
).astype("string")
df["first_char"] = inputs["string_feature"].str[:1].astype("string")
return df
31 changes: 31 additions & 0 deletions sdk/python/tests/integration/registration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,37 @@ def test_nullable_online_store(test_nullable_online_store) -> None:
runner.run(["teardown"], cwd=repo_path)


@pytest.mark.integration
@pytest.mark.universal_offline_stores
def test_odfv_apply(environment) -> None:
project = f"test_odfv_apply{str(uuid.uuid4()).replace('-', '')[:8]}"
runner = CliRunner()

with tempfile.TemporaryDirectory() as repo_dir_name:
try:
repo_path = Path(repo_dir_name)
feature_store_yaml = make_feature_store_yaml(
project, environment.test_repo_config, repo_path
)

repo_config = repo_path / "feature_store.yaml"

repo_config.write_text(dedent(feature_store_yaml))

repo_example = repo_path / "example.py"
repo_example.write_text(get_example_repo("on_demand_feature_view_repo.py"))
result = runner.run(["apply"], cwd=repo_path)
assertpy.assert_that(result.returncode).is_equal_to(0)

# entity & feature view list commands should succeed
result = runner.run(["entities", "list"], cwd=repo_path)
assertpy.assert_that(result.returncode).is_equal_to(0)
result = runner.run(["on-demand-feature-views", "list"], cwd=repo_path)
assertpy.assert_that(result.returncode).is_equal_to(0)
finally:
runner.run(["teardown"], cwd=repo_path)


@contextmanager
def setup_third_party_provider_repo(provider_name: str):
with tempfile.TemporaryDirectory() as repo_dir_name:
Expand Down
74 changes: 74 additions & 0 deletions sdk/python/tests/integration/registration/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,80 @@ def test_apply_feature_view_success(test_registry):
test_registry._get_registry_proto()


@pytest.mark.parametrize(
"test_registry", [lazy_fixture("local_registry")],
)
def test_apply_on_demand_feature_view_success(test_registry):
# Create Feature Views
driver_stats = FileSource(
name="driver_stats_source",
path="data/driver_stats_lat_lon.parquet",
timestamp_field="event_timestamp",
created_timestamp_column="created",
description="A table describing the stats of a driver based on hourly logs",
owner="[email protected]",
)

driver_daily_features_view = FeatureView(
name="driver_daily_features",
entities=["driver"],
ttl=timedelta(seconds=8640000000),
schema=[
Field(name="daily_miles_driven", dtype=Float32),
Field(name="lat", dtype=Float32),
Field(name="lon", dtype=Float32),
Field(name="string_feature", dtype=String),
],
online=True,
source=driver_stats,
tags={"production": "True"},
owner="[email protected]",
)

@on_demand_feature_view(
sources=[driver_daily_features_view],
schema=[Field(name="first_char", dtype=String)],
)
def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["first_char"] = inputs["string_feature"].str[:1].astype("string")
return df

project = "project"

# Register Feature View
test_registry.apply_feature_view(location_features_from_push, project)

feature_views = test_registry.list_on_demand_feature_views(project)

# List Feature Views
assert (
len(feature_views) == 1
and feature_views[0].name == "location_features_from_push"
and feature_views[0].features[0].name == "first_char"
and feature_views[0].features[0].dtype == String
)

feature_view = test_registry.get_on_demand_feature_view(
"location_features_from_push", project
)
assert (
feature_view.name == "location_features_from_push"
and feature_view.features[0].name == "first_char"
and feature_view.features[0].dtype == String
)

test_registry.delete_feature_view("location_features_from_push", project)
feature_views = test_registry.list_on_demand_feature_views(project)
assert len(feature_views) == 0

test_registry.teardown()

# Will try to reload registry, which will fail because the file has been deleted
with pytest.raises(FileNotFoundError):
test_registry._get_registry_proto()


@pytest.mark.parametrize(
"test_registry", [lazy_fixture("local_registry")],
)
Expand Down

0 comments on commit c5539fd

Please sign in to comment.