diff --git a/src/momento/internal/aio/_scs_grpc_manager.py b/src/momento/internal/aio/_scs_grpc_manager.py index 0531ee90..86bf92d6 100644 --- a/src/momento/internal/aio/_scs_grpc_manager.py +++ b/src/momento/internal/aio/_scs_grpc_manager.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional + import grpc from momento_wire_types import cacheclient_pb2_grpc as cache_client from momento_wire_types import cachepubsub_pb2_grpc as pubsub_client @@ -108,7 +110,7 @@ def async_stub(self) -> pubsub_client.PubsubStub: return pubsub_client.PubsubStub(self._secure_channel) # type: ignore[no-untyped-call] -def _interceptors(auth_token: str, retry_strategy: RetryStrategy = None) -> list[grpc.aio.ClientInterceptor]: +def _interceptors(auth_token: str, retry_strategy: Optional[RetryStrategy] = None) -> list[grpc.aio.ClientInterceptor]: headers = [ Header("authorization", auth_token), Header("agent", f"python:{_ControlGrpcManager.version}"), diff --git a/src/momento/internal/aio/_scs_pubsub_client.py b/src/momento/internal/aio/_scs_pubsub_client.py index 09a8a4a7..4bd49285 100644 --- a/src/momento/internal/aio/_scs_pubsub_client.py +++ b/src/momento/internal/aio/_scs_pubsub_client.py @@ -67,7 +67,7 @@ async def publish(self, cache_name: str, topic_name: str, value: str | bytes) -> value=topic_value, ) - await self._build_stub().Publish( + await self._build_stub().Publish( # type: ignore[misc] request, ) return TopicPublish.Success() @@ -85,13 +85,13 @@ async def subscribe(self, cache_name: str, topic_name: str) -> TopicSubscribeRes topic=topic_name, # TODO: resume_at_topic_sequence_number ) - stream = self._build_stream_stub().Subscribe( + stream = self._build_stream_stub().Subscribe( # type: ignore[misc] request, ) # Ping the stream to provide a nice error message if the cache does not exist. - msg = await stream.read() - msg_type = msg.WhichOneof("kind") + msg: pubsub_pb._SubscriptionItem = await stream.read() # type: ignore[misc] + msg_type: str = msg.WhichOneof("kind") if msg_type == "heartbeat": # The first message to a new subscription is always a heartbeat. pass @@ -99,7 +99,7 @@ async def subscribe(self, cache_name: str, topic_name: str) -> TopicSubscribeRes err = Exception(f"expected a heartbeat message but got '{msg_type}'") self._log_request_error("subscribe", err) return TopicSubscribe.Error(convert_error(err)) - return TopicSubscribe.SubscriptionAsync(cache_name, topic_name, client_stream=stream) + return TopicSubscribe.SubscriptionAsync(cache_name, topic_name, client_stream=stream) # type: ignore[misc] except Exception as e: self._log_request_error("subscribe", e) return TopicSubscribe.Error(convert_error(e)) diff --git a/src/momento/internal/synchronous/_scs_grpc_manager.py b/src/momento/internal/synchronous/_scs_grpc_manager.py index d0c11689..06a89643 100644 --- a/src/momento/internal/synchronous/_scs_grpc_manager.py +++ b/src/momento/internal/synchronous/_scs_grpc_manager.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Optional + import grpc from momento_wire_types import cacheclient_pb2_grpc as cache_client from momento_wire_types import cachepubsub_pb2_grpc as pubsub_client @@ -73,7 +75,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede intercept_channel = grpc.intercept_channel( self._secure_channel, *_interceptors(credential_provider.auth_token, None) ) - self._stub = pubsub_client.PubsubStub(intercept_channel) + self._stub = pubsub_client.PubsubStub(intercept_channel) # type: ignore[no-untyped-call] def close(self) -> None: self._secure_channel.close() @@ -95,7 +97,7 @@ def __init__(self, configuration: TopicConfiguration, credential_provider: Crede intercept_channel = grpc.intercept_channel( self._secure_channel, *_stream_interceptors(credential_provider.auth_token) ) - self._stub = pubsub_client.PubsubStub(intercept_channel) + self._stub = pubsub_client.PubsubStub(intercept_channel) # type: ignore[no-untyped-call] def close(self) -> None: self._secure_channel.close() @@ -104,7 +106,9 @@ def stub(self) -> pubsub_client.PubsubStub: return self._stub -def _interceptors(auth_token: str, retry_strategy: RetryStrategy = None) -> list[grpc.UnaryUnaryClientInterceptor]: +def _interceptors( + auth_token: str, retry_strategy: Optional[RetryStrategy] = None +) -> list[grpc.UnaryUnaryClientInterceptor]: headers = [Header("authorization", auth_token), Header("agent", f"python:{_ControlGrpcManager.version}")] return list( filter( diff --git a/src/momento/internal/synchronous/_scs_pubsub_client.py b/src/momento/internal/synchronous/_scs_pubsub_client.py index 4f62ad19..36f2bfa1 100644 --- a/src/momento/internal/synchronous/_scs_pubsub_client.py +++ b/src/momento/internal/synchronous/_scs_pubsub_client.py @@ -85,13 +85,13 @@ def subscribe(self, cache_name: str, topic_name: str) -> TopicSubscribeResponse: topic=topic_name, # TODO: resume_at_topic_sequence_number ) - stream = self._build_stream_stub().Subscribe( + stream = self._build_stream_stub().Subscribe( # type: ignore[misc] request, ) # Ping the stream to provide a nice error message if the cache does not exist. - msg = stream.next() - msg_type = msg.WhichOneof("kind") + msg: pubsub_pb._SubscriptionItem = stream.next() # type: ignore[misc] + msg_type: str = msg.WhichOneof("kind") if msg_type == "heartbeat": # The first message to a new subscription is always a heartbeat. pass @@ -99,7 +99,7 @@ def subscribe(self, cache_name: str, topic_name: str) -> TopicSubscribeResponse: err = Exception(f"expected a heartbeat message but got '{msg_type}'") self._log_request_error("subscribe", err) return TopicSubscribe.Error(convert_error(err)) - return TopicSubscribe.Subscription(cache_name, topic_name, client_stream=stream) + return TopicSubscribe.Subscription(cache_name, topic_name, client_stream=stream) # type: ignore[misc] except Exception as e: self._log_request_error("subscribe", e) return TopicSubscribe.Error(convert_error(e)) diff --git a/src/momento/responses/pubsub/subscribe.py b/src/momento/responses/pubsub/subscribe.py index c862e6f7..957ffdb7 100644 --- a/src/momento/responses/pubsub/subscribe.py +++ b/src/momento/responses/pubsub/subscribe.py @@ -1,15 +1,15 @@ from abc import ABC -from grpc.aio._interceptor import InterceptedUnaryStreamCall -from grpc._channel import _MultiThreadedRendezvous from typing import Optional +from grpc._channel import _MultiThreadedRendezvous +from grpc.aio._interceptor import InterceptedUnaryStreamCall +from momento_wire_types import cachepubsub_pb2 + from ... import logs from ..mixins import ErrorResponseMixin from ..response import PubsubResponse from .subscription_item import TopicSubscriptionItem, TopicSubscriptionItemResponse -from momento_wire_types import cachepubsub_pb2 - class TopicSubscribeResponse(PubsubResponse): """Parent response type for a topic `publish` request. @@ -27,12 +27,8 @@ class TopicSubscribe(ABC): class SubscriptionBase(TopicSubscribeResponse): """Base class for common logic shared between async and synchronous subscriptions.""" - def __init__(self, cache_name: str, topic_name: str, client_stream: InterceptedUnaryStreamCall): - self._logger = logs.logger - self._cache_name = cache_name - self._topic_name = topic_name - self._client_stream = client_stream # type: ignore[misc] - self._last_known_sequence_number: Optional[int] = None + _logger = logs.logger + _last_known_sequence_number: Optional[int] = None def _process_result(self, result: cachepubsub_pb2._SubscriptionItem) -> Optional[TopicSubscriptionItemResponse]: msg_type: str = result.WhichOneof("kind") @@ -54,9 +50,15 @@ def _process_result(self, result: cachepubsub_pb2._SubscriptionItem) -> Optional return None class SubscriptionAsync(SubscriptionBase): - """Indicates the request was successful.""" + """Provides the async version of a topic subscription.""" + + def __init__(self, cache_name: str, topic_name: str, client_stream: InterceptedUnaryStreamCall): + self._cache_name = cache_name + self._topic_name = topic_name + self._client_stream = client_stream # type: ignore[misc] async def item(self) -> TopicSubscriptionItemResponse: + """Returns the next published item from the subscription.""" while True: try: result: cachepubsub_pb2._SubscriptionItem = await self._client_stream.read() # type: ignore[misc] @@ -70,7 +72,15 @@ async def item(self) -> TopicSubscriptionItemResponse: return item class Subscription(SubscriptionBase): + """Provides the synchronous version of a topic subscription.""" + + def __init__(self, cache_name: str, topic_name: str, client_stream: _MultiThreadedRendezvous): + self._cache_name = cache_name + self._topic_name = topic_name + self._client_stream = client_stream # type: ignore[misc] + def item(self) -> TopicSubscriptionItemResponse: + """Returns the next published item from the subscription.""" while True: try: result: cachepubsub_pb2._SubscriptionItem = self._client_stream.next() # type: ignore[misc]