Skip to content

Commit

Permalink
Merge pull request #5 from dmartinol/remote_offline
Browse files Browse the repository at this point in the history
Initial skeleton of unit test for offline server
  • Loading branch information
redhatHameed authored May 17, 2024
2 parents 77ae13c + c52cc51 commit b56b826
Showing 1 changed file with 237 additions and 0 deletions.
237 changes: 237 additions & 0 deletions sdk/python/tests/unit/test_offline_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import os
import tempfile
from datetime import datetime, timedelta

import assertpy
import pandas as pd
import pyarrow as pa
import pyarrow.flight as flight
import pytest

from feast import FeatureStore
from feast.infra.offline_stores.remote import (
RemoteOfflineStore,
RemoteOfflineStoreConfig,
)
from feast.offline_server import OfflineServer
from feast.repo_config import RepoConfig
from tests.utils.cli_repo_creator import CliRunner

PROJECT_NAME = "test_remote_offline"


@pytest.fixture
def empty_offline_server(environment):
store = environment.feature_store

location = "grpc+tcp://localhost:0"
return OfflineServer(store=store, location=location)


@pytest.fixture
def arrow_client(empty_offline_server):
return flight.FlightClient(f"grpc://localhost:{empty_offline_server.port}")


def test_offline_server_is_alive(environment, empty_offline_server, arrow_client):
server = empty_offline_server
client = arrow_client

assertpy.assert_that(server).is_not_none
assertpy.assert_that(server.port).is_not_equal_to(0)

actions = list(client.list_actions())
flights = list(client.list_flights())

assertpy.assert_that(actions).is_empty()
assertpy.assert_that(flights).is_empty()


def default_store(temp_dir):
runner = CliRunner()
result = runner.run(["init", PROJECT_NAME], cwd=temp_dir)
repo_path = os.path.join(temp_dir, PROJECT_NAME, "feature_repo")
assert result.returncode == 0

result = runner.run(["--chdir", repo_path, "apply"], cwd=temp_dir)
assert result.returncode == 0

fs = FeatureStore(repo_path=repo_path)
return fs


def remote_feature_store(offline_server):
offline_config = RemoteOfflineStoreConfig(
type="remote", host="0.0.0.0", port=offline_server.port
)

registry_path = os.path.join(
str(offline_server.store.repo_path),
offline_server.store.config.registry.path,
)
store = FeatureStore(
config=RepoConfig(
project=PROJECT_NAME,
registry=registry_path,
provider="local",
offline_store=offline_config,
entity_key_serialization_version=2,
)
)
return store


def test_get_historical_features():
with tempfile.TemporaryDirectory() as temp_dir:
store = default_store(str(temp_dir))
location = "grpc+tcp://localhost:0"
server = OfflineServer(store=store, location=location)

assertpy.assert_that(server).is_not_none
assertpy.assert_that(server.port).is_not_equal_to(0)

fs = remote_feature_store(server)

_test_get_historical_features_returns_data(fs)
_test_get_historical_features_returns_nan(fs)
_test_offline_write_batch(str(temp_dir), fs)
_test_write_logged_features(str(temp_dir), fs)
_test_pull_latest_from_table_or_query(str(temp_dir), fs)
_test_pull_all_from_table_or_query(str(temp_dir), fs)


def _test_get_historical_features_returns_data(fs: FeatureStore):
entity_df = pd.DataFrame.from_dict(
{
"driver_id": [1001, 1002, 1003],
"event_timestamp": [
datetime(2021, 4, 12, 10, 59, 42),
datetime(2021, 4, 12, 8, 12, 10),
datetime(2021, 4, 12, 16, 40, 26),
],
"label_driver_reported_satisfaction": [1, 5, 3],
"val_to_add": [1, 2, 3],
"val_to_add_2": [10, 20, 30],
}
)

features = [
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"transformed_conv_rate:conv_rate_plus_val1",
"transformed_conv_rate:conv_rate_plus_val2",
]

training_df = fs.get_historical_features(entity_df, features).to_df()

assertpy.assert_that(training_df).is_not_none()
assertpy.assert_that(len(training_df)).is_equal_to(3)

for index, driver_id in enumerate(entity_df["driver_id"]):
assertpy.assert_that(training_df["driver_id"][index]).is_equal_to(driver_id)
for feature in features:
column_id = feature.split(":")[1]
value = training_df[column_id][index]
assertpy.assert_that(value).is_not_nan()


def _test_get_historical_features_returns_nan(fs: FeatureStore):
entity_df = pd.DataFrame.from_dict(
{
"driver_id": [1, 2, 3],
"event_timestamp": [
datetime(2021, 4, 12, 10, 59, 42),
datetime(2021, 4, 12, 8, 12, 10),
datetime(2021, 4, 12, 16, 40, 26),
],
"label_driver_reported_satisfaction": [1, 5, 3],
"val_to_add": [1, 2, 3],
"val_to_add_2": [10, 20, 30],
}
)

features = [
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"transformed_conv_rate:conv_rate_plus_val1",
"transformed_conv_rate:conv_rate_plus_val2",
]

training_df = fs.get_historical_features(entity_df, features).to_df()

assertpy.assert_that(training_df).is_not_none()
assertpy.assert_that(len(training_df)).is_equal_to(3)

for index, driver_id in enumerate(entity_df["driver_id"]):
assertpy.assert_that(training_df["driver_id"][index]).is_equal_to(driver_id)
for feature in features:
column_id = feature.split(":")[1]
value = training_df[column_id][index]
assertpy.assert_that(value).is_nan()


def _test_offline_write_batch(temp_dir, fs: FeatureStore):
data_file = os.path.join(
temp_dir, fs.project, "feature_repo/data/driver_stats.parquet"
)
data_df = pd.read_parquet(data_file)
feature_view = fs.get_feature_view("driver_hourly_stats")

with pytest.raises(NotImplementedError):
RemoteOfflineStore.offline_write_batch(
fs.config, feature_view, pa.Table.from_pandas(data_df), progress=None
)


def _test_write_logged_features(temp_dir, fs: FeatureStore):
data_file = os.path.join(
temp_dir, fs.project, "feature_repo/data/driver_stats.parquet"
)
data_df = pd.read_parquet(data_file)
feature_service = fs.get_feature_service("driver_activity_v1")

with pytest.raises(NotImplementedError):
RemoteOfflineStore.write_logged_features(
config=fs.config,
data=pa.Table.from_pandas(data_df),
source=feature_service,
logging_config=None,
registry=fs.registry,
)


def _test_pull_latest_from_table_or_query(temp_dir, fs: FeatureStore):
data_source = fs.get_data_source("driver_hourly_stats_source")

end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)
with pytest.raises(NotImplementedError):
RemoteOfflineStore.pull_latest_from_table_or_query(
config=fs.config,
data_source=data_source,
join_key_columns=[],
feature_name_columns=[],
timestamp_field="event_timestamp",
created_timestamp_column="created",
start_date=start_date,
end_date=end_date,
)


def _test_pull_all_from_table_or_query(temp_dir, fs: FeatureStore):
data_source = fs.get_data_source("driver_hourly_stats_source")

end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)
with pytest.raises(NotImplementedError):
RemoteOfflineStore.pull_all_from_table_or_query(
config=fs.config,
data_source=data_source,
join_key_columns=[],
feature_name_columns=[],
timestamp_field="event_timestamp",
start_date=start_date,
end_date=end_date,
)

0 comments on commit b56b826

Please sign in to comment.