Skip to content

Commit

Permalink
ODFV UDFs should handle list types (#2002)
Browse files Browse the repository at this point in the history
* ODFV UDFs should handle list types

ODFV UDFs handling list types (e.g., embeddings/vectors) should be registered without error.

Signed-off-by: Jeff <[email protected]>

* handle all value type names that end in _LIST

Signed-off-by: Jeff <[email protected]>

* clearly define dummy vector for driver embedding test data

Signed-off-by: Jeff <[email protected]>

* example embedding in test_write_to_online_store()

Signed-off-by: Jeff <[email protected]>

* map Arrow list types to Redshift super type

Signed-off-by: Jeff <[email protected]>

* ensure float list types in ODFV UDFs can be appied

Signed-off-by: Jeff <[email protected]>

* isolate ODFV list type feature test to smaller code changes

Signed-off-by: Jeff <[email protected]>
  • Loading branch information
Agent007 authored Nov 12, 2021
1 parent fbc1d61 commit b456c46
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 1 deletion.
5 changes: 5 additions & 0 deletions sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def feast_value_type_to_pandas_type(value_type: ValueType) -> Any:
ValueType.BOOL: "bool",
ValueType.UNIX_TIMESTAMP: "datetime",
}
if value_type.name.endswith("_LIST"):
return "object"
if value_type in value_type_to_pandas_type:
return value_type_to_pandas_type[value_type]
raise TypeError(
Expand Down Expand Up @@ -451,6 +453,9 @@ def pa_to_redshift_value_type(pa_type: pyarrow.DataType) -> str:
# PyArrow decimal types (e.g. "decimal(38,37)") luckily directly map to the Redshift type.
return pa_type_as_str

if pa_type_as_str.startswith("list"):
return "super"

# We have to take into account how arrow types map to parquet types as well.
# For example, null type maps to int32 in parquet, so we have to use int4 in Redshift.
# Other mappings have also been adjusted accordingly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def customer():

def location():
return Entity(name="location_id", value_type=ValueType.INT64)


def item():
return Entity(name="item_id", value_type=ValueType.INT64)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd

from feast import Feature, FeatureView, OnDemandFeatureView, ValueType
Expand Down Expand Up @@ -68,6 +69,40 @@ def conv_rate_plus_100_feature_view(
)


def similarity(features_df: pd.DataFrame) -> pd.DataFrame:
if features_df.size == 0:
# give hint to Feast about return type
df = pd.DataFrame({"cos_double": [0.0]})
df["cos_float"] = df["cos_double"].astype(np.float32)
return df
vectors_a = features_df["embedding_double"].apply(np.array)
vectors_b = features_df["vector_double"].apply(np.array)
dot_products = vectors_a.mul(vectors_b).apply(sum)
norms_q = vectors_a.apply(np.linalg.norm)
norms_doc = vectors_b.apply(np.linalg.norm)
df = pd.DataFrame()
df["cos_double"] = dot_products / (norms_q * norms_doc)
df["cos_float"] = df["cos_double"].astype(np.float32)
return df


def similarity_feature_view(
inputs: Dict[str, Union[RequestDataSource, FeatureView]],
infer_features: bool = False,
features: Optional[List[Feature]] = None,
) -> OnDemandFeatureView:
_features = features or [
Feature("cos_double", ValueType.DOUBLE),
Feature("cos_float", ValueType.FLOAT),
]
return OnDemandFeatureView(
name=similarity.__name__,
inputs=inputs,
features=[] if infer_features else _features,
udf=similarity,
)


def create_driver_age_request_feature_view():
return RequestFeatureView(
name="driver_age",
Expand All @@ -83,6 +118,32 @@ def create_conv_rate_request_data_source():
)


def create_similarity_request_data_source():
return RequestDataSource(
name="similarity_input",
schema={
"vector_double": ValueType.DOUBLE_LIST,
"vector_float": ValueType.FLOAT_LIST,
},
)


def create_item_embeddings_feature_view(source, infer_features: bool = False):
item_embeddings_feature_view = FeatureView(
name="item_embeddings",
entities=["item"],
features=None
if infer_features
else [
Feature(name="embedding_double", dtype=ValueType.DOUBLE_LIST),
Feature(name="embedding_float", dtype=ValueType.FLOAT_LIST),
],
batch_source=source,
ttl=timedelta(hours=2),
)
return item_embeddings_feature_view


def create_driver_hourly_stats_feature_view(source, infer_features: bool = False):
driver_stats_feature_view = FeatureView(
name="driver_stats",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from datetime import datetime

import pandas as pd
import pytest

from feast import Feature, ValueType
from feast.errors import SpecifiedFeaturesNotPresentError
from tests.integration.feature_repos.universal.entities import customer, driver
from feast.infra.offline_stores.file_source import FileSource
from tests.integration.feature_repos.universal.entities import customer, driver, item
from tests.integration.feature_repos.universal.feature_views import (
conv_rate_plus_100_feature_view,
create_conv_rate_request_data_source,
create_driver_hourly_stats_feature_view,
create_item_embeddings_feature_view,
create_similarity_request_data_source,
similarity_feature_view,
)


Expand All @@ -33,6 +40,37 @@ def test_infer_odfv_features(environment, universal_data_sources, infer_features
assert len(odfv.features) == 3


@pytest.mark.integration
@pytest.mark.parametrize("infer_features", [True, False], ids=lambda v: str(v))
def test_infer_odfv_list_features(environment, infer_features, tmp_path):
fake_embedding = [1.0, 1.0]
items_df = pd.DataFrame(
data={
"item_id": [0],
"embedding_float": [fake_embedding],
"embedding_double": [fake_embedding],
"event_timestamp": [pd.Timestamp(datetime.utcnow())],
"created": [pd.Timestamp(datetime.utcnow())],
}
)
output_path = f"{tmp_path}/items.parquet"
items_df.to_parquet(output_path)
fake_items_src = FileSource(
path=output_path,
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
items = create_item_embeddings_feature_view(fake_items_src)
sim_odfv = similarity_feature_view(
{"items": items, "input_request": create_similarity_request_data_source()},
infer_features=infer_features,
)
store = environment.feature_store
store.apply([item(), items, sim_odfv])
odfv = store.get_on_demand_feature_view("similarity")
assert len(odfv.features) == 2


@pytest.mark.integration
@pytest.mark.universal
def test_infer_odfv_features_with_error(environment, universal_data_sources):
Expand Down

0 comments on commit b456c46

Please sign in to comment.