Skip to content

Commit

Permalink
Add GraphQL Transport WS Support
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsureshkumar committed Mar 19, 2024
1 parent 481493f commit c601dd5
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 60 deletions.
2 changes: 2 additions & 0 deletions examples/fastapi/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from starlette.websockets import WebSocket

from graphql_ws.integrations.fastapi.server import FastAPISubscriptionServer
from graphql_ws.protocols import ProtocolEnum
from graphql_ws.subscription_managers import AsyncSubscriptionManager
from .schema import schema

Expand Down Expand Up @@ -43,5 +44,6 @@ async def websocket_endpoint(websocket: WebSocket):
request_context=Context(
request=websocket, subscription_manager_=subscription_manager
),
protocol=ProtocolEnum.GRAPHQL_TRANSPORT_WS
)
return websocket
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,12 @@ class CompleteGraphQLTransportWSMessage(BiDirectionalMessage):

type: str = "complete"

@property
def data(self) -> dict:
return {
"id": self.id,
"type": self.type
}

def __init__(self, id_: str):
self.id = id_
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class ConnectionAckGraphQLTransportWSMessage(ServerToClientMessage):

def __init__(self, payload: dict | None = None):
self.payload: dict = payload

@property
def data(self) -> dict:
return {"type": self.type, "payload": self.payload}
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
class ConnectionInitGraphQLTransportWSMessage:
from graphql_ws.protocols.messages import ClientToServerMessage


class ConnectionInitGraphQLTransportWSMessage(ClientToServerMessage):
"""
Direction: Client -> Server
Expand All @@ -19,4 +22,5 @@ class ConnectionInitGraphQLTransportWSMessage:
type: str = "connection_init"

def __init__(self, payload: dict | None = None):
super().__init__(self.type, payload)
self.payload: dict = payload
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ class NextGraphQLTransportWSMessage(ServerToClientMessage):
been emitted, the Complete message will follow indicating stream completion.
"""

type: str = "next"

def __init__(self, id_: str, payload: ExecutionResult | None = None):
self.id = id_
self.type: str = "next"
self.payload: ExecutionResult = payload

@property
def data(self) -> dict:
return {"id": self.id, "type": self.type, "payload": self.payload}
11 changes: 8 additions & 3 deletions graphql_ws/protocols/graphql_transport_ws/message_types/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class PingGraphQLTransportWSMessage(BiDirectionalMessage):
The optional payload field can be used to transfer additional details about the ping.
"""

def __init__(self, payload: dict | None = None):
self.type: str = "ping"
self.payload: dict = payload
type: str = "ping"

@property
def data(self) -> dict:
return {"type": self.type}

def __init__(self):
pass
11 changes: 8 additions & 3 deletions graphql_ws/protocols/graphql_transport_ws/message_types/pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ class PongGraphQLTransportWSMessage(BiDirectionalMessage):
The optional payload field can be used to transfer additional details about the pong.
"""

def __init__(self, payload: dict | None = None):
self.type: str = "pong"
self.payload: dict = payload
type: str = "pong"

@property
def data(self) -> dict:
return {"type": self.type}

def __init__(self):
pass
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from graphql_ws.protocols.messages import ClientToServerMessage


class SubscribeMessagePayload(object):
def __init__(
self, operation_name: str | None, query: str, variables: dict, extensions: dict
):
self.operation_name = operation_name
self.query = query
self.variables = variables
self.extensions = extensions


class SubscribeGraphQLTransportWSMessage(ClientToServerMessage):
"""
Direction: Client -> Server
Expand All @@ -31,8 +21,9 @@ class SubscribeGraphQLTransportWSMessage(ClientToServerMessage):
"""

def __init__(self, id_: str, payload: SubscribeMessagePayload, type: str):
super().__init__(type, payload, id_)
type: str = "subscribe"

def __init__(self, id_: str, payload: dict):
super().__init__(self.type, payload, id_)
self.id = id_
self.type: str = "subscribe"
self.payload: SubscribeMessagePayload = payload
self.payload: dict = payload
9 changes: 8 additions & 1 deletion graphql_ws/protocols/messages/bi_directional.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from abc import abstractmethod


class BiDirectionalMessage:
pass

@property
@abstractmethod
def data(self) -> dict:
pass
25 changes: 19 additions & 6 deletions graphql_ws/protocols/messages/message_parser.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json

from graphql_ws.protocols.exceptions import UnSupportedProtocolException
from graphql_ws.protocols.graphql_transport_ws.message_types import ConnectionInitGraphQLTransportWSMessage, \
SubscribeGraphQLTransportWSMessage, PingGraphQLTransportWSMessage, PongGraphQLTransportWSMessage, \
CompleteGraphQLTransportWSMessage
from graphql_ws.protocols.graphql_ws.message_types import (
ConnectionInitGraphQLWSMessage,
StartGraphQLWSMessage,
ConnectionTerminateGraphQLWSMessage,
)
from graphql_ws.protocols.messages import ClientToServerMessage
from graphql_ws.protocols.messages import ClientToServerMessage, BiDirectionalMessage
from graphql_ws.protocols.messages.exceptions import ClientToServerMessageInvalid
from graphql_ws.protocols.protocol import ProtocolEnum

Expand All @@ -15,7 +18,7 @@ class MessageParser(object):
def __init__(self, protocol: ProtocolEnum):
self.protocol = protocol

def parse_client_message(self, message: str) -> ClientToServerMessage:
def parse_client_message(self, message: str) -> ClientToServerMessage | BiDirectionalMessage:
message = json.loads(message)
if self.protocol == ProtocolEnum.GRAPHQL_WS:
match message.get("type"):
Expand All @@ -33,10 +36,20 @@ def parse_client_message(self, message: str) -> ClientToServerMessage:
raise ClientToServerMessageInvalid()
elif self.protocol == ProtocolEnum.GRAPHQL_TRANSPORT_WS:
match message.get("type"):
case 1:
print("This is case 1")
case 2:
print("This is case 2")
case ConnectionInitGraphQLTransportWSMessage.type:
return ConnectionInitGraphQLTransportWSMessage(
payload=message.get("payload")
)
case SubscribeGraphQLTransportWSMessage.type:
return SubscribeGraphQLTransportWSMessage(
id_=message.get("id"), payload=message.get("payload")
)
case PingGraphQLTransportWSMessage.type:
return PingGraphQLTransportWSMessage()
case PongGraphQLTransportWSMessage.type:
return PongGraphQLTransportWSMessage()
case CompleteGraphQLTransportWSMessage.type:
return CompleteGraphQLTransportWSMessage(id_=message.get("id"))
case _:
raise ClientToServerMessageInvalid()
else:
Expand Down
106 changes: 75 additions & 31 deletions graphql_ws/servers/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from ..contexts import AsyncConnectionContext, BaseConnectionContext
from ..protocols import ProtocolEnum
from ..protocols.graphql_transport_ws.message_types import (
SubscribeGraphQLTransportWSMessage,
SubscribeGraphQLTransportWSMessage, ConnectionInitGraphQLTransportWSMessage, ConnectionAckGraphQLTransportWSMessage,
NextGraphQLTransportWSMessage, PingGraphQLTransportWSMessage, PongGraphQLTransportWSMessage,
CompleteGraphQLTransportWSMessage,
)
from ..protocols.graphql_ws.message_types import (
ConnectionInitGraphQLWSMessage,
Expand Down Expand Up @@ -32,12 +34,15 @@ def __init__(self, schema, subscription_manager: AsyncSubscriptionManager):
super().__init__(schema, subscription_manager=subscription_manager)

async def execute(self, params):
return await self.schema.subscribe(
query=params.get("request_string"),
operation_name=params.get("operation_name"),
variable_values=params.get("variable_values"),
context_value=params.get("context_value"),
)
try:
return await self.schema.subscribe(
query=params.get("request_string"),
operation_name=params.get("operation_name"),
variable_values=params.get("variable_values"),
context_value=params.get("context_value"),
)
except Exception as error:
a = 1

async def on_message(
self,
Expand All @@ -61,7 +66,7 @@ async def process_message(
connection_context: AsyncConnectionContext,
message: ClientToServerMessage | BiDirectionalMessage,
):
if protocol.GRAPHQL_WS:
if protocol == protocol.GRAPHQL_WS:
if isinstance(message, ConnectionInitGraphQLWSMessage):
await self.on_connection_init(connection_context, message)
elif isinstance(message, StartGraphQLWSMessage):
Expand All @@ -74,8 +79,20 @@ async def process_message(
connection_context=connection_context
)
await self.on_terminate(connection_context, message)
elif protocol.GRAPHQL_TRANSPORT_WS:
pass
elif protocol == protocol.GRAPHQL_TRANSPORT_WS:
if isinstance(message, ConnectionInitGraphQLTransportWSMessage):
await self.on_connection_init(connection_context, message)
elif isinstance(message, PingGraphQLTransportWSMessage):
await self.send_message(connection_context, PongGraphQLTransportWSMessage())
elif isinstance(message, PongGraphQLTransportWSMessage):
await self.send_message(connection_context, PingGraphQLTransportWSMessage())
elif isinstance(message, SubscribeGraphQLTransportWSMessage):
connection_context.id = message.id
await self.on_start(connection_context, message)
elif isinstance(message, CompleteGraphQLTransportWSMessage):
await self.subscription_manager.unsubscribe(
connection_context=connection_context
)
else:
pass

Expand All @@ -95,34 +112,56 @@ async def on_start(
if hasattr(execution_result, "__aiter__"):
iterator = execution_result.__aiter__()
async for result in iterator:
if isinstance(message, StartGraphQLWSMessage):
await self.send_message(
connection_context,
message=DataGraphQLWSMessage(
id_=message.id, payload=result.formatted
),
)
elif isinstance(message, SubscribeGraphQLTransportWSMessage):
await self.send_message(
connection_context,
message=NextGraphQLTransportWSMessage(
id_=message.id, payload=result.formatted
),
)
else:
if is_awaitable(execution_result):
execution_result = await execution_result
if isinstance(message, StartGraphQLWSMessage):
await self.send_message(
connection_context,
message=DataGraphQLWSMessage(
id_=message.id, payload=result.formatted
id_=message.id, payload=execution_result.formatted
),
)
elif isinstance(message, SubscribeGraphQLTransportWSMessage):
await self.send_message(
connection_context,
message=NextGraphQLTransportWSMessage(
id_=message.id, payload=execution_result.formatted
),
)
else:
if is_awaitable(execution_result):
execution_result = await execution_result
await self.send_message(
connection_context,
message=DataGraphQLWSMessage(
id_=message.id, payload=execution_result.formatted
),
)
except SubscriberAlreadyExistException as error:
pass
# await self.send_error(connection_context, error)
else:
await self.on_complete(
connection_context=connection_context,
message=CompleteGraphQLWSMessage(id_=message.id),
)
if isinstance(message, StartGraphQLWSMessage):
await self.on_complete(
connection_context=connection_context,
message=CompleteGraphQLWSMessage(id_=message.id),
)
elif isinstance(message, SubscribeGraphQLTransportWSMessage):
await self.on_complete(
connection_context=connection_context,
message=CompleteGraphQLTransportWSMessage(id_=message.id),
)

async def on_complete(
self,
connection_context: AsyncConnectionContext,
message: CompleteGraphQLWSMessage,
message: CompleteGraphQLWSMessage | CompleteGraphQLTransportWSMessage,
):
await self.send_message(connection_context, message=message)
await self.subscription_manager.unsubscribe(
Expand All @@ -133,7 +172,7 @@ async def on_complete(
async def on_terminate(
self,
connection_context: AsyncConnectionContext,
message: ConnectionTerminateGraphQLWSMessage,
message: ConnectionTerminateGraphQLWSMessage | StopGraphQLWSMessage,
):
await self.subscription_manager.unsubscribe(
connection_context=connection_context
Expand Down Expand Up @@ -164,13 +203,18 @@ async def send_error(
async def on_connection_init(
self,
connection_context: AsyncConnectionContext,
message: ConnectionInitGraphQLWSMessage,
message: ConnectionInitGraphQLWSMessage | ConnectionInitGraphQLTransportWSMessage,
):
await self.on_connect(connection_context, message)
await self.send_message(
connection_context,
message=ConnectionAckGraphQLWSMessage(payload=message.payload),
)
if isinstance(message, ConnectionInitGraphQLWSMessage):
await self.send_message(
connection_context,
message=ConnectionAckGraphQLWSMessage(payload=message.payload),
)
elif isinstance(message, ConnectionInitGraphQLTransportWSMessage):
await self.send_message(
connection_context,
message=ConnectionAckGraphQLTransportWSMessage(payload=message.payload))

async def send_message(
self,
Expand Down
2 changes: 1 addition & 1 deletion graphql_ws/servers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def on_complete(
def on_terminate(
self,
connection_context: BaseConnectionContext,
message: ConnectionTerminateGraphQLWSMessage,
message: ConnectionTerminateGraphQLWSMessage | StopGraphQLWSMessage,
):
pass

Expand Down

0 comments on commit c601dd5

Please sign in to comment.