Skip to content

Commit

Permalink
add switch to async
Browse files Browse the repository at this point in the history
  • Loading branch information
swathipil committed Aug 2, 2022
1 parent 8de034d commit 9625aad
Show file tree
Hide file tree
Showing 18 changed files with 829 additions and 313 deletions.
5 changes: 2 additions & 3 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
if not scopes:
raise ValueError("No token scope provided.")

return _generate_sas_token(scopes[0], self.policy, self.key)


Expand Down Expand Up @@ -291,7 +290,7 @@ def __init__(
**kwargs: Any,
) -> None:
self._uamqp_transport = kwargs.pop("uamqp_transport", True)
self._amqp_transport = UamqpTransport if self._uamqp_transport else None
self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport)

self.eventhub_name = eventhub_name
if not eventhub_name:
Expand All @@ -302,7 +301,7 @@ def __init__(
if isinstance(credential, AzureSasCredential):
self._credential = EventhubAzureSasTokenCredential(credential)
elif isinstance(credential, AzureNamedKeyCredential):
self._credential = EventhubAzureNamedKeyTokenCredential(credential)
self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore
else:
self._credential = credential # type: ignore
self._keep_alive = kwargs.get("keep_alive", 30)
Expand Down
3 changes: 2 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,9 @@ def __init__(
**kwargs,
) -> None:
# TODO: this changes API, check with Anna if valid -
# If possible, move out message creation to right before sending.
# Need move out message creation to right before sending.
# Might take more time to loop through events and add them all to batch in `send` than in `add` here
# Default async vs sync might cause issues.
self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport)


Expand Down
4 changes: 2 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if TYPE_CHECKING:
from typing import Deque
from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message
from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message, types as uamqp_types
from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth

from ._consumer_client import EventHubConsumerClient
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any)
self._auto_reconnect = auto_reconnect
self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config)
self._reconnect_backoff = 1
link_properties: Dict[bytes, int] = {}
link_properties: Dict[uamqp_types.AMQPTypes, uamqp_types.AMQPType] = {}
self._error = None
self._timeout = 0
self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
)
from typing_extensions import Literal

from .exceptions import ConnectError, EventHubError
from .amqp import AmqpAnnotatedMessage
from ._client_base import ClientBase
from ._producer import EventHubProducer
from ._constants import ALL_PARTITIONS, MAX_MESSAGE_LENGTH_BYTES
Expand Down
24 changes: 1 addition & 23 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

# TODO: remove after fixing up async
from uamqp import types
from uamqp.message import MessageHeader
PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY)


Expand Down Expand Up @@ -132,7 +131,7 @@ def send_context_manager():
else:
yield None

# TODO: delete after async unit tests have been refactored

def set_event_partition_key(event, partition_key):
# type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None
if not partition_key:
Expand All @@ -155,27 +154,6 @@ def set_event_partition_key(event, partition_key):
raw_message.header.durable = True


def set_message_partition_key(message, partition_key):
# type: (Message, Optional[Union[bytes, str]]) -> None
"""Set the partition key as an annotation on a uamqp message.
:param ~uamqp.Message message: The message to update.
:param str partition_key: The partition key value.
:rtype: None
"""
if partition_key:
annotations = message.annotations
if annotations is None:
annotations = dict()
annotations[
PROP_PARTITION_KEY_AMQP_SYMBOL
] = partition_key # pylint:disable=protected-access
header = MessageHeader()
header.durable = True
message.annotations = annotations
message.header = header


@contextmanager
def send_context_manager():
span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import queue
Expand All @@ -14,6 +15,7 @@
from ...exceptions import OperationTimeoutError

if TYPE_CHECKING:
from .._transport._base_async import AmqpTransportAsync
from ..._producer_client import SendEventTypes

_LOGGER = logging.getLogger(__name__)
Expand All @@ -32,7 +34,8 @@ def __init__(
max_message_size_on_link: int,
*,
max_wait_time: float = 1,
max_buffer_length: int
max_buffer_length: int,
amqp_transport: AmqpTransportAsync
):
self._buffered_queue: queue.Queue = queue.Queue()
self._max_buffer_len = max_buffer_length
Expand All @@ -47,11 +50,12 @@ def __init__(
self._cur_batch: Optional[EventDataBatch] = None
self._max_message_size_on_link = max_message_size_on_link
self._check_max_wait_time_future = None
self._amqp_transport = amqp_transport
self.partition_id = partition_id

async def start(self):
async with self._lock:
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
self._running = True
if self._max_wait_time:
self._last_send_time = time.time()
Expand Down Expand Up @@ -113,11 +117,11 @@ async def put_events(self, events, timeout_time=None):
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
except ValueError:
# add single event exceeds the cur batch size, create new batch
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
self._cur_batch.add(events)
self._cur_buffered_len += new_events_len

Expand Down Expand Up @@ -145,7 +149,7 @@ async def _flush(self, timeout_time=None, raise_error=True):
_LOGGER.info("Partition: %r started flushing.", self.partition_id)
if self._cur_batch: # if there is batch, enqueue it to the buffer first
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
while self._cur_buffered_len:
remaining_time = timeout_time - time.time() if timeout_time else None
if (remaining_time and remaining_time > 0) or remaining_time is None:
Expand Down Expand Up @@ -187,7 +191,7 @@ async def _flush(self, timeout_time=None, raise_error=True):
break
# after finishing flushing, reset cur batch and put it into the buffer
self._last_send_time = time.time()
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
_LOGGER.info("Partition %r finished flushing.", self.partition_id)

async def check_max_wait_time_worker(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
from typing import Dict, List, Callable, Optional, Awaitable, TYPE_CHECKING
Expand All @@ -13,6 +14,7 @@
from ...exceptions import EventDataSendError, ConnectError, EventHubError

if TYPE_CHECKING:
from .._transport._base_async import AmqpTransportAsync
from ..._producer_client import SendEventTypes

_LOGGER = logging.getLogger(__name__)
Expand All @@ -33,6 +35,7 @@ def __init__(
*,
max_buffer_length: int = 1500,
max_wait_time: float = 1,
amqp_transport: AmqpTransportAsync,
):
self._buffered_producers: Dict[str, BufferedProducer] = {}
self._partition_ids: List[str] = partitions
Expand All @@ -45,6 +48,7 @@ def __init__(
self._partition_resolver = PartitionResolver(self._partition_ids)
self._max_wait_time = max_wait_time
self._max_buffer_length = max_buffer_length
self._amqp_transport = amqp_transport

async def _get_partition_id(self, partition_id, partition_key):
if partition_id:
Expand Down Expand Up @@ -77,6 +81,7 @@ async def enqueue_events(
self._max_message_size_on_link,
max_wait_time=self._max_wait_time,
max_buffer_length=self._max_buffer_length,
amqp_transport=self._amqp_transport,
)
await buffered_producer.start()
self._buffered_producers[pid] = buffered_producer
Expand Down
Loading

0 comments on commit 9625aad

Please sign in to comment.