Skip to content

Commit

Permalink
feat: Make arrow primary interchange for offline ODFV execution (#4083)
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko authored Apr 15, 2024
1 parent a05cdbc commit 9ed0a09
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 46 deletions.
64 changes: 19 additions & 45 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,11 @@ def to_df(
validation_reference (optional): The validation to apply against the retrieved dataframe.
timeout (optional): The query timeout if applicable.
"""
features_df = self._to_df_internal(timeout=timeout)

if self.on_demand_feature_views:
# TODO(adchia): Fix requirement to specify dependent feature views in feature_refs
for odfv in self.on_demand_feature_views:
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
features_df = features_df.join(
odfv.get_transformed_features_df(
features_df,
self.full_feature_names,
)
)

if validation_reference:
if not flags_helper.is_test():
warnings.warn(
"Dataset validation is an experimental feature. "
"This API is unstable and it could and most probably will be changed in the future. "
"We do not guarantee that future changes will maintain backward compatibility.",
RuntimeWarning,
)

validation_result = validation_reference.profile.validate(features_df)
if not validation_result.is_success:
raise ValidationFailed(validation_result)

return features_df
return (
self.to_arrow(validation_reference=validation_reference, timeout=timeout)
.to_pandas()
.reset_index(drop=True)
)

def to_arrow(
self,
Expand All @@ -122,23 +97,20 @@ def to_arrow(
validation_reference (optional): The validation to apply against the retrieved dataframe.
timeout (optional): The query timeout if applicable.
"""
if not self.on_demand_feature_views and not validation_reference:
return self._to_arrow_internal(timeout=timeout)

features_df = self._to_df_internal(timeout=timeout)
features_table = self._to_arrow_internal(timeout=timeout)
if self.on_demand_feature_views:
for odfv in self.on_demand_feature_views:
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
features_df = features_df.join(
odfv.get_transformed_features_df(
features_df,
self.full_feature_names,
)
transformed_arrow = odfv.transform_arrow(
features_table, self.full_feature_names
)

for col in transformed_arrow.column_names:
if col.startswith("__index"):
continue
features_table = features_table.append_column(
col, transformed_arrow[col]
)

if validation_reference:
if not flags_helper.is_test():
warnings.warn(
Expand All @@ -148,11 +120,13 @@ def to_arrow(
RuntimeWarning,
)

validation_result = validation_reference.profile.validate(features_df)
validation_result = validation_reference.profile.validate(
features_table.to_pandas()
)
if not validation_result.is_success:
raise ValidationFailed(validation_result)

return pyarrow.Table.from_pandas(features_df)
return features_table

def to_sql(self) -> str:
"""
Expand Down
55 changes: 55 additions & 0 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dill
import pandas as pd
import pyarrow
from typeguard import typechecked

from feast.base_feature_view import BaseFeatureView
Expand Down Expand Up @@ -391,6 +392,60 @@ def get_request_data_schema(self) -> Dict[str, ValueType]:
def _get_projected_feature_name(self, feature: str) -> str:
return f"{self.projection.name_to_use()}__{feature}"

def transform_arrow(
self,
pa_table: pyarrow.Table,
full_feature_names: bool = False,
) -> pyarrow.Table:
if not isinstance(pa_table, pyarrow.Table):
raise TypeError("transform_arrow only accepts pyarrow.Table")
columns_to_cleanup = []
for source_fv_projection in self.source_feature_view_projections.values():
for feature in source_fv_projection.features:
full_feature_ref = f"{source_fv_projection.name}__{feature.name}"
if full_feature_ref in pa_table.column_names:
# Make sure the partial feature name is always present
pa_table = pa_table.append_column(
feature.name, pa_table[full_feature_ref]
)
# pa_table[feature.name] = pa_table[full_feature_ref]
columns_to_cleanup.append(feature.name)
elif feature.name in pa_table.column_names:
# Make sure the full feature name is always present
# pa_table[full_feature_ref] = pa_table[feature.name]
pa_table = pa_table.append_column(
full_feature_ref, pa_table[feature.name]
)
columns_to_cleanup.append(full_feature_ref)

df_with_transformed_features: pyarrow.Table = (
self.feature_transformation.transform_arrow(pa_table)
)

# Work out whether the correct columns names are used.
rename_columns: Dict[str, str] = {}
for feature in self.features:
short_name = feature.name
long_name = self._get_projected_feature_name(feature.name)
if (
short_name in df_with_transformed_features.column_names
and full_feature_names
):
rename_columns[short_name] = long_name
elif not full_feature_names:
rename_columns[long_name] = short_name

# Cleanup extra columns used for transformation
for col in columns_to_cleanup:
if col in df_with_transformed_features.column_names:
df_with_transformed_features = df_with_transformed_features.dtop(col)
return df_with_transformed_features.rename_columns(
[
rename_columns.get(c, c)
for c in df_with_transformed_features.column_names
]
)

def get_transformed_features_df(
self,
df_with_features: pd.DataFrame,
Expand Down
14 changes: 14 additions & 0 deletions sdk/python/feast/transformation/pandas_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dill
import pandas as pd
import pyarrow

from feast.field import Field, from_value_type
from feast.protos.feast.core.Transformation_pb2 import (
Expand All @@ -26,6 +27,19 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
self.udf = udf
self.udf_string = udf_string

def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
if not isinstance(pa_table, pyarrow.Table):
raise TypeError(
f"pa_table should be type pyarrow.Table but got {type(pa_table).__name__}"
)
output_df = self.udf.__call__(pa_table.to_pandas())
output_df = pyarrow.Table.from_pandas(output_df)
if not isinstance(output_df, pyarrow.Table):
raise TypeError(
f"output_df should be type pyarrow.Table but got {type(output_df).__name__}"
)
return output_df

def transform(self, input_df: pd.DataFrame) -> pd.DataFrame:
if not isinstance(input_df, pd.DataFrame):
raise TypeError(
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/transformation/python_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List

import dill
import pyarrow

from feast.field import Field, from_value_type
from feast.protos.feast.core.Transformation_pb2 import (
Expand All @@ -24,6 +25,11 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
self.udf = udf
self.udf_string = udf_string

def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
raise Exception(
'OnDemandFeatureView mode "python" not supported for offline processing.'
)

def transform(self, input_dict: Dict) -> Dict:
if not isinstance(input_dict, Dict):
raise TypeError(
Expand Down
9 changes: 9 additions & 0 deletions sdk/python/feast/transformation/substrait_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def table_provider(names, schema: pyarrow.Schema):
).read_all()
return table.to_pandas()

def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
def table_provider(names, schema: pyarrow.Schema):
return pa_table.select(schema.names)

table: pyarrow.Table = pyarrow.substrait.run_query(
self.substrait_plan, table_provider=table_provider
).read_all()
return table

def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
df = pd.DataFrame.from_dict(random_input)
output_df: pd.DataFrame = self.transform(df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_to_sql():

@pytest.mark.parametrize("timeout", (None, 30))
def test_to_df_timeout(retrieval_job, timeout: Optional[int]):
with patch.object(retrieval_job, "_to_df_internal") as mock_to_df_internal:
with patch.object(retrieval_job, "_to_arrow_internal") as mock_to_df_internal:
retrieval_job.to_df(timeout=timeout)
mock_to_df_internal.assert_called_once_with(timeout=timeout)

Expand Down

0 comments on commit 9ed0a09

Please sign in to comment.