Skip to content

Commit

Permalink
chore: typing, linting, and formatting work
Browse files Browse the repository at this point in the history
  • Loading branch information
pgautier404 committed Jul 10, 2023
1 parent f73ba5b commit 17d478c
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 24 deletions.
4 changes: 3 additions & 1 deletion src/momento/internal/aio/_scs_grpc_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional

import grpc
from momento_wire_types import cacheclient_pb2_grpc as cache_client
from momento_wire_types import cachepubsub_pb2_grpc as pubsub_client
Expand Down Expand Up @@ -108,7 +110,7 @@ def async_stub(self) -> pubsub_client.PubsubStub:
return pubsub_client.PubsubStub(self._secure_channel) # type: ignore[no-untyped-call]


def _interceptors(auth_token: str, retry_strategy: RetryStrategy = None) -> list[grpc.aio.ClientInterceptor]:
def _interceptors(auth_token: str, retry_strategy: Optional[RetryStrategy] = None) -> list[grpc.aio.ClientInterceptor]:
headers = [
Header("authorization", auth_token),
Header("agent", f"python:{_ControlGrpcManager.version}"),
Expand Down
10 changes: 5 additions & 5 deletions src/momento/internal/aio/_scs_pubsub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def publish(self, cache_name: str, topic_name: str, value: str | bytes) ->
value=topic_value,
)

await self._build_stub().Publish(
await self._build_stub().Publish( # type: ignore[misc]
request,
)
return TopicPublish.Success()
Expand All @@ -85,21 +85,21 @@ async def subscribe(self, cache_name: str, topic_name: str) -> TopicSubscribeRes
topic=topic_name,
# TODO: resume_at_topic_sequence_number
)
stream = self._build_stream_stub().Subscribe(
stream = self._build_stream_stub().Subscribe( # type: ignore[misc]
request,
)

# Ping the stream to provide a nice error message if the cache does not exist.
msg = await stream.read()
msg_type = msg.WhichOneof("kind")
msg: pubsub_pb._SubscriptionItem = await stream.read() # type: ignore[misc]
msg_type: str = msg.WhichOneof("kind")
if msg_type == "heartbeat":
# The first message to a new subscription is always a heartbeat.
pass
else:
err = Exception(f"expected a heartbeat message but got '{msg_type}'")
self._log_request_error("subscribe", err)
return TopicSubscribe.Error(convert_error(err))
return TopicSubscribe.SubscriptionAsync(cache_name, topic_name, client_stream=stream)
return TopicSubscribe.SubscriptionAsync(cache_name, topic_name, client_stream=stream) # type: ignore[misc]
except Exception as e:
self._log_request_error("subscribe", e)
return TopicSubscribe.Error(convert_error(e))
Expand Down
10 changes: 7 additions & 3 deletions src/momento/internal/synchronous/_scs_grpc_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional

import grpc
from momento_wire_types import cacheclient_pb2_grpc as cache_client
from momento_wire_types import cachepubsub_pb2_grpc as pubsub_client
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede
intercept_channel = grpc.intercept_channel(
self._secure_channel, *_interceptors(credential_provider.auth_token, None)
)
self._stub = pubsub_client.PubsubStub(intercept_channel)
self._stub = pubsub_client.PubsubStub(intercept_channel) # type: ignore[no-untyped-call]

def close(self) -> None:
self._secure_channel.close()
Expand All @@ -95,7 +97,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede
intercept_channel = grpc.intercept_channel(
self._secure_channel, *_stream_interceptors(credential_provider.auth_token)
)
self._stub = pubsub_client.PubsubStub(intercept_channel)
self._stub = pubsub_client.PubsubStub(intercept_channel) # type: ignore[no-untyped-call]

def close(self) -> None:
self._secure_channel.close()
Expand All @@ -104,7 +106,9 @@ def stub(self) -> pubsub_client.PubsubStub:
return self._stub


def _interceptors(auth_token: str, retry_strategy: RetryStrategy = None) -> list[grpc.UnaryUnaryClientInterceptor]:
def _interceptors(
auth_token: str, retry_strategy: Optional[RetryStrategy] = None
) -> list[grpc.UnaryUnaryClientInterceptor]:
headers = [Header("authorization", auth_token), Header("agent", f"python:{_ControlGrpcManager.version}")]
return list(
filter(
Expand Down
8 changes: 4 additions & 4 deletions src/momento/internal/synchronous/_scs_pubsub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ def subscribe(self, cache_name: str, topic_name: str) -> TopicSubscribeResponse:
topic=topic_name,
# TODO: resume_at_topic_sequence_number
)
stream = self._build_stream_stub().Subscribe(
stream = self._build_stream_stub().Subscribe( # type: ignore[misc]
request,
)

# Ping the stream to provide a nice error message if the cache does not exist.
msg = stream.next()
msg_type = msg.WhichOneof("kind")
msg: pubsub_pb._SubscriptionItem = stream.next() # type: ignore[misc]
msg_type: str = msg.WhichOneof("kind")
if msg_type == "heartbeat":
# The first message to a new subscription is always a heartbeat.
pass
else:
err = Exception(f"expected a heartbeat message but got '{msg_type}'")
self._log_request_error("subscribe", err)
return TopicSubscribe.Error(convert_error(err))
return TopicSubscribe.Subscription(cache_name, topic_name, client_stream=stream)
return TopicSubscribe.Subscription(cache_name, topic_name, client_stream=stream) # type: ignore[misc]
except Exception as e:
self._log_request_error("subscribe", e)
return TopicSubscribe.Error(convert_error(e))
Expand Down
32 changes: 21 additions & 11 deletions src/momento/responses/pubsub/subscribe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from abc import ABC
from grpc.aio._interceptor import InterceptedUnaryStreamCall
from grpc._channel import _MultiThreadedRendezvous
from typing import Optional

from grpc._channel import _MultiThreadedRendezvous
from grpc.aio._interceptor import InterceptedUnaryStreamCall
from momento_wire_types import cachepubsub_pb2

from ... import logs
from ..mixins import ErrorResponseMixin
from ..response import PubsubResponse
from .subscription_item import TopicSubscriptionItem, TopicSubscriptionItemResponse

from momento_wire_types import cachepubsub_pb2


class TopicSubscribeResponse(PubsubResponse):
"""Parent response type for a topic `publish` request.
Expand All @@ -27,12 +27,8 @@ class TopicSubscribe(ABC):
class SubscriptionBase(TopicSubscribeResponse):
"""Base class for common logic shared between async and synchronous subscriptions."""

def __init__(self, cache_name: str, topic_name: str, client_stream: InterceptedUnaryStreamCall):
self._logger = logs.logger
self._cache_name = cache_name
self._topic_name = topic_name
self._client_stream = client_stream # type: ignore[misc]
self._last_known_sequence_number: Optional[int] = None
_logger = logs.logger
_last_known_sequence_number: Optional[int] = None

def _process_result(self, result: cachepubsub_pb2._SubscriptionItem) -> Optional[TopicSubscriptionItemResponse]:
msg_type: str = result.WhichOneof("kind")
Expand All @@ -54,9 +50,15 @@ def _process_result(self, result: cachepubsub_pb2._SubscriptionItem) -> Optional
return None

class SubscriptionAsync(SubscriptionBase):
"""Indicates the request was successful."""
"""Provides the async version of a topic subscription."""

def __init__(self, cache_name: str, topic_name: str, client_stream: InterceptedUnaryStreamCall):
self._cache_name = cache_name
self._topic_name = topic_name
self._client_stream = client_stream # type: ignore[misc]

async def item(self) -> TopicSubscriptionItemResponse:
"""Returns the next published item from the subscription."""
while True:
try:
result: cachepubsub_pb2._SubscriptionItem = await self._client_stream.read() # type: ignore[misc]
Expand All @@ -70,7 +72,15 @@ async def item(self) -> TopicSubscriptionItemResponse:
return item

class Subscription(SubscriptionBase):
"""Provides the synchronous version of a topic subscription."""

def __init__(self, cache_name: str, topic_name: str, client_stream: _MultiThreadedRendezvous):
self._cache_name = cache_name
self._topic_name = topic_name
self._client_stream = client_stream # type: ignore[misc]

def item(self) -> TopicSubscriptionItemResponse:
"""Returns the next published item from the subscription."""
while True:
try:
result: cachepubsub_pb2._SubscriptionItem = self._client_stream.next() # type: ignore[misc]
Expand Down

0 comments on commit 17d478c

Please sign in to comment.