Skip to content

Commit

Permalink
Add test for SFV online retrieval
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Wang <[email protected]>
  • Loading branch information
felixwang9817 committed Jun 16, 2022
1 parent 47d1c45 commit d71be53
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,82 @@ def test_online_retrieval_with_event_timestamps(
)


@pytest.mark.integration
@pytest.mark.universal_online_stores
@pytest.mark.goserver
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_stream_feature_view_online_retrieval(
environment, universal_data_sources, feature_server_endpoint, full_feature_names
):
"""
Tests materialization and online retrieval for stream feature views.
This test is separate from test_online_retrieval since combining feature views and
stream feature views into a single test resulted in test flakiness. This is tech
debt that should be resolved soon.
"""
# Set up feature store.
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
pushable_feature_view = feature_views.pushed_locations
fs.apply([location(), pushable_feature_view])

# Materialize.
fs.materialize(
environment.start_date - timedelta(days=1),
environment.end_date + timedelta(days=1),
)

# Get online features by randomly sampling 10 entities that exist in the batch source.
sample_locations = datasets.location_df.sample(10)["location_id"]
entity_rows = [
{"location_id": sample_location} for sample_location in sample_locations
]

feature_refs = [
"pushable_location_stats:temperature",
]
unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f]

online_features_dict = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
)

# Check that the response has the expected set of keys.
keys = set(online_features_dict.keys())
expected_keys = set(
f.replace(":", "__") if full_feature_names else f.split(":")[-1]
for f in feature_refs
) | {"location_id"}
assert (
keys == expected_keys
), f"Response keys are different from expected: {keys - expected_keys} (extra) and {expected_keys - keys} (missing)"

# Check that the feature values match.
tc = unittest.TestCase()
for i, entity_row in enumerate(entity_rows):
df_features = get_latest_feature_values_from_location_df(
entity_row, datasets.location_df
)

assert df_features["location_id"] == online_features_dict["location_id"][i]
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertAlmostEqual(
df_features[unprefixed_feature_ref],
online_features_dict[
response_feature_name(
unprefixed_feature_ref, feature_refs, full_feature_names
)
][i],
delta=0.0001,
)


@pytest.mark.integration
@pytest.mark.universal_online_stores
@pytest.mark.goserver
Expand Down Expand Up @@ -859,6 +935,10 @@ def get_latest_feature_values_for_location_df(entity_row, origin_df, destination
}


def get_latest_feature_values_from_location_df(entity_row, location_df):
return get_latest_row(entity_row, location_df, "location_id", "location_id")


def assert_feature_service_correctness(
environment,
endpoint,
Expand Down

0 comments on commit d71be53

Please sign in to comment.