Skip to content

Commit

Permalink
Added SSE client
Browse files Browse the repository at this point in the history
  • Loading branch information
vazarkevych committed Aug 26, 2024
1 parent cd1c4f8 commit fa5431d
Showing 1 changed file with 189 additions and 1 deletion.
190 changes: 189 additions & 1 deletion growthbook/growthbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import json
from abc import ABC, abstractmethod
import threading
import logging

from typing import Optional, Any, Set, Tuple, List, Dict
Expand All @@ -23,7 +24,9 @@
from base64 import b64decode
from time import time
import aiohttp
import asyncio

from aiohttp.client_exceptions import ClientConnectorError, ClientResponseError, ClientPayloadError
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from urllib3 import PoolManager
Expand Down Expand Up @@ -805,10 +808,138 @@ def destroy(self) -> None:
self.docs.clear()


class SSEClient:
def __init__(self, api_host, client_key, on_event, reconnect_delay=5, headers=None):
self.api_host = api_host
self.client_key = client_key

self.on_event = on_event
self.reconnect_delay = reconnect_delay

self._sse_session = None
self._sse_thread = None
self._loop = None

self.is_running = False

self.headers = {
"Accept": "application/json; q=0.5, text/event-stream",
"Cache-Control": "no-cache",
}

if headers:
self.headers.update(headers)

def connect(self):
if self.is_running:
logger.debug("Streaming session is already running.")
return

self.is_running = True
self._sse_thread = threading.Thread(target=self._run_sse_channel)
self._sse_thread.start()

def disconnect(self):
self.is_running = False
if self._loop and self._loop.is_running():
future = asyncio.run_coroutine_threadsafe(self._stop_session(), self._loop)
try:
future.result()
except Exception as e:
logger.error(f"Streaming disconnect error: {e}")

if self._sse_thread:
self._sse_thread.join(timeout=5)

logger.debug("Streaming session disconnected")

def _get_sse_url(self, api_host: str, client_key: str) -> str:
api_host = (api_host or "https://cdn.growthbook.io").rstrip("/")
return f"{api_host}/sub/{client_key}"

async def _init_session(self):
url = self._get_sse_url(self.api_host, self.client_key)

while self.is_running:
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
self._sse_session = session

async with session.get(url) as response:
response.raise_for_status()
await self._process_response(response)
except ClientResponseError as e:
logger.error(f"Streaming error, closing connection: {e.status} {e.message}")
self.is_running = False
break
except (ClientConnectorError, ClientPayloadError) as e:
logger.error(f"Streaming error: {e}")
if not self.is_running:
break
await self._wait_for_reconnect()
except TimeoutError:
logger.warning(f"Streaming connection timed out after {self.timeout} seconds.")
await self._wait_for_reconnect()
except asyncio.CancelledError:
logger.debug("Streaming was cancelled.")
break
finally:
await self._close_session()

async def _process_response(self, response):
event_data = {}
async for line in response.content:
decoded_line = line.decode('utf-8').strip()
if decoded_line.startswith("event:"):
event_data['type'] = decoded_line[len("event:"):].strip()
elif decoded_line.startswith("data:"):
event_data['data'] = event_data.get('data', '') + f"\n{decoded_line[len('data:'):].strip()}"
elif not decoded_line:
if 'type' in event_data and 'data' in event_data:
self.on_event(event_data)
event_data = {}

if 'type' in event_data and 'data' in event_data:
self.on_event(event_data)

async def _wait_for_reconnect(self):
logger.debug(f"Attempting to reconnect streaming in {self.reconnect_delay}")
await asyncio.sleep(self.reconnect_delay)

async def _close_session(self):
if self._sse_session:
await self._sse_session.close()
logger.debug("Streaming session closed.")

def _run_sse_channel(self):
self._loop = asyncio.new_event_loop()

try:
self._loop.run_until_complete(self._init_session())
except asyncio.CancelledError:
pass
finally:
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
self._loop.close()

async def _stop_session(self):
if self._sse_session:
await self._sse_session.close()

if self._loop and self._loop.is_running():
tasks = [task for task in asyncio.all_tasks(self._loop) if not task.done()]
for task in tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

class FeatureRepository(object):
def __init__(self) -> None:
self.cache: AbstractFeatureCache = InMemoryFeatureCache()
self.http: Optional[PoolManager] = None
self.sse_client: Optional[SSEClient] = None

def set_cache(self, cache: AbstractFeatureCache) -> None:
self.cache = cache
Expand Down Expand Up @@ -930,6 +1061,14 @@ async def _fetch_features_async(
logger.warning("GrowthBook API response missing features")
return None


def startAutoRefresh(self, api_host, client_key, cb):
self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb)
self.sse_client.connect()

def stopAutoRefresh(self):
self.sse_client.disconnect()

@staticmethod
def _get_features_url(api_host: str, client_key: str) -> str:
api_host = (api_host or "https://cdn.growthbook.io").rstrip("/")
Expand All @@ -939,7 +1078,6 @@ def _get_features_url(api_host: str, client_key: str) -> str:
# Singleton instance
feature_repo = FeatureRepository()


class GrowthBook(object):
def __init__(
self,
Expand All @@ -956,6 +1094,7 @@ def __init__(
forced_variations: dict = {},
sticky_bucket_service: AbstractStickyBucketService = None,
sticky_bucket_identifier_attributes: List[str] = None,
streaming: bool = False,
# Deprecated args
trackingCallback=None,
qaMode: bool = False,
Expand All @@ -981,6 +1120,8 @@ def __init__(
self._qaMode = qa_mode or qaMode
self._trackingCallback = on_experiment_viewed or trackingCallback

self._streaming = streaming

# Deprecated args
self._user = user
self._groups = groups
Expand All @@ -994,6 +1135,10 @@ def __init__(
if features:
self.setFeatures(features)

if self._streaming:
self.load_features()
self.startAutoRefresh()

def load_features(self) -> None:
if not self._client_key:
raise ValueError("Must specify `client_key` to refresh features")
Expand All @@ -1014,6 +1159,49 @@ async def load_features_async(self) -> None:
if features is not None:
self.setFeatures(features)

def features_event_handler(self, features):
decoded = json.loads(features)
if not decoded:
return None

if "encryptedFeatures" in decoded:
if not self._decryption_key:
raise ValueError("Must specify decryption_key")
try:
decrypted = decrypt(decoded["encryptedFeatures"], self._decryption_key)
return json.loads(decrypted)
except Exception:
logger.warning(
"Failed to decrypt features from GrowthBook API response"
)
return None
elif "features" in decoded:
self.set_features(decoded["features"])
else:
logger.warning("GrowthBook API response missing features")

def dispatch_sse_event(self, event_data):
event_type = event_data['type']
data = event_data['data']
if event_type == 'features-updated':
self.load_features()
elif event_type == 'features':
self.features_event_handler(data)


def startAutoRefresh(self):
if not self._client_key:
raise ValueError("Must specify `client_key` to start features streaming")

feature_repo.startAutoRefresh(
api_host=self._api_host,
client_key=self._client_key,
cb=self.dispatch_sse_event
)

def stopAutoRefresh(self):
feature_repo.stopAutoRefresh()

# @deprecated, use set_features
def setFeatures(self, features: dict) -> None:
return self.set_features(features)
Expand Down

0 comments on commit fa5431d

Please sign in to comment.