Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support cohort targeting for local evaluation #47

Merged
merged 44 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
def9924
feat: Enable use of project API key for default deployments
tyiuhc Feb 1, 2024
405dcb6
initial commit
tyiuhc Jun 6, 2024
eac1cd3
update local eval client
tyiuhc Jun 6, 2024
7f864d5
fix imports
tyiuhc Jun 6, 2024
3ef92da
refactor
tyiuhc Jun 10, 2024
473fe7b
fix tests, add logging config
tyiuhc Jun 10, 2024
3d5dfe9
add CohortNotModifiedException
tyiuhc Jun 12, 2024
93eb012
update user transformation to evaluation context
tyiuhc Jun 13, 2024
981de9f
refactor and simplify to not use cohort_description
tyiuhc Jun 14, 2024
080e8f0
nit: fix formatting
tyiuhc Jun 14, 2024
23c8b25
handle flag fetch fail
tyiuhc Jun 14, 2024
f249b44
Use lastModified instead of lastComputed
tyiuhc Jun 17, 2024
04d35f6
add cohort_request_delay_millis to config
tyiuhc Jun 17, 2024
b4402eb
fix DirectCohortDownloadApi constructor
tyiuhc Jun 17, 2024
a7ffbb6
Simplify deployment_runner, clean up comments
tyiuhc Jun 25, 2024
f88d56e
revert default deployment changes
tyiuhc Jun 25, 2024
f98f09f
Update cohort sync config with comments and server_url config
tyiuhc Jun 26, 2024
7243d9b
fix EU flag url
tyiuhc Jun 26, 2024
4fe1fac
export CohortSyncConfig and ServerZone
tyiuhc Jun 26, 2024
a51067b
nit: simplify logic
tyiuhc Jun 27, 2024
f0e899b
Handle 204 errors
tyiuhc Jun 28, 2024
5130cdd
update deployment_runner flag/cohort update logic, update tests, fix …
tyiuhc Jul 2, 2024
5ed0a98
Update logger requirement for classes
tyiuhc Jul 3, 2024
916e5c1
Refactor cohort_loader update_storage_cohorts
tyiuhc Jul 3, 2024
013ffc9
fix lint
tyiuhc Jul 3, 2024
57e1cc2
remove unnecessary import
tyiuhc Jul 22, 2024
93d2f15
update test.yml
tyiuhc Jul 23, 2024
5b30cb7
add client cohort ci tests
tyiuhc Jul 24, 2024
12267a0
update requirements-dev dotenv version
tyiuhc Jul 24, 2024
03b9081
debug env vars
tyiuhc Jul 24, 2024
832a00c
test yml set env vars
tyiuhc Jul 24, 2024
135a286
test cases use os.environ for secrets
tyiuhc Jul 24, 2024
e1ff4a2
test-arm.yml env syntax fix
tyiuhc Jul 25, 2024
9d6d62f
update client tests
tyiuhc Jul 25, 2024
7fddc96
cohort not modified should not throw exception
tyiuhc Jul 30, 2024
8cfb128
nit: update test name
tyiuhc Jul 31, 2024
c71e0c7
do not throw exception upon start() if cohort download fails, log war…
tyiuhc Aug 1, 2024
85c6cf3
fix deployment runner logging
tyiuhc Aug 1, 2024
646dd5a
nit: fix test name
tyiuhc Aug 1, 2024
9864e46
update error log and test
tyiuhc Aug 1, 2024
7043732
update_stored_cohorts using load_cohort
tyiuhc Aug 5, 2024
1d974f1
refresh cohorts based on flag configs in storage
tyiuhc Aug 6, 2024
06e693e
update cohort_sync_config fields: include polling and remove request …
tyiuhc Aug 6, 2024
a9006cf
add SDK+version to cohort request header
tyiuhc Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/amplitude_experiment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
from .cookie import AmplitudeCookie
from .local.client import LocalEvaluationClient
from .local.config import LocalEvaluationConfig
from .local.config import ServerZone
from .assignment import AssignmentConfig
from .cohort.cohort_sync_config import CohortSyncConfig
13 changes: 13 additions & 0 deletions src/amplitude_experiment/cohort/cohort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass, field
from typing import ClassVar, Set

USER_GROUP_TYPE: ClassVar[str] = "User"


@dataclass
class Cohort:
id: str
last_modified: int
size: int
member_ids: Set[str]
group_type: str = field(default=USER_GROUP_TYPE)
90 changes: 90 additions & 0 deletions src/amplitude_experiment/cohort/cohort_download_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import time
import logging
import base64
import json
from http.client import HTTPResponse
from typing import Optional

from .cohort import Cohort
from ..connection_pool import HTTPConnectionPool
from ..exception import HTTPErrorResponseException, CohortTooLargeException, CohortNotModifiedException


class CohortDownloadApi:

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:
raise NotImplementedError


class DirectCohortDownloadApi(CohortDownloadApi):
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int,
server_url: str, debug: bool):
super().__init__()
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_request_delay_millis = cohort_request_delay_millis
self.logger = logging.getLogger("Amplitude")
self.logger.addHandler(logging.StreamHandler())
self.server_url = server_url
if debug:
self.logger.setLevel(logging.DEBUG)
self.__setup_connection_pool()

def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:
self.logger.debug(f"getCohortMembers({cohort_id}): start")
errors = 0
while True:
response = None
try:
last_modified = None if cohort is None else cohort.last_modified
response = self._get_cohort_members_request(cohort_id, last_modified)
self.logger.debug(f"getCohortMembers({cohort_id}): status={response.status}")
if response.status == 200:
cohort_info = json.loads(response.read().decode("utf8"))
self.logger.debug(f"getCohortMembers({cohort_id}): end - resultSize={cohort_info['size']}")
return Cohort(
id=cohort_info['cohortId'],
last_modified=cohort_info['lastModified'],
size=cohort_info['size'],
member_ids=set(cohort_info['memberIds']),
group_type=cohort_info['groupType'],
)
elif response.status == 204:
raise CohortNotModifiedException(f"Cohort not modified: {response.status}")
elif response.status == 413:
raise CohortTooLargeException(f"Cohort exceeds max cohort size: {response.status}")
elif response.status != 202:
raise HTTPErrorResponseException(response.status,
f"Unexpected response code: {response.status}")
except Exception as e:
if response and not isinstance(e, HTTPErrorResponseException) and response.status != 429:
errors += 1
self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}")
if errors >= 3 or isinstance(e, CohortNotModifiedException) or isinstance(e, CohortTooLargeException):
raise e
time.sleep(self.cohort_request_delay_millis/1000)

def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse:
headers = {
'Authorization': f'Basic {self._get_basic_auth()}',
}
conn = self._connection_pool.acquire()
try:
url = f'/sdk/v1/cohort/{cohort_id}?maxCohortSize={self.max_cohort_size}'
if last_modified is not None:
url += f'&lastModified={last_modified}'
response = conn.request('GET', url, headers=headers)
return response
finally:
self._connection_pool.release(conn)

def _get_basic_auth(self) -> str:
credentials = f'{self.api_key}:{self.secret_key}'
return base64.b64encode(credentials.encode('utf-8')).decode('utf-8')

def __setup_connection_pool(self):
scheme, _, host = self.server_url.split('/', 3)
timeout = 10
self._connection_pool = HTTPConnectionPool(host, max_size=10, idle_timeout=30, read_timeout=timeout,
scheme=scheme)
42 changes: 42 additions & 0 deletions src/amplitude_experiment/cohort/cohort_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Dict, Set
from concurrent.futures import ThreadPoolExecutor, Future
import threading

from .cohort import Cohort
from .cohort_download_api import CohortDownloadApi
from .cohort_storage import CohortStorage


class CohortLoader:
def __init__(self, cohort_download_api: CohortDownloadApi, cohort_storage: CohortStorage):
self.cohort_download_api = cohort_download_api
self.cohort_storage = cohort_storage
self.jobs: Dict[str, Future] = {}
self.lock_jobs = threading.Lock()
self.executor = ThreadPoolExecutor(
max_workers=32,
thread_name_prefix='CohortLoaderExecutor'
)

def load_cohort(self, cohort_id: str) -> Future:
with self.lock_jobs:
if cohort_id not in self.jobs:
def task():
try:
cohort = self.download_cohort(cohort_id)
self.cohort_storage.put_cohort(cohort)
except Exception as e:
raise e

future = self.executor.submit(task)
future.add_done_callback(lambda f: self._remove_job(cohort_id))
self.jobs[cohort_id] = future
return self.jobs[cohort_id]

def _remove_job(self, cohort_id: str):
if cohort_id in self.jobs:
del self.jobs[cohort_id]

def download_cohort(self, cohort_id: str) -> Cohort:
cohort = self.cohort_storage.get_cohort(cohort_id)
return self.cohort_download_api.get_cohort(cohort_id, cohort)
66 changes: 66 additions & 0 deletions src/amplitude_experiment/cohort/cohort_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Dict, Set, Optional
from threading import RLock

from .cohort import Cohort, USER_GROUP_TYPE


class CohortStorage:
def get_cohort(self, cohort_id: str):
raise NotImplementedError

def get_cohorts(self):
raise NotImplementedError

def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
raise NotImplementedError

def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
raise NotImplementedError

def put_cohort(self, cohort_description: Cohort):
raise NotImplementedError

def delete_cohort(self, group_type: str, cohort_id: str):
raise NotImplementedError


class InMemoryCohortStorage(CohortStorage):
def __init__(self):
self.lock = RLock()
self.group_to_cohort_store: Dict[str, Set[str]] = {}
self.cohort_store: Dict[str, Cohort] = {}

def get_cohort(self, cohort_id: str):
with self.lock:
return self.cohort_store.get(cohort_id)

def get_cohorts(self):
return self.cohort_store.copy()

def get_cohorts_for_user(self, user_id: str, cohort_ids: Set[str]) -> Set[str]:
return self.get_cohorts_for_group(USER_GROUP_TYPE, user_id, cohort_ids)

def get_cohorts_for_group(self, group_type: str, group_name: str, cohort_ids: Set[str]) -> Set[str]:
result = set()
with self.lock:
group_type_cohorts = self.group_to_cohort_store.get(group_type, {})
for cohort_id in group_type_cohorts:
members = self.cohort_store.get(cohort_id).member_ids
if cohort_id in cohort_ids and group_name in members:
result.add(cohort_id)
return result

def put_cohort(self, cohort: Cohort):
with self.lock:
if cohort.group_type not in self.group_to_cohort_store:
self.group_to_cohort_store[cohort.group_type] = set()
self.group_to_cohort_store[cohort.group_type].add(cohort.id)
self.cohort_store[cohort.id] = cohort

def delete_cohort(self, group_type: str, cohort_id: str):
with self.lock:
group_cohorts = self.group_to_cohort_store.get(group_type, {})
if cohort_id in group_cohorts:
group_cohorts.remove(cohort_id)
if cohort_id in self.cohort_store:
del self.cohort_store[cohort_id]
23 changes: 23 additions & 0 deletions src/amplitude_experiment/cohort/cohort_sync_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
DEFAULT_COHORT_SYNC_URL = 'https://cohort-v2.lab.amplitude.com'
EU_COHORT_SYNC_URL = 'https://cohort-v2.lab.eu.amplitude.com'


class CohortSyncConfig:
"""Experiment Cohort Sync Configuration
This configuration is used to set up the cohort loader. The cohort loader is responsible for
downloading cohorts from the server and storing them locally.
Parameters:
api_key (str): The project API Key
secret_key (str): The project Secret Key
max_cohort_size (int): The maximum cohort size that can be downloaded
cohort_request_delay_millis (int): The delay in milliseconds between cohort download requests
cohort_server_url (str): The server endpoint from which to request cohorts
"""

def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 15000,
cohort_request_delay_millis: int = 5000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL):
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_request_delay_millis = cohort_request_delay_millis
self.cohort_server_url = cohort_server_url
113 changes: 113 additions & 0 deletions src/amplitude_experiment/deployment/deployment_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import logging
from typing import Optional, Set
import threading

from ..exception import CohortNotModifiedException
from ..local.config import LocalEvaluationConfig
from ..cohort.cohort_loader import CohortLoader
from ..cohort.cohort_storage import CohortStorage
from ..flag.flag_config_api import FlagConfigApi
from ..flag.flag_config_storage import FlagConfigStorage
from ..local.poller import Poller
from ..util.flag_config import get_all_cohort_ids_from_flag


class DeploymentRunner:
def __init__(
self,
config: LocalEvaluationConfig,
flag_config_api: FlagConfigApi,
flag_config_storage: FlagConfigStorage,
cohort_storage: CohortStorage,
cohort_loader: Optional[CohortLoader] = None,
):
self.config = config
self.flag_config_api = flag_config_api
self.flag_config_storage = flag_config_storage
self.cohort_storage = cohort_storage
self.cohort_loader = cohort_loader
self.lock = threading.Lock()
self.poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_refresh)
self.logger = logging.getLogger("Amplitude")
self.logger.addHandler(logging.StreamHandler())
if self.config.debug:
self.logger.setLevel(logging.DEBUG)

def start(self):
with self.lock:
self.refresh()
self.poller.start()

def stop(self):
self.poller.stop()

def __periodic_refresh(self):
try:
self.refresh()
except Exception as e:
self.logger.error(f"Refresh flag and cohort configs failed: {e}")

def refresh(self):
self.logger.debug("Refreshing flag configs.")
try:
flag_configs = self.flag_config_api.get_flag_configs()
except Exception as e:
self.logger.error(f'Failed to fetch flag configs: {e}')
raise Exception

flag_keys = {flag['key'] for flag in flag_configs}
self.flag_config_storage.remove_if(lambda f: f['key'] not in flag_keys)

for flag_config in flag_configs:
tyiuhc marked this conversation as resolved.
Show resolved Hide resolved
cohort_ids = get_all_cohort_ids_from_flag(flag_config)
if not self.cohort_loader or not cohort_ids:
self.logger.debug(f"Putting non-cohort flag {flag_config['key']}")
self.flag_config_storage.put_flag_config(flag_config)
continue

# Keep track of old flag and cohort for each flag
old_flag_config = self.flag_config_storage.get_flag_config(flag_config['key'])

try:
self._load_cohorts(flag_config, cohort_ids)
tyiuhc marked this conversation as resolved.
Show resolved Hide resolved
self.flag_config_storage.put_flag_config(flag_config) # Store new flag config
self.logger.debug(f"Stored flag config {flag_config['key']}")

except Exception as e:
self.logger.warning(f"Failed to load all cohorts for flag {flag_config['key']}. "
f"Using the old flag config.")
self.flag_config_storage.put_flag_config(old_flag_config)
raise e

self._delete_unused_cohorts()
self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.")

def _load_cohorts(self, flag_config: dict, cohort_ids: Set[str]):
def task():
try:
for cohort_id in cohort_ids:
future = self.cohort_loader.load_cohort(cohort_id)
future.result()
self.logger.debug(f"Cohort {cohort_id} loaded for flag {flag_config['key']}")
except Exception as e:
if not isinstance(e, CohortNotModifiedException):
self.logger.error(f"Failed to load cohorts for flag {flag_config['key']}: {e}")
raise e

cohort_fetched = self.cohort_loader.executor.submit(task)
# Wait for both flag and cohort loading to complete
cohort_fetched.result()

def _delete_unused_cohorts(self):
flag_cohort_ids = set()
for flag in self.flag_config_storage.get_flag_configs().values():
flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag))

storage_cohorts = self.cohort_storage.get_cohorts()
deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids

for deleted_cohort_id in deleted_cohort_ids:
deleted_cohort = storage_cohorts.get(deleted_cohort_id)
if deleted_cohort is not None:
self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id)

16 changes: 16 additions & 0 deletions src/amplitude_experiment/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,19 @@ class FetchException(Exception):
def __init__(self, status_code, message):
super().__init__(message)
self.status_code = status_code


class CohortNotModifiedException(Exception):
def __init__(self, message):
super().__init__(message)


class CohortTooLargeException(Exception):
def __init__(self, message):
super().__init__(message)


class HTTPErrorResponseException(Exception):
def __init__(self, status_code, message):
super().__init__(message)
self.status_code = status_code
Loading
Loading