diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 3e6e113c9097..f72611edca5a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -194,7 +194,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (str, Any) -> AccessToken if not scopes: raise ValueError("No token scope provided.") - return _generate_sas_token(scopes[0], self.policy, self.key) @@ -291,7 +290,7 @@ def __init__( **kwargs: Any, ) -> None: self._uamqp_transport = kwargs.pop("uamqp_transport", True) - self._amqp_transport = UamqpTransport if self._uamqp_transport else None + self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) self.eventhub_name = eventhub_name if not eventhub_name: @@ -302,7 +301,7 @@ def __init__( if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredential(credential) elif isinstance(credential, AzureNamedKeyCredential): - self._credential = EventhubAzureNamedKeyTokenCredential(credential) + self._credential = EventhubAzureNamedKeyTokenCredential(credential) # type: ignore else: self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index b302b9c7727e..26f5d42b53cf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -511,8 +511,9 @@ def __init__( **kwargs, ) -> None: # TODO: this changes API, check with Anna if valid - - # If possible, move out message creation to right before sending. + # Need move out message creation to right before sending. # Might take more time to loop through events and add them all to batch in `send` than in `add` here + # Default async vs sync might cause issues. self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 6015bf0186f2..ad647a011a1a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from typing import Deque - from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message + from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message, types as uamqp_types from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth from ._consumer_client import EventHubConsumerClient @@ -97,7 +97,7 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) self._auto_reconnect = auto_reconnect self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 - link_properties: Dict[bytes, int] = {} + link_properties: Dict[uamqp_types.AMQPTypes, uamqp_types.AMQPType] = {} self._error = None self._timeout = 0 self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 3f00c3d501ae..1fecbd80f978 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -19,8 +19,6 @@ ) from typing_extensions import Literal -from .exceptions import ConnectError, EventHubError -from .amqp import AmqpAnnotatedMessage from ._client_base import ClientBase from ._producer import EventHubProducer from ._constants import ALL_PARTITIONS, MAX_MESSAGE_LENGTH_BYTES diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 410bd2ff5536..a3a0e25df56a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -43,7 +43,6 @@ # TODO: remove after fixing up async from uamqp import types -from uamqp.message import MessageHeader PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) @@ -132,7 +131,7 @@ def send_context_manager(): else: yield None -# TODO: delete after async unit tests have been refactored + def set_event_partition_key(event, partition_key): # type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None if not partition_key: @@ -155,27 +154,6 @@ def set_event_partition_key(event, partition_key): raw_message.header.durable = True -def set_message_partition_key(message, partition_key): - # type: (Message, Optional[Union[bytes, str]]) -> None - """Set the partition key as an annotation on a uamqp message. - - :param ~uamqp.Message message: The message to update. - :param str partition_key: The partition key value. - :rtype: None - """ - if partition_key: - annotations = message.annotations - if annotations is None: - annotations = dict() - annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL - ] = partition_key # pylint:disable=protected-access - header = MessageHeader() - header.durable = True - message.annotations = annotations - message.header = header - - @contextmanager def send_context_manager(): span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py index 2d98878d5146..67fee8dd2a58 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import queue @@ -14,6 +15,7 @@ from ...exceptions import OperationTimeoutError if TYPE_CHECKING: + from .._transport._base_async import AmqpTransportAsync from ..._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -32,7 +34,8 @@ def __init__( max_message_size_on_link: int, *, max_wait_time: float = 1, - max_buffer_length: int + max_buffer_length: int, + amqp_transport: AmqpTransportAsync ): self._buffered_queue: queue.Queue = queue.Queue() self._max_buffer_len = max_buffer_length @@ -47,11 +50,12 @@ def __init__( self._cur_batch: Optional[EventDataBatch] = None self._max_message_size_on_link = max_message_size_on_link self._check_max_wait_time_future = None + self._amqp_transport = amqp_transport self.partition_id = partition_id async def start(self): async with self._lock: - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) self._running = True if self._max_wait_time: self._last_send_time = time.time() @@ -113,11 +117,11 @@ async def put_events(self, events, timeout_time=None): self._buffered_queue.put(self._cur_batch) self._buffered_queue.put(events) # create a new batch for incoming events - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) except ValueError: # add single event exceeds the cur batch size, create new batch self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -145,7 +149,7 @@ async def _flush(self, timeout_time=None, raise_error=True): _LOGGER.info("Partition: %r started flushing.", self.partition_id) if self._cur_batch: # if there is batch, enqueue it to the buffer first self._buffered_queue.put(self._cur_batch) - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) while self._cur_buffered_len: remaining_time = timeout_time - time.time() if timeout_time else None if (remaining_time and remaining_time > 0) or remaining_time is None: @@ -187,7 +191,7 @@ async def _flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() - self._cur_batch = EventDataBatch(self._max_message_size_on_link) + self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport) _LOGGER.info("Partition %r finished flushing.", self.partition_id) async def check_max_wait_time_worker(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py index ecae49098086..04e5a12ea69f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging from typing import Dict, List, Callable, Optional, Awaitable, TYPE_CHECKING @@ -13,6 +14,7 @@ from ...exceptions import EventDataSendError, ConnectError, EventHubError if TYPE_CHECKING: + from .._transport._base_async import AmqpTransportAsync from ..._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -33,6 +35,7 @@ def __init__( *, max_buffer_length: int = 1500, max_wait_time: float = 1, + amqp_transport: AmqpTransportAsync, ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions @@ -45,6 +48,7 @@ def __init__( self._partition_resolver = PartitionResolver(self._partition_ids) self._max_wait_time = max_wait_time self._max_buffer_length = max_buffer_length + self._amqp_transport = amqp_transport async def _get_partition_id(self, partition_id, partition_key): if partition_id: @@ -77,6 +81,7 @@ async def enqueue_events( self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, + amqp_transport=self._amqp_transport, ) await buffered_producer.start() self._buffered_producers[pid] = buffered_producer diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 5a5a312485c4..b82ae698ed6b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import logging import asyncio @@ -39,10 +39,11 @@ MGMT_PARTITION_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, + READ_OPERATION, ) from ._async_utils import get_dict_with_loop_if_needed from ._connection_manager_async import get_connection_manager -from ._error_async import _handle_exception +from ._transport._uamqp_transport_async import UamqpTransportAsync if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -211,6 +212,8 @@ def __init__( **kwargs: Any ) -> None: self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) + self._uamqp_transport = kwargs.pop("uamqp_transport", True) + self._amqp_transport = UamqpTransportAsync if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredentialAsync(credential) # type: ignore elif isinstance(credential, AzureNamedKeyCredential): @@ -221,6 +224,8 @@ def __init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=self._credential, + uamqp_transport=self._uamqp_transport, + amqp_transport=self._amqp_transport, **kwargs ) self._conn_manager_async = get_connection_manager(**kwargs) @@ -255,32 +260,19 @@ async def _create_auth_async(self) -> authentication.JWTTokenAsync: except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - auth = authentication.JWTTokenAsync( - self._auth_uri, + return await self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, - refresh_window=300, + config=self._config, + update_token=True, ) - await auth.update_token() - return auth - return authentication.JWTTokenAsync( - self._auth_uri, + return await self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=False, ) async def _close_connection_async(self) -> None: @@ -322,19 +314,21 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = await self._create_auth_async() - mgmt_client = AMQPClientAsync( - self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: - conn = await self._conn_manager_async.get_connection( - self._address.hostname, mgmt_auth - ) - mgmt_msg.application_properties["security_token"] = mgmt_auth.token - await mgmt_client.open_async(connection=conn) - response = await mgmt_client.mgmt_request_async( + await mgmt_client.open_async() + while not await mgmt_client.client_ready_async(): + await asyncio.sleep(0.05) + mgmt_msg.application_properties[ + "security_token" + ] = await self._amqp_transport.get_updated_token(mgmt_auth) + response = await self._amqp_transport.mgmt_client_request( + mgmt_client, mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, + operation=READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -347,26 +341,23 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> if status_code < 400: return response if status_code in [401]: - raise errors.AuthenticationException( - "Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, description - ) + raise self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + f"Management authentication failed. Status code: {status_code}, Description: {description!r}" ) if status_code in [404]: - raise ConnectError( - "Management connection failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - raise errors.AMQPConnectionError( - "Management request error. Status code: {}, Description: {!r}".format( - status_code, description + return self._amqp_transport.get_error( + self._amqp_transport.CONNECTION_ERROR, + f"Management connection failed. Status code: {status_code}, Description: {description!r}" ) + return self._amqp_transport.get_error( + self._amqp_transport.AMQP_CONNECTION_ERROR, + f"Management request error. Status code: {status_code}, Description: {description!r}" ) except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as exception: # pylint:disable=broad-except - last_exception = await _handle_exception(exception, self) + last_exception = await self._amqp_transport._handle_exception(exception, self) # pylint: disable=protected-access await self._backoff_async( retried_times=retried_times, last_exception=last_exception ) @@ -380,12 +371,14 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> await mgmt_client.close_async() async def _get_eventhub_properties_async(self) -> Dict[str, Any]: - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + mgmt_msg = mgmt_msg = self._amqp_transport.MESSAGE( + application_properties={"name": self.eventhub_name} + ) response = await self._management_request_async( mgmt_msg, op_type=MGMT_OPERATION ) output = {} - eh_info = response.get_data() # type: Dict[bytes, Any] + eh_info: Dict[bytes, Any] = response.value if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -402,7 +395,7 @@ async def _get_partition_ids_async(self) -> List[str]: async def _get_partition_properties_async( self, partition_id: str ) -> Dict[str, Any]: - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.MESSAGE( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -411,7 +404,7 @@ async def _get_partition_properties_async( response = await self._management_request_async( mgmt_msg, op_type=MGMT_PARTITION_OPERATION ) - partition_info = response.get_data() # type: Dict[bytes, Union[bytes, int]] + partition_info = response.value # type: Dict[bytes, Union[bytes, int]] output = {} # type: Dict[str, Any] if partition_info: output["eventhub_name"] = cast(bytes, partition_info[b"name"]).decode( @@ -463,16 +456,12 @@ async def _open(self) -> None: await self._handler.close_async() auth = await self._client._create_auth_async() self._create_handler(auth) - await self._handler.open_async( - connection=await self._client._conn_manager_async.get_connection( - self._client._address.hostname, auth - ) - ) + await self._handler.open_async() while not await self._handler.client_ready_async(): await asyncio.sleep(0.05, **self._internal_kwargs) self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES + self._amqp_transport.get_remote_max_message_size(self._handler) + or constants.MAX_FRAME_SIZE_BYTES ) self.running = True @@ -487,11 +476,14 @@ async def _close_connection_async(self) -> None: await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access async def _handle_exception(self, exception: Exception) -> Exception: - if not self.running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return await _handle_exception(exception, self) - - return await _handle_exception(exception, self) + if not self.running and isinstance(exception, self._amqp_transport.TIMEOUT_EXCEPTION): + exception = self._amqp_transport.get_error( + self._amqp_transport.AUTH_EXCEPTION, + "Authorization timeout." + ) + return await self._amqp_transport._handle_exception( # pylint: disable=protected-access + exception, self + ) async def _do_retryable_operation( self, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index d53443f12fd7..32e544344989 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -5,8 +5,12 @@ from typing import TYPE_CHECKING +from uamqp import c_uamqp from uamqp.async_ops import ConnectionAsync +from .._connection_manager import _ConnectionMode +from .._constants import TransportType + if TYPE_CHECKING: from uamqp.authentication import JWTTokenAsync @@ -28,6 +32,62 @@ async def reset_connection_if_broken(self) -> None: pass +class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes + def __init__(self, **kwargs) -> None: + self._loop = kwargs.get("loop") + self._lock = Lock(loop=self._loop) + self._conn = None + + self._container_id = kwargs.get("container_id") + self._debug = kwargs.get("debug") + self._error_policy = kwargs.get("error_policy") + self._properties = kwargs.get("properties") + self._encoding = kwargs.get("encoding") or "UTF-8" + self._transport_type = kwargs.get("transport_type") or TransportType.Amqp + self._http_proxy = kwargs.get("http_proxy") + self._max_frame_size = kwargs.get("max_frame_size") + self._channel_max = kwargs.get("channel_max") + self._idle_timeout = kwargs.get("idle_timeout") + self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( + "remote_idle_timeout_empty_frame_send_ratio" + ) + + async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAsync: + async with self._lock: + if self._conn is None: + self._conn = ConnectionAsync( + host, + auth, + container_id=self._container_id, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, + error_policy=self._error_policy, + debug=self._debug, + loop=self._loop, + encoding=self._encoding, + ) + return self._conn + + async def close_connection(self) -> None: + async with self._lock: + if self._conn: + await self._conn.destroy_async() + self._conn = None + + async def reset_connection_if_broken(self) -> None: + async with self._lock: + if self._conn and self._conn._state in ( # pylint:disable=protected-access + c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member + ): + self._conn = None + + class _SeparateConnectionManager(object): def __init__(self, **kwargs) -> None: pass @@ -43,4 +103,7 @@ async def reset_connection_if_broken(self) -> None: def get_connection_manager(**kwargs) -> "ConnectionManager": + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) + if connection_mode == _ConnectionMode.ShareConnection: + return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index afdcad0ad9e2..f3bcb0a36636 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import time import asyncio import uuid @@ -9,19 +10,16 @@ from collections import deque from typing import TYPE_CHECKING, Callable, Awaitable, cast, Dict, Optional, Union, List -import uamqp -from uamqp import errors, types, utils -from uamqp import ReceiveClientAsync, Source - from ._client_base_async import ConsumerProducerMixin from ._async_utils import get_dict_with_loop_if_needed from .._common import EventData -from ..exceptions import _error_handler from .._utils import create_properties, event_position_selector from .._constants import EPOCH_SYMBOL, TIMEOUT_SYMBOL, RECEIVER_RUNTIME_METRIC_SYMBOL if TYPE_CHECKING: from typing import Deque + import uamqp + from uamqp import ReceiveClientAsync, Source, types from uamqp.authentication import JWTTokenAsync from ._consumer_client_async import EventHubConsumerClient @@ -79,9 +77,10 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self.running = False self.closed = False - self._on_event_received = kwargs[ + self._amqp_transport = kwargs.pop("amqp_transport") + self._on_event_received: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]] = kwargs[ "on_event_received" - ] # type: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]] + ] self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) self._client = client self._source = source @@ -91,81 +90,65 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 self._timeout = 0 - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] + self._idle_timeout = (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) if idle_timeout else None + link_properties: Dict[types.AMQPType, types.AMQPType] = {} partition = self._source.split("/")[-1] self._partition = partition - self._name = "EHReceiver-{}-partition{}".format(uuid.uuid4(), partition) + self._name = f"EHReceiver-{uuid.uuid4()}-partition{partition}" if owner_level is not None: - self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong( - int(owner_level) - ) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong( - int(link_property_timeout_ms) - ) - self._handler = None # type: Optional[ReceiveClientAsync] + ) * self._amqp_transport.IDLE_TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) + self._handler: Optional[ReceiveClientAsync] = None self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties ) - self._message_buffer = deque() # type: Deque[uamqp.Message] - self._last_received_event = None # type: Optional[EventData] - - def _create_handler(self, auth: "JWTTokenAsync") -> None: - source = Source(self._source) - if self._offset is not None: - source.set_filter( - event_position_selector(self._offset, self._offset_inclusive) - ) - desired_capabilities = None - if self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - - properties = create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._message_buffer: Deque[uamqp.Message] = deque() + self._last_received_event: Optional[EventData] = None + + def _create_handler(self, auth: JWTTokenAsync) -> None: + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) ) - self._handler = ReceiveClientAsync( - source, + desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + link_credit=self._prefetch, link_properties=self._link_properties, - timeout=self._timeout, idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=properties, + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, - **self._internal_kwargs - ) - - self._handler._streaming_receive = True # pylint:disable=protected-access - self._handler._message_received_callback = ( # pylint:disable=protected-access - self._message_received + streaming_receive=True, + message_received_callback=self._message_received, ) async def _open_with_retry(self) -> None: await self._do_retryable_operation(self._open, operation_need_param=False) def _message_received(self, message: uamqp.Message) -> None: - self._message_buffer.appendleft(message) + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access - message = self._message_buffer.pop() + message = self._message_buffer.popleft() event_data = EventData._from_message(message) self._last_received_event = event_data return event_data diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py index 20c3468bd8b1..0cf4634a655b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import datetime @@ -21,7 +22,7 @@ from ._eventprocessor.event_processor import EventProcessor from ._consumer_async import EventHubConsumer from ._client_base_async import ClientBaseAsync -from .._constants import ALL_PARTITIONS +from .._constants import ALL_PARTITIONS, TransportType from .._eventprocessor.common import LoadBalancingStrategy @@ -215,6 +216,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, **self._internal_kwargs, ) return handler @@ -231,7 +233,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, + transport_type: Optional["TransportType"] = TransportType.Amqp, checkpoint_store: Optional["CheckpointStore"] = None, load_balancing_interval: float = 10, **kwargs: Any diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py deleted file mode 100644 index e272f496ec81..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py +++ /dev/null @@ -1,74 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -import asyncio -import logging -from typing import TYPE_CHECKING, Union, cast - -from uamqp import errors - -from ..exceptions import ( - _create_eventhub_exception, - EventHubError, - EventDataSendError, - EventDataError, -) - -if TYPE_CHECKING: - from ._client_base_async import ClientBaseAsync, ConsumerProducerMixin - -_LOGGER = logging.getLogger(__name__) - - -async def _handle_exception( # pylint:disable=too-many-branches, too-many-statements - exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] -) -> Exception: - # pylint: disable=protected-access - if isinstance(exception, asyncio.CancelledError): - raise exception - error = exception - try: - name = cast("ConsumerProducerMixin", closable)._name - except AttributeError: - name = cast("ClientBaseAsync", closable)._container_id - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - await cast("ConsumerProducerMixin", closable)._close_connection_async() - raise error - elif isinstance(exception, EventHubError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - raise error - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - try: - if isinstance(exception, errors.AuthenticationException): - await closable._close_connection_async() - elif isinstance(exception, errors.LinkDetach): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - elif isinstance(exception, errors.ConnectionClose): - await closable._close_connection_async() - elif isinstance(exception, errors.MessageHandlerError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors - await closable._close_connection_async() - except AttributeError: - pass - return _create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 47db456b0a5f..04d4865ce92f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -2,23 +2,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import uuid import asyncio import logging from typing import Iterable, Union, Optional, Any, AnyStr, List, TYPE_CHECKING import time -from uamqp import types, constants, errors -from uamqp import SendClientAsync - from azure.core.tracing import AbstractSpan from .._common import EventData, EventDataBatch -from ..exceptions import _error_handler, OperationTimeoutError +from ..exceptions import OperationTimeoutError from .._producer import _set_partition_key, _set_trace_message from .._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, @@ -29,6 +26,9 @@ from ._async_utils import get_dict_with_loop_if_needed if TYPE_CHECKING: + from uamqp import types, constants, errors + from uamqp import SendClientAsync + from uamqp.authentication import JWTTokenAsync # pylint: disable=ungrouped-imports from ._producer_client_async import EventHubProducerClient @@ -60,8 +60,9 @@ class EventHubProducer( Default value is `True`. """ - def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> None: + def __init__(self, client: EventHubProducerClient, target: str, **kwargs) -> None: super().__init__() + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -79,10 +80,14 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect self._timeout = send_timeout - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access + self._idle_timeout = ( + (idle_timeout * self._amqp_transport.IDLE_TIMEOUT_FACTOR) + if idle_timeout + else None + ) + + self._retry_policy = self._amqp_transport.create_retry_policy( + config=self._client._config ) self._reconnect_backoff = 1 self._name = "EHProducer-{}".format(uuid.uuid4()) @@ -91,29 +96,31 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N if partition: self._target += "/Partitions/" + partition self._name += "-partition{}".format(partition) - self._handler = None # type: Optional[SendClientAsync] - self._outcome = None # type: Optional[constants.MessageSendResult] - self._condition = None # type: Optional[Exception] + self._handler: Optional[SendClientAsync] = None + self._outcome: Optional[constants.MessageSendResult] = None + self._condition: Optional[Exception] = None self._lock = asyncio.Lock(**self._internal_kwargs) - self._link_properties = { - types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) - } + self._link_properties = self._amqp_transport.create_link_properties( + {TIMEOUT_SYMBOL: int(self._timeout * 1000)} + ) + def _create_handler(self, auth: "JWTTokenAsync") -> None: - self._handler = SendClientAsync( - self._target, + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout * 1000, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, properties=create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._client._config.user_agent, # pylint: disable=protected-access + amqp_transport=self._amqp_transport, ), - **self._internal_kwargs + msg_timeout=self._timeout * 1000, ) async def _open_with_retry(self) -> Any: @@ -121,38 +128,15 @@ async def _open_with_retry(self) -> Any: self._open, operation_need_param=False ) - def _set_msg_timeout( - self, timeout_time: Optional[float], last_exception: Optional[Exception] - ) -> None: - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - async def _send_event_data( self, timeout_time: Optional[float] = None, last_exception: Optional[Exception] = None, ) -> None: - # TODO: Correct uAMQP type hints if self._unsent_events: - await self._open() - self._set_msg_timeout(timeout_time, last_exception) - self._handler.queue_message(*self._unsent_events) # type: ignore - await self._handler.wait_async() # type: ignore - self._unsent_events = self._handler.pending_messages # type: ignore - if self._outcome != constants.MessageSendResult.Ok: - if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("Send operation timed out") - if self._condition: - raise self._condition + self._amqp_transport.send_messages( + self, timeout_time, last_exception, _LOGGER + ) async def _send_event_data_with_retry( self, timeout: Optional[float] = None @@ -183,16 +167,20 @@ def _wrap_eventdata( ) -> Union[EventData, EventDataBatch]: if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message( - event_data, EventData + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, partition_key # pylint: disable=protected-access + ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: if isinstance( event_data, EventDataBatch ): # The partition_key in the param will be omitted. + if not event_data: + return event_data if ( partition_key and partition_key @@ -203,15 +191,16 @@ def _wrap_eventdata( ) for ( event - ) in event_data.message._body_gen: # pylint: disable=protected-access + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key( + event_data, partition_key, self._amqp_transport + ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome + wrapper_event_data = EventDataBatch._from_batch(event_data, self._amqp_transport, partition_key) # type: ignore # pylint: disable=protected-access return wrapper_event_data async def send( @@ -253,7 +242,11 @@ async def send( wrapper_event_data = self._wrap_eventdata( event_data, child, partition_key ) - self._unsent_events = [wrapper_event_data.message] + + if not wrapper_event_data: + return + + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index a425bea6d059..15e7c345f8be 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -16,7 +16,7 @@ from ._producer_async import EventHubProducer from ._buffered_producer import BufferedProducerDispatcher from .._utils import set_event_partition_key -from .._constants import ALL_PARTITIONS +from .._constants import ALL_PARTITIONS, TransportType from .._common import EventDataBatch, EventData if TYPE_CHECKING: @@ -232,6 +232,7 @@ async def _buffered_send(self, events, **kwargs): self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, + amqp_transport=self._amqp_transport, ) await self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) @@ -301,9 +302,11 @@ async def _get_max_message_size(self) -> None: EventHubProducer, self._producers[ALL_PARTITIONS] )._open_with_retry() self._max_message_size_on_link = ( - cast( # type: ignore + self._amqp_transport.get_remote_max_message_size( + cast( # type: ignore EventHubProducer, self._producers[ALL_PARTITIONS] - )._handler.message_handler._link.peer_max_message_size + )._handler + ) or constants.MAX_MESSAGE_LENGTH_BYTES ) @@ -350,6 +353,7 @@ def _create_producer( partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport = self._amqp_transport, **self._internal_kwargs ) return handler @@ -402,7 +406,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, + transport_type: Optional["TransportType"] = TransportType.Amqp, **kwargs: Any ) -> "EventHubProducerClient": """Create an EventHubProducerClient from a connection string. @@ -719,6 +723,7 @@ async def create_batch( max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, partition_key=partition_key, + amqp_transport=self._amqp_transport, ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py new file mode 100644 index 000000000000..ea36b4288da0 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -0,0 +1,232 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from abc import ABC, abstractmethod + +class AmqpTransportAsync(ABC): + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ + # define constants + BATCH_MESSAGE = None + MAX_FRAME_SIZE_BYTES = None + IDLE_TIMEOUT_FACTOR = None + MESSAGE = None + + # define symbols + PRODUCT_SYMBOL = None + VERSION_SYMBOL = None + FRAMEWORK_SYMBOL = None + PLATFORM_SYMBOL = None + USER_AGENT_SYMBOL = None + PROP_PARTITION_KEY_AMQP_SYMBOL = None + + # errors + AMQP_LINK_ERROR = None + LINK_STOLEN_CONDITION = None + MGMT_AUTH_EXCEPTION = None + CONNECTION_ERROR = None + AMQP_CONNECTION_ERROR = None + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + async def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return await message.gather()[0].get_message_encoded_size() + + @staticmethod + @abstractmethod + async def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + async def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + async def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + + @staticmethod + @abstractmethod + async def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + + @staticmethod + @abstractmethod + async def create_send_client(*, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + async def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + async def set_message_partition_key(message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + + @staticmethod + @abstractmethod + async def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + + @staticmethod + @abstractmethod + async def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + + @staticmethod + @abstractmethod + async def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + async def open_receive_client(*, handler, client, auth): + """ + Opens the receive client. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + """ + + @staticmethod + @abstractmethod + async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + async def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @staticmethod + @abstractmethod + async def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @staticmethod + @abstractmethod + async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @staticmethod + @abstractmethod + async def get_error(error, message, *, condition=None): + """ + Gets error and passes in error message, and, if applicable, condition. + :param error: The error to raise. + :param str message: Error message. + :param condition: Optional error condition. Will not be used by uamqp. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py new file mode 100644 index 000000000000..1c5cdcb4797c --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -0,0 +1,356 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from __future__ import annotations +import asyncio +import logging +from typing import Optional, Union, Any, cast, TYPE_CHECKING + +try: + from uamqp import ( + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClientAsync, + ReceiveClientAsync, + Source, + utils, + authentication, + AMQPClientAsync, + compat, + errors, + ) + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + uamqp_installed = True +except ImportError: + uamqp_installed = False + +from ._base_async import AmqpTransportAsync +from ...amqp._constants import AmqpMessageBodyType +from ..._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY, +) + +from ...exceptions import ( + ConnectError, + EventDataError, + EventDataSendError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataError, + EventDataSendError, +) + +if TYPE_CHECKING: + from .._client_base_async import ClientBaseAsync, ConsumerProducerMixin + +_LOGGER = logging.getLogger(__name__) + +if uamqp_installed: + + from ..._transport._uamqp_transport import UamqpTransport + + class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + + @staticmethod + async def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return await message.gather()[0].get_message_encoded_size() + + @staticmethod + def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClientAsync( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) + + @staticmethod + async def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + await producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + UamqpTransportAsync._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + await producer._handler.wait_async() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + @staticmethod + def set_message_partition_key(message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + return message + + @staticmethod + async def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + # pylint: disable=protected-access + batch_message._internal_events.append(event_data) + batch_message._message._body_gen.append( + outgoing_event_data._message + ) + + @staticmethod + async def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClientAsync( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + @staticmethod + async def open_receive_client(*, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + await handler.open() + + @staticmethod + async def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAsync( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + await token_auth.update_token() + return token_auth + + @staticmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClientAsync( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + @staticmethod + async def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + @staticmethod + async def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return await mgmt_client.mgmt_request_async( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) + + @staticmethod + async def _handle_exception( # pylint:disable=too-many-branches, too-many-statements + exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] + ) -> Exception: + # pylint: disable=protected-access + if isinstance(exception, asyncio.CancelledError): + raise exception + error = exception + try: + name = cast("ConsumerProducerMixin", closable)._name + except AttributeError: + name = cast("ClientBaseAsync", closable)._container_id + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + await cast("ConsumerProducerMixin", closable)._close_connection_async() + raise error + elif isinstance(exception, EventHubError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + raise error + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + try: + if isinstance(exception, errors.AuthenticationException): + await closable._close_connection_async() + elif isinstance(exception, errors.LinkDetach): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + elif isinstance(exception, errors.ConnectionClose): + await closable._close_connection_async() + elif isinstance(exception, errors.MessageHandlerError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors + await closable._close_connection_async() + except AttributeError: + pass + return UamqpTransport._create_eventhub_exception(exception) \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 8001d97cea6d..a688e70a8861 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -98,28 +98,3 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" - -# TODO: delete when async unittests have been refactored -def _create_eventhub_exception(exception): - if isinstance(exception, errors.AuthenticationException): - error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.VendorLinkDetach): - error = ConnectError(str(exception), exception) - elif isinstance(exception, errors.LinkDetach): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.ConnectionClose): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.MessageHandlerError): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.AMQPConnectionError): - error_type = ( - AuthenticationError - if str(exception).startswith("Unable to open authentication session") - else ConnectError - ) - error = error_type(str(exception), exception) - elif isinstance(exception, compat.TimeoutException): - error = ConnectionLostError(str(exception), exception) - else: - error = EventHubError(str(exception), exception) - return error