Skip to content

Commit

Permalink
feat: add set_auth method (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
grdsdev authored Aug 8, 2024
1 parent 7485e7c commit 5859c72
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 66 deletions.
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ typing-extensions = "^4.12.2"

[tool.poetry.dev-dependencies]
pytest = "^8.3.1"
python-dotenv = "^1.0.1"

[tool.poetry.group.dev.dependencies]
python-semantic-release = ">=8.3,<10.0"
Expand Down
61 changes: 49 additions & 12 deletions realtime/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple

from realtime.message import ChannelEvents
from realtime.types import Callback

from .presence import RealtimePresence
Expand All @@ -21,6 +23,35 @@ class CallbackListener(NamedTuple):
callback: Callback


class Push:
def __init__(self, channel: Channel, event: str, payload: Dict[str, Any] = {}):
self.channel = channel
self.event = event
self.payload = payload
self.ref = channel.socket._make_ref()

def send(self):
asyncio.get_event_loop().run_until_complete(self._send())

async def _send(self):
self.ref = self.channel.socket._make_ref()

message = {
"topic": self.channel.topic,
"event": self.event,
"payload": self.payload,
"ref": self.ref,
}

try:
await self.socket.ws_connection.send(json.dumps(message))
except Exception as e:
logging.error(f"send push failed: {e}")

def update_payload(self, payload: Dict[str, Any]):
self.payload = {**self.payload, **payload}


class Channel:
"""
`Channel` is an abstraction for a topic listener for an existing socket connection.
Expand Down Expand Up @@ -230,19 +261,25 @@ def rejoin(self) -> None:
):
self.channel_params["filter"] = self.filter

join_req = {
"topic": self.topic,
"event": "phx_join",
"payload": {"config": self.channel_params},
"ref": None,
}
try:
asyncio.get_event_loop().run_until_complete(
self.socket.ws_connection.send(json.dumps(join_req))
access_token_payload = {}

if self.socket._access_token is not None:
access_token_payload["access_token"] = self.socket._access_token

self._push(
ChannelEvents.join,
{"config": self.channel_params, "access_token": access_token_payload},
)

def _push(self, event: str, payload: Dict[str, Any]) -> Push:
if not self.joined:
raise Exception(
f"tried to push '{event}' to '{self.topic}' before joining. Use channel.subscribe() before pushing events"
)
except Exception as e:
print(e)
return

push = Push(self, event, payload)
push.send()
return push

# @Deprecated:
# You should use `subscribe` instead of this low-level method. It will be removed in the future.
Expand Down
19 changes: 18 additions & 1 deletion realtime/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from collections import defaultdict
from functools import wraps
from typing import Any, DefaultDict, Dict, List
from typing import Any, DefaultDict, Dict, List, Union

import websockets

Expand Down Expand Up @@ -55,10 +55,14 @@ def __init__(
self.hb_interval = hb_interval
self.ws_connection: websockets.client.WebSocketClientProtocol
self.kept_alive = False
self.ref = 0
self.auto_reconnect = auto_reconnect

self.channels: DefaultDict[str, List[Channel]] = defaultdict(list)

self._access_token: Union[str, None] = token
self._api_key = token

@ensure_connection
def listen(self) -> None:
"""
Expand Down Expand Up @@ -196,3 +200,16 @@ def summary(self) -> None:
for topic, chans in self.channels.items():
for chan in chans:
print(f"Topic: {topic} | Events: {[e for e, _ in chan.listeners]}]")

@ensure_connection
def set_auth(self, token: Union[str, None]) -> None:
self._access_token = token

for _, channels in self.channels.items():
for channel in channels:
if channel.joined:
channel._push(ChannelEvents.access_token, {"access_token": token})

def _make_ref(self) -> str:
self.ref += 1
return f"{self.ref}"
2 changes: 1 addition & 1 deletion realtime/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ChannelEvents(str, Enum):
reply = "phx_reply"
leave = "phx_leave"
heartbeat = "heartbeat"
auth = "phx_auth"
access_token = "access_token"


PHOENIX_CHANNEL = "phoenix"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import os
from dotenv import load_dotenv

from realtime.connection import Socket

load_dotenv()

@pytest.fixture
def socket() -> Socket:
return Socket(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_ANON_KEY"))

def test_set_auth(socket: Socket):
socket.connect()

socket.set_auth("jwt")
assert socket._access_token == "jwt"
27 changes: 0 additions & 27 deletions tests/tests_close.py

This file was deleted.

24 changes: 0 additions & 24 deletions tests/type_test.py

This file was deleted.

0 comments on commit 5859c72

Please sign in to comment.