Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Random port allocation for python server in tests #2710

Merged
merged 1 commit into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,19 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):


@pytest.fixture(scope="session")
def python_server(environment):
assert not _check_port_open("localhost", environment.get_local_server_port())
def feature_server_endpoint(environment):
if (
not environment.python_feature_server
or environment.test_repo_config.provider != "local"
):
yield environment.feature_store.get_feature_server_endpoint()
return

port = _free_port()

proc = Process(
target=start_test_local_server,
args=(environment.feature_store.repo_path, environment.get_local_server_port()),
args=(environment.feature_store.repo_path, port),
)
if (
environment.python_feature_server
Expand All @@ -287,14 +294,10 @@ def python_server(environment):
proc.start()
# Wait for server to start
wait_retry_backoff(
lambda: (
None,
_check_port_open("localhost", environment.get_local_server_port()),
),
timeout_secs=10,
lambda: (None, _check_port_open("localhost", port)), timeout_secs=10,
)

yield
yield f"http://localhost:{port}"

if proc.is_alive():
proc.kill()
Expand All @@ -314,6 +317,12 @@ def _check_port_open(host, port) -> bool:
return sock.connect_ex((host, port)) == 0


def _free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


@pytest.fixture(scope="session")
def universal_data_sources(environment) -> TestData:
return construct_universal_test_data(environment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import importlib
import json
import os
import re
import tempfile
import uuid
from dataclasses import dataclass
Expand Down Expand Up @@ -328,29 +327,10 @@ class Environment:
worker_id: str
online_store_creator: Optional[OnlineStoreCreator] = None

next_id = 0

def __post_init__(self):
self.end_date = datetime.utcnow().replace(microsecond=0, second=0, minute=0)
self.start_date: datetime = self.end_date - timedelta(days=3)

Environment.next_id += 1
self.id = Environment.next_id

def get_feature_server_endpoint(self) -> str:
if self.python_feature_server and self.test_repo_config.provider == "local":
return f"http://localhost:{self.get_local_server_port()}"
return self.feature_store.get_feature_server_endpoint()

def get_local_server_port(self) -> int:
# Heuristic when running with xdist to extract unique ports for each worker
parsed_worker_id = re.findall("gw(\\d+)", self.worker_id)
if len(parsed_worker_id) != 0:
worker_id_num = int(parsed_worker_id[0])
else:
worker_id_num = 0
return 6000 + 100 * worker_id_num + self.id


def table_name_from_data_source(ds: DataSource) -> Optional[str]:
if hasattr(ds, "table_ref"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _get_online_features_dict_remotely(

def get_online_features_dict(
environment: Environment,
endpoint: str,
features: Union[List[str], FeatureService],
entity_rows: List[Dict[str, Any]],
full_feature_names: bool = False,
Expand All @@ -305,7 +306,6 @@ def get_online_features_dict(
assertpy.assert_that(online_features).is_not_none()
dict1 = online_features.to_dict()

endpoint = environment.get_feature_server_endpoint()
# If endpoint is None, it means that a local / remote feature server aren't configured
if endpoint is not None:
dict2 = _get_online_features_dict_remotely(
Expand Down Expand Up @@ -447,7 +447,7 @@ def test_online_retrieval_with_event_timestamps(
@pytest.mark.goserver
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_online_retrieval(
environment, universal_data_sources, python_server, full_feature_names
environment, universal_data_sources, feature_server_endpoint, full_feature_names
):
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
Expand Down Expand Up @@ -547,6 +547,7 @@ def test_online_retrieval(

online_features_dict = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand All @@ -556,6 +557,7 @@ def test_online_retrieval(
# feature isn't requested.
online_features_no_conv_rate = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=[ref for ref in feature_refs if ref != "driver_stats:conv_rate"],
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -616,6 +618,7 @@ def test_online_retrieval(
# Check what happens for missing values
missing_responses_dict = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=[{"driver_id": 0, "customer_id": 0, "val_to_add": 100}],
full_feature_names=full_feature_names,
Expand All @@ -635,13 +638,15 @@ def test_online_retrieval(
with pytest.raises(RequestDataNotFoundInEntityRowsException):
get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=[{"driver_id": 0, "customer_id": 0}],
full_feature_names=full_feature_names,
)

assert_feature_service_correctness(
environment,
feature_server_endpoint,
feature_service,
entity_rows,
full_feature_names,
Expand All @@ -659,6 +664,7 @@ def test_online_retrieval(
]
assert_feature_service_entity_mapping_correctness(
environment,
feature_server_endpoint,
feature_service_entity_mapping,
entity_rows,
full_feature_names,
Expand Down Expand Up @@ -856,6 +862,7 @@ def get_latest_feature_values_for_location_df(entity_row, origin_df, destination

def assert_feature_service_correctness(
environment,
endpoint,
feature_service,
entity_rows,
full_feature_names,
Expand All @@ -866,6 +873,7 @@ def assert_feature_service_correctness(
):
feature_service_online_features_dict = get_online_features_dict(
environment=environment,
endpoint=endpoint,
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -905,6 +913,7 @@ def assert_feature_service_correctness(

def assert_feature_service_entity_mapping_correctness(
environment,
endpoint,
feature_service,
entity_rows,
full_feature_names,
Expand All @@ -914,6 +923,7 @@ def assert_feature_service_entity_mapping_correctness(
if full_feature_names:
feature_service_online_features_dict = get_online_features_dict(
environment=environment,
endpoint=endpoint,
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down Expand Up @@ -948,6 +958,7 @@ def assert_feature_service_entity_mapping_correctness(
with pytest.raises(FeatureNameCollisionError):
get_online_features_dict(
environment=environment,
endpoint=endpoint,
features=feature_service,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
Expand Down