-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support cohort targeting for local evaluation (#47)
- Loading branch information
Showing
25 changed files
with
1,255 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
parameterized~=0.9.0 | ||
python-dotenv~=0.21.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from dataclasses import dataclass, field | ||
from typing import ClassVar, Set | ||
|
||
USER_GROUP_TYPE: ClassVar[str] = "User" | ||
|
||
|
||
@dataclass | ||
class Cohort: | ||
id: str | ||
last_modified: int | ||
size: int | ||
member_ids: Set[str] | ||
group_type: str = field(default=USER_GROUP_TYPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import time | ||
import logging | ||
import base64 | ||
import json | ||
from http.client import HTTPResponse | ||
from typing import Optional | ||
from ..version import __version__ | ||
|
||
from .cohort import Cohort | ||
from ..connection_pool import HTTPConnectionPool | ||
from ..exception import HTTPErrorResponseException, CohortTooLargeException | ||
|
||
COHORT_REQUEST_RETRY_DELAY_MILLIS = 100 | ||
|
||
|
||
class CohortDownloadApi: | ||
|
||
def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort: | ||
raise NotImplementedError | ||
|
||
|
||
class DirectCohortDownloadApi(CohortDownloadApi): | ||
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger): | ||
super().__init__() | ||
self.api_key = api_key | ||
self.secret_key = secret_key | ||
self.max_cohort_size = max_cohort_size | ||
self.server_url = server_url | ||
self.logger = logger | ||
self.__setup_connection_pool() | ||
|
||
def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None: | ||
self.logger.debug(f"getCohortMembers({cohort_id}): start") | ||
errors = 0 | ||
while True: | ||
response = None | ||
try: | ||
last_modified = None if cohort is None else cohort.last_modified | ||
response = self._get_cohort_members_request(cohort_id, last_modified) | ||
self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}") | ||
if response.status == 200: | ||
cohort_info = json.loads(response.read().decode("utf8")) | ||
self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}") | ||
return Cohort( | ||
id=cohort_info['cohortId'], | ||
last_modified=cohort_info['lastModified'], | ||
size=cohort_info['size'], | ||
member_ids=set(cohort_info['memberIds']), | ||
group_type=cohort_info['groupType'], | ||
) | ||
elif response.status == 204: | ||
self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified") | ||
return | ||
elif response.status == 413: | ||
raise CohortTooLargeException( | ||
f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}") | ||
elif response.status != 202: | ||
raise HTTPErrorResponseException(response.status, | ||
f"Unexpected response code: {response.status}") | ||
except Exception as e: | ||
if response and not (isinstance(e, HTTPErrorResponseException) and response.status == 429): | ||
errors += 1 | ||
self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}") | ||
if errors >= 3 or isinstance(e, CohortTooLargeException): | ||
raise e | ||
time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000) | ||
|
||
def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse: | ||
headers = { | ||
'Authorization': f'Basic {self._get_basic_auth()}', | ||
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}" | ||
} | ||
conn = self._connection_pool.acquire() | ||
try: | ||
url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}' | ||
if last_modified is not None: | ||
url += f'&lastModified={last_modified}' | ||
response = conn.request('GET', url, headers=headers) | ||
return response | ||
finally: | ||
self._connection_pool.release(conn) | ||
|
||
def _get_basic_auth(self) -> str: | ||
credentials = f'{self.api_key}:{self.secret_key}' | ||
return base64.b64encode(credentials.encode('utf-8')).decode('utf-8') | ||
|
||
def __setup_connection_pool(self): | ||
scheme, _, host = self.server_url.split('/', 3) | ||
timeout = 10 | ||
self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout, | ||
scheme=scheme) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import logging | ||
from typing import Dict, Set | ||
from concurrent.futures import ThreadPoolExecutor, Future, as_completed | ||
import threading | ||
|
||
from .cohort import Cohort | ||
from .cohort_download_api import CohortDownloadApi | ||
from .cohort_storage import CohortStorage | ||
from ..exception import CohortsDownloadException | ||
|
||
|
||
class CohortLoader: | ||
def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage): | ||
self.cohort_download_api = cohort_download_api | ||
self.cohort_storage = cohort_storage | ||
self.jobs: Dict[str, Future] = {} | ||
self.lock_jobs = threading.Lock() | ||
self.executor = ThreadPoolExecutor( | ||
max_workers=32, | ||
thread_name_prefix='CohortLoaderExecutor' | ||
) | ||
|
||
def load_cohort(self, cohort_id: str) -> Future: | ||
with self.lock_jobs: | ||
if cohort_id not in self.jobs: | ||
future = self.executor.submit(self.__load_cohort_internal, cohort_id) | ||
future.add_done_callback(lambda f: self._remove_job(cohort_id)) | ||
self.jobs[cohort_id] = future | ||
return self.jobs[cohort_id] | ||
|
||
def _remove_job(self, cohort_id: str): | ||
if cohort_id in self.jobs: | ||
with self.lock_jobs: | ||
self.jobs.pop(cohort_id, None) | ||
|
||
def download_cohort(self, cohort_id: str) -> Cohort: | ||
cohort = self.cohort_storage.get_cohort(cohort_id) | ||
return self.cohort_download_api.get_cohort(cohort_id, cohort) | ||
|
||
def download_cohorts(self, cohort_ids: Set[str]) -> Future: | ||
def update_task(task_cohort_ids): | ||
errors = [] | ||
futures = [] | ||
for cohort_id in task_cohort_ids: | ||
future = self.load_cohort(cohort_id) | ||
futures.append(future) | ||
|
||
for future in as_completed(futures): | ||
try: | ||
future.result() | ||
except Exception as e: | ||
cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None) | ||
if cohort_id: | ||
errors.append((cohort_id, e)) | ||
|
||
if errors: | ||
raise CohortsDownloadException(errors) | ||
|
||
return self.executor.submit(update_task, cohort_ids) | ||
|
||
def __load_cohort_internal(self, cohort_id): | ||
try: | ||
cohort = self.download_cohort(cohort_id) | ||
if cohort is not None: | ||
self.cohort_storage.put_cohort(cohort) | ||
except Exception as e: | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Dict, Set, Optional | ||
from threading import RLock | ||
|
||
from .cohort import Cohort, USER_GROUP_TYPE | ||
|
||
|
||
class CohortStorage: | ||
def get_cohort(self, cohort_id: str): | ||
raise NotImplementedError | ||
|
||
def get_cohorts(self): | ||
raise NotImplementedError | ||
|
||
def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: | ||
raise NotImplementedError | ||
|
||
def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: | ||
raise NotImplementedError | ||
|
||
def put_cohort(self, cohort_description: Cohort): | ||
raise NotImplementedError | ||
|
||
def delete_cohort(self, group_type: str, cohort_id: str): | ||
raise NotImplementedError | ||
|
||
def get_cohort_ids(self) -> Set[str]: | ||
raise NotImplementedError | ||
|
||
|
||
class InMemoryCohortStorage(CohortStorage): | ||
def __init__(self): | ||
self.lock = RLock() | ||
self.group_to_cohort_store: Dict[str, Set[str]] = {} | ||
self.cohort_store: Dict[str, Cohort] = {} | ||
|
||
def get_cohort(self, cohort_id: str): | ||
with self.lock: | ||
return self.cohort_store.get(cohort_id) | ||
|
||
def get_cohorts(self): | ||
return self.cohort_store.copy() | ||
|
||
def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]: | ||
return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids) | ||
|
||
def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]: | ||
result = set() | ||
with self.lock: | ||
group_type_cohorts = self.group_to_cohort_store.get(group_type, {}) | ||
for cohort_id in group_type_cohorts: | ||
members = self.cohort_store.get(cohort_id).member_ids | ||
if cohort_id in cohort_ids and group_name in members: | ||
result.add(cohort_id) | ||
return result | ||
|
||
def put_cohort(self, cohort: Cohort): | ||
with self.lock: | ||
if cohort.group_type not in self.group_to_cohort_store: | ||
self.group_to_cohort_store[cohort.group_type] = set() | ||
self.group_to_cohort_store[cohort.group_type].add(cohort.id) | ||
self.cohort_store[cohort.id] = cohort | ||
|
||
def delete_cohort(self, group_type: str, cohort_id: str): | ||
with self.lock: | ||
group_cohorts = self.group_to_cohort_store.get(group_type, {}) | ||
if cohort_id in group_cohorts: | ||
group_cohorts.remove(cohort_id) | ||
if cohort_id in self.cohort_store: | ||
del self.cohort_store[cohort_id] | ||
|
||
def get_cohort_ids(self): | ||
with self.lock: | ||
return set(self.cohort_store.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com' | ||
EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com' | ||
|
||
|
||
class CohortSyncConfig: | ||
"""Experiment Cohort Sync Configuration | ||
This configuration is used to set up the cohort loader. The cohort loader is responsible for | ||
downloading cohorts from the server and storing them locally. | ||
Parameters: | ||
api_key (str): The project API Key | ||
secret_key (str): The project Secret Key | ||
max_cohort_size (int): The maximum cohort size that can be downloaded | ||
cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for | ||
cohort updates, minimum 60000 | ||
cohort_server_url (str): The server endpoint from which to request cohorts | ||
""" | ||
|
||
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647, | ||
cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL): | ||
self.api_key = api_key | ||
self.secret_key = secret_key | ||
self.max_cohort_size = max_cohort_size | ||
self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000) | ||
self.cohort_server_url = cohort_server_url |
Oops, something went wrong.