diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index c4a9ff45c800..07fe891132ae 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,19 +1,19 @@ # Release History -## 5.10.1 (Unreleased) +## 5.10.1 (2022-08-18) This version and all future versions will require Python 3.7+, Python 3.6 is no longer supported. -### Features Added - -### Breaking Changes +### Bugs Fixed - Fixed a bug in `BufferedProducer` that would block when flushing the queue causing the client to freeze up (issue #23510). - -### Bugs Fixed +- Fixed a bug in the async `EventHubProducerClient` and `EventHubConsumerClient` that set the default value of the `transport_type` parameter in the constructors to `None` rather than `TransportType.Amqp`. ### Other Changes +- Internal refactoring to support upcoming Pure Python AMQP-based release. +- Updated uAMQP dependency to 1.6.0. + ## 5.10.0 (2022-06-08) ### Features Added diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index c2a457b2726e..6645bf9ea577 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -2,12 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from uamqp import constants from ._common import EventData, EventDataBatch from ._version import VERSION __version__ = VERSION +from ._constants import TransportType from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient from ._client_base import EventHubSharedKeyCredential @@ -19,8 +19,6 @@ EventHubConnectionStringProperties, ) -TransportType = constants.TransportType - __all__ = [ "EventData", "EventDataBatch", diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py index 1eb06757f230..14cce8b88697 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import time import queue import logging @@ -14,6 +15,7 @@ from ..exceptions import OperationTimeoutError if TYPE_CHECKING: + from .._transport._base import AmqpTransport from .._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -30,8 +32,8 @@ def __init__( max_message_size_on_link: int, executor: ThreadPoolExecutor, *, - max_wait_time: float = 1, - max_buffer_length: int + max_buffer_length: int, + max_wait_time: float = 1 ): self._buffered_queue: queue.Queue = queue.Queue() self._max_buffer_len = max_buffer_length diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py index 3c112de18d18..71f97f15fecd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer_dispatcher.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import logging from threading import Lock from concurrent.futures import ThreadPoolExecutor @@ -14,6 +15,7 @@ if TYPE_CHECKING: from .._producer_client import SendEventTypes + from .._transport._base import AmqpTransport _LOGGER = logging.getLogger(__name__) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 98cc19b0bfa7..61e59688cbc4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import logging import uuid @@ -10,16 +10,13 @@ import functools import collections from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING, cast, Union -from datetime import timedelta - try: - from urlparse import urlparse - from urllib import quote_plus # type: ignore + from typing import TypeAlias except ImportError: - from urllib.parse import urlparse, quote_plus + from typing_extensions import TypeAlias +from datetime import timedelta +from urllib.parse import urlparse -from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils -import six from azure.core.credentials import ( AccessToken, AzureSasCredential, @@ -29,19 +26,25 @@ from azure.core.pipeline.policies import RetryMode -from .exceptions import _handle_exception, ClientClosedError, ConnectError +from ._transport._uamqp_transport import UamqpTransport +from .exceptions import ClientClosedError from ._configuration import Configuration -from ._utils import utc_from_timestamp, parse_sas_credential +from ._utils import utc_from_timestamp, parse_sas_credential, generate_sas_token from ._connection_manager import get_connection_manager from ._constants import ( CONTAINER_PREFIX, JWT_TOKEN_SCOPE, - MGMT_OPERATION, - MGMT_PARTITION_OPERATION, + READ_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, + MGMT_OPERATION, + MGMT_PARTITION_OPERATION, ) +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from uamqp import Message as uamqp_Message + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth _LOGGER = logging.getLogger(__name__) _Address = collections.namedtuple("_Address", "hostname path") @@ -142,11 +145,8 @@ def _generate_sas_token(uri, policy, key, expiry=None): expiry = timedelta(hours=1) # Default to 1 hour. abs_expiry = int(time.time()) + expiry.seconds - encoded_uri = quote_plus(uri).encode("utf-8") # pylint: disable=no-member - encoded_policy = quote_plus(policy).encode("utf-8") # pylint: disable=no-member - encoded_key = key.encode("utf-8") - token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry) + token = generate_sas_token(uri, policy, key, abs_expiry).encode() return AccessToken(token=token, expires_on=abs_expiry) @@ -265,7 +265,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if TYPE_CHECKING: from azure.core.credentials import TokenCredential - CredentialTypes = Union[ + CredentialTypes: TypeAlias = Union[ AzureSasCredential, AzureNamedKeyCredential, EventHubSharedKeyCredential, @@ -274,8 +274,16 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument class ClientBase(object): # pylint:disable=too-many-instance-attributes - def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs): - # type: (str, str, CredentialTypes, Any) -> None + def __init__( + self, + fully_qualified_namespace: str, + eventhub_name: str, + credential: CredentialTypes, + **kwargs: Any, + ) -> None: + uamqp_transport = kwargs.pop("uamqp_transport", True) + self._amqp_transport = kwargs.pop("amqp_transport", UamqpTransport) + self.eventhub_name = eventhub_name if not eventhub_name: raise ValueError("The eventhub name can not be None or empty.") @@ -290,11 +298,12 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg self._credential = credential # type: ignore self._keep_alive = kwargs.get("keep_alive", 30) self._auto_reconnect = kwargs.get("auto_reconnect", True) - self._mgmt_target = "amqps://{}/{}".format( - self._address.hostname, self.eventhub_name + self._auth_uri = f"sb://{self._address.hostname}{self._address.path}" + self._config = Configuration( + uamqp_transport=uamqp_transport, + hostname=self._address.hostname, + **kwargs, ) - self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path) - self._config = Configuration(**kwargs) self._debug = self._config.network_tracing self._conn_manager = get_connection_manager(**kwargs) self._idle_timeout = kwargs.get("idle_timeout", None) @@ -313,11 +322,10 @@ def _from_connection_string(conn_str, **kwargs): kwargs["credential"] = EventHubSharedKeyCredential(policy, key) return kwargs - def _create_auth(self): - # type: () -> authentication.JWTTokenAuth + def _create_auth(self) -> uamqp_JWTTokenAuth: """ - Create an ~uamqp.authentication.SASTokenAuth instance to authenticate - the session. + Create an ~uamqp.authentication.SASTokenAuth instance + to authenticate the session. """ try: # ignore mypy's warning because token_type is Optional @@ -325,32 +333,19 @@ def _create_auth(self): except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - auth = authentication.JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=True, ) - auth.update_token() - return auth - return authentication.JWTTokenAuth( - self._auth_uri, + return self._amqp_transport.create_token_auth( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, - refresh_window=300, + config=self._config, + update_token=False, ) def _close_connection(self): @@ -385,26 +380,32 @@ def _backoff( ) raise last_exception - def _management_request(self, mgmt_msg, op_type): - # type: (Message, bytes) -> Any + def _management_request( + self, mgmt_msg: uamqp_Message, op_type: bytes + ) -> Any: # pylint:disable=assignment-from-none retried_times = 0 last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = self._create_auth() - mgmt_client = AMQPClient( - self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: conn = self._conn_manager.get_connection( # pylint:disable=assignment-from-none - self._address.hostname, mgmt_auth + host=self._address.hostname, auth=mgmt_auth ) mgmt_client.open(connection=conn) - mgmt_msg.application_properties["security_token"] = mgmt_auth.token - response = mgmt_client.mgmt_request( + while not mgmt_client.client_ready(): + time.sleep(0.05) + mgmt_msg.application_properties[ + "security_token" + ] = self._amqp_transport.get_updated_token(mgmt_auth) + response = self._amqp_transport.mgmt_client_request( + mgmt_client, mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, + operation=READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -412,29 +413,15 @@ def _management_request(self, mgmt_msg, op_type): description = response.application_properties.get( MGMT_STATUS_DESC ) # type: Optional[Union[str, bytes]] - if description and isinstance(description, six.binary_type): + if description and isinstance(description, bytes): description = description.decode("utf-8") if status_code < 400: return response - if status_code in [401]: - raise errors.AuthenticationException( - "Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - if status_code in [404]: - raise ConnectError( - "Management connection failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - raise errors.AMQPConnectionError( - "Management request error. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) + raise self._amqp_transport.get_error(status_code, description) except Exception as exception: # pylint: disable=broad-except - last_exception = _handle_exception(exception, self) + last_exception = self._amqp_transport._handle_exception( # pylint: disable=protected-access + exception, self + ) self._backoff( retried_times=retried_times, last_exception=last_exception ) @@ -453,12 +440,13 @@ def _add_span_request_attributes(self, span): span.add_attribute("message_bus.destination", self._address.path) span.add_attribute("peer.address", self._address.hostname) - def _get_eventhub_properties(self): - # type:() -> Dict[str, Any] - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + def _get_eventhub_properties(self) -> Dict[str, Any]: + mgmt_msg = self._amqp_transport.build_message( + application_properties={"name": self.eventhub_name} + ) response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) output = {} - eh_info = response.get_data() # type: Dict[bytes, Any] + eh_info: Dict[bytes, Any] = response.value if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -475,14 +463,14 @@ def _get_partition_ids(self): def _get_partition_properties(self, partition_id): # type:(str) -> Dict[str, Any] - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.build_message( application_properties={ "name": self.eventhub_name, "partition": partition_id, } ) response = self._management_request(mgmt_msg, op_type=MGMT_PARTITION_OPERATION) - partition_info = response.get_data() # type: Dict[bytes, Any] + partition_info = response.value # type: Dict[bytes, Any] output = {} if partition_info: output["eventhub_name"] = partition_info[b"name"].decode("utf-8") @@ -520,9 +508,7 @@ def _create_handler(self, auth): def _check_closed(self): if self.closed: raise ClientClosedError( - "{} has been closed. Please create a new one to handle event data.".format( - self._name - ) + f"{self._name} has been closed. Please create a new one to handle event data." ) def _open(self): @@ -533,17 +519,16 @@ def _open(self): self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open( - connection=self._client._conn_manager.get_connection( - self._client._address.hostname, auth - ) # pylint: disable=protected-access + conn = self._client._conn_manager.get_connection( # pylint: disable=protected-access + host=self._client._address.hostname, auth=auth ) + self._handler.open(connection=conn) while not self._handler.client_ready(): time.sleep(0.05) self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES - ) # pylint: disable=protected-access + self._amqp_transport.get_remote_max_message_size(self._handler) + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES + ) self.running = True def _close_handler(self): @@ -556,9 +541,10 @@ def _close_connection(self): self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access def _handle_exception(self, exception): - if not self.running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return _handle_exception(exception, self) + exception = self._amqp_transport.check_timeout_exception(self, exception) + return self._amqp_transport._handle_exception( # pylint: disable=protected-access + exception, self + ) def _do_retryable_operation(self, operation, timeout=None, **kwargs): # pylint:disable=protected-access @@ -576,7 +562,7 @@ def _do_retryable_operation(self, operation, timeout=None, **kwargs): return operation( timeout_time=timeout_time, last_exception=last_exception, - **kwargs + **kwargs, ) return operation() except Exception as exception: # pylint:disable=broad-except diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 1ec58f126b52..969d9c5bfb5f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -2,9 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import json +import datetime import logging import uuid from typing import ( @@ -20,12 +21,7 @@ ) from typing_extensions import TypedDict -import six - -from uamqp import BatchMessage, Message, constants - from ._utils import ( - set_message_partition_key, trace_message, utc_from_timestamp, transform_outbound_single_message, @@ -35,7 +31,6 @@ PROP_SEQ_NUMBER, PROP_OFFSET, PROP_PARTITION_KEY, - PROP_PARTITION_KEY_AMQP_SYMBOL, PROP_TIMESTAMP, PROP_ABSOLUTE_EXPIRY_TIME, PROP_CONTENT_ENCODING, @@ -57,9 +52,11 @@ AmqpMessageHeader, AmqpMessageProperties, ) +from ._transport._uamqp_transport import UamqpTransport if TYPE_CHECKING: - import datetime + from uamqp import Message as uamqp_Message, BatchMessage as uamqp_BatchMessage + from ._transport._base import AmqpTransport MessageContent = TypedDict("MessageContent", {"content": bytes, "content_type": str}) PrimitiveTypes = Optional[ @@ -127,62 +124,60 @@ def __init__( self._raw_amqp_message = AmqpAnnotatedMessage( # type: ignore data_body=body, annotations={}, application_properties={} ) - self.message = ( - self._raw_amqp_message._message - ) # pylint:disable=protected-access + # amqp message to be reset right before sending + self._message = UamqpTransport.to_outgoing_amqp_message(self._raw_amqp_message) + self.message = self._message self._raw_amqp_message.header = AmqpMessageHeader() self._raw_amqp_message.properties = AmqpMessageProperties() self.message_id = None self.content_type = None self.correlation_id = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except try: body_str = self.body_as_str() except: body_str = "" - event_repr = "body='{}'".format(body_str) + event_repr = f"body='{body_str}'" try: - event_repr += ", properties={}".format(self.properties) + event_repr += f", properties={self.properties}" except: event_repr += ", properties=" try: - event_repr += ", offset={}".format(self.offset) + event_repr += f", offset={self.offset}" except: event_repr += ", offset=" try: - event_repr += ", sequence_number={}".format(self.sequence_number) + event_repr += f", sequence_number={self.sequence_number}" except: event_repr += ", sequence_number=" try: - event_repr += ", partition_key={!r}".format(self.partition_key) + event_repr += f", partition_key={self.partition_key!r}" except: event_repr += ", partition_key=" try: - event_repr += ", enqueued_time={!r}".format(self.enqueued_time) + event_repr += f", enqueued_time={self.enqueued_time!r}" except: event_repr += ", enqueued_time=" - return "EventData({})".format(event_repr) + return f"EventData({event_repr})" - def __str__(self): - # type: () -> str + def __str__(self) -> str: try: body_str = self.body_as_str() except: # pylint: disable=bare-except body_str = "" - event_str = "{{ body: '{}'".format(body_str) + event_str = f"{{ body: '{body_str}'" try: - event_str += ", properties: {}".format(self.properties) + event_str += f", properties: {self.properties}" if self.offset: - event_str += ", offset: {}".format(self.offset) + event_str += f", offset: {self.offset}" if self.sequence_number: - event_str += ", sequence_number: {}".format(self.sequence_number) + event_str += f", sequence_number: {self.sequence_number}" if self.partition_key: - event_str += ", partition_key={!r}".format(self.partition_key) + event_str += f", partition_key={self.partition_key!r}" if self.enqueued_time: - event_str += ", enqueued_time={!r}".format(self.enqueued_time) + event_str += f", enqueued_time={self.enqueued_time!r}" except: # pylint: disable=bare-except pass event_str += " }" @@ -213,8 +208,11 @@ def from_message_content( # pylint: disable=unused-argument return event_data @classmethod - def _from_message(cls, message, raw_amqp_message=None): - # type: (Message, Optional[AmqpAnnotatedMessage]) -> EventData + def _from_message( + cls, + message: uamqp_Message, + raw_amqp_message: Optional[AmqpAnnotatedMessage] = None, + ) -> EventData: # pylint:disable=protected-access """Internal use only. @@ -225,8 +223,9 @@ def _from_message(cls, message, raw_amqp_message=None): :rtype: ~azure.eventhub.EventData """ event_data = cls(body="") - event_data.message = message # pylint: disable=protected-access + event_data._message = message + event_data.message = message event_data._raw_amqp_message = ( raw_amqp_message if raw_amqp_message @@ -234,39 +233,24 @@ def _from_message(cls, message, raw_amqp_message=None): ) return event_data - def _encode_message(self): - # type: () -> bytes - # pylint: disable=protected-access - return self._raw_amqp_message._message.encode_message() - - def _decode_non_data_body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def _decode_non_data_body_as_str(self, encoding: str = "UTF-8") -> str: # pylint: disable=protected-access - body = self.raw_amqp_message._message._body + body = self.raw_amqp_message.body if self.body_type == AmqpMessageBodyType.VALUE: - if not body.data: + if not body: return "" - return str(decode_with_recurse(body.data, encoding)) + return str(decode_with_recurse(body, encoding)) - seq_list = [d for seq_section in body.data for d in seq_section] + seq_list = [d for seq_section in body for d in seq_section] return str(decode_with_recurse(seq_list, encoding)) - def _to_outgoing_message(self): - # type: () -> EventData - self.message = ( - self._raw_amqp_message._to_outgoing_amqp_message() # pylint:disable=protected-access - ) - return self - @property - def raw_amqp_message(self): - # type: () -> AmqpAnnotatedMessage + def raw_amqp_message(self) -> AmqpAnnotatedMessage: """Advanced usage only. The internal AMQP message payload that is sent or received.""" return self._raw_amqp_message @property - def sequence_number(self): - # type: () -> Optional[int] + def sequence_number(self) -> Optional[int]: """The sequence number of the event. :rtype: int @@ -274,8 +258,7 @@ def sequence_number(self): return self._raw_amqp_message.annotations.get(PROP_SEQ_NUMBER, None) @property - def offset(self): - # type: () -> Optional[str] + def offset(self) -> Optional[str]: """The offset of the event. :rtype: str @@ -286,8 +269,7 @@ def offset(self): return None @property - def enqueued_time(self): - # type: () -> Optional[datetime.datetime] + def enqueued_time(self) -> Optional[datetime.datetime]: """The enqueued timestamp of the event. :rtype: datetime.datetime @@ -298,20 +280,15 @@ def enqueued_time(self): return None @property - def partition_key(self): - # type: () -> Optional[bytes] + def partition_key(self) -> Optional[bytes]: """The partition key of the event. :rtype: bytes """ - try: - return self._raw_amqp_message.annotations[PROP_PARTITION_KEY_AMQP_SYMBOL] - except KeyError: - return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) + return self._raw_amqp_message.annotations.get(PROP_PARTITION_KEY, None) @property - def properties(self): - # type: () -> Dict[Union[str, bytes], Any] + def properties(self) -> Dict[Union[str, bytes], Any]: """Application-defined properties on the event. :rtype: dict @@ -319,8 +296,7 @@ def properties(self): return self._raw_amqp_message.application_properties @properties.setter - def properties(self, value): - # type: (Dict[Union[str, bytes], Any]) -> None + def properties(self, value: Dict[Union[str, bytes], Any]): """Application-defined properties on the event. :param dict value: The application properties for the EventData. @@ -329,8 +305,7 @@ def properties(self, value): self._raw_amqp_message.application_properties = properties @property - def system_properties(self): - # type: () -> Dict[bytes, Any] + def system_properties(self) -> Dict[bytes, Any]: """Metadata set by the Event Hubs Service associated with the event. An EventData could have some or all of the following meta data depending on the source @@ -368,8 +343,7 @@ def system_properties(self): return self._sys_properties @property - def body(self): - # type: () -> PrimitiveTypes + def body(self) -> PrimitiveTypes: """The body of the Message. The format may vary depending on the body type: For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, the body could be bytes or Iterable[bytes]. @@ -386,16 +360,14 @@ def body(self): raise ValueError("Event content empty.") @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ return self._raw_amqp_message.body_type - def body_as_str(self, encoding="UTF-8"): - # type: (str) -> str + def body_as_str(self, encoding: str = "UTF-8") -> str: """The content of the event as a string, if the data is of a compatible type. :param encoding: The encoding to use for decoding event data. @@ -408,18 +380,15 @@ def body_as_str(self, encoding="UTF-8"): return self._decode_non_data_body_as_str(encoding=encoding) return "".join(b.decode(encoding) for b in cast(Iterable[bytes], data)) except TypeError: - return six.text_type(data) + return str(data) except: # pylint: disable=bare-except pass try: return cast(bytes, data).decode(encoding) except Exception as e: - raise TypeError( - "Message data is not compatible with string type: {}".format(e) - ) + raise TypeError(f"Message data is not compatible with string type: {e}") - def body_as_json(self, encoding="UTF-8"): - # type: (str) -> Dict[str, Any] + def body_as_json(self, encoding: str = "UTF-8") -> Dict[str, Any]: """The content of the event loaded as a JSON object, if the data is compatible. :param encoding: The encoding to use for decoding event data. @@ -430,11 +399,10 @@ def body_as_json(self, encoding="UTF-8"): try: return json.loads(data_str) except Exception as e: - raise TypeError("Event data is not compatible with JSON type: {}".format(e)) + raise TypeError(f"Event data is not compatible with JSON type: {e}") @property - def content_type(self): - # type: () -> Optional[str] + def content_type(self) -> Optional[str]: """The content type descriptor. Optionally describes the payload of the message, with a descriptor following the format of RFC2045, Section 5, for example "application/json". @@ -448,15 +416,13 @@ def content_type(self): return self._raw_amqp_message.properties.content_type @content_type.setter - def content_type(self, value): - # type: (str) -> None + def content_type(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.content_type = value @property - def correlation_id(self): - # type: () -> Optional[str] + def correlation_id(self) -> Optional[str]: """The correlation identifier. Allows an application to specify a context for the message for the purposes of correlation, for example reflecting the MessageId of a message that is being replied to. @@ -470,15 +436,13 @@ def correlation_id(self): return self._raw_amqp_message.properties.correlation_id @correlation_id.setter - def correlation_id(self, value): - # type: (str) -> None + def correlation_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.correlation_id = value @property - def message_id(self): - # type: () -> Optional[str] + def message_id(self) -> Optional[str]: """The id to identify the message. The message identifier is an application-defined value that uniquely identifies the message and its payload. The identifier is a free-form string and can reflect a GUID or an identifier derived from the @@ -494,7 +458,7 @@ def message_id(self): return self._raw_amqp_message.properties.message_id @message_id.setter - def message_id(self, value): + def message_id(self, value: str) -> None: if not self._raw_amqp_message.properties: self._raw_amqp_message.properties = AmqpMessageProperties() self._raw_amqp_message.properties.message_id = value @@ -525,11 +489,16 @@ class EventDataBatch(object): Event Hub decided by the service. """ - def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None): - # type: (Optional[int], Optional[str], Optional[Union[str, bytes]]) -> None + def __init__( + self, + max_size_in_bytes: Optional[int] = None, + partition_id: Optional[str] = None, + partition_key: Optional[Union[str, bytes]] = None + ) -> None: + self._amqp_transport = UamqpTransport if partition_key and not isinstance( - partition_key, (six.text_type, six.binary_type) + partition_key, (str, bytes) ): _LOGGER.info( "WARNING: Setting partition_key of non-string value on the events to be sent is discouraged " @@ -538,35 +507,50 @@ def __init__(self, max_size_in_bytes=None, partition_id=None, partition_key=None "partition_key to only be string type, they might fail to parse the non-string value." ) - self.max_size_in_bytes = max_size_in_bytes or constants.MAX_MESSAGE_LENGTH_BYTES - self.message = BatchMessage(data=[], multi_messages=False, properties=None) + self.max_size_in_bytes = ( + max_size_in_bytes or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES + ) + self._message = self._amqp_transport.build_batch_message(data=[]) self._partition_id = partition_id self._partition_key = partition_key - set_message_partition_key(self.message, self._partition_key) - self._size = self.message.gather()[0].get_message_encoded_size() + self._message = self._amqp_transport.set_message_partition_key( + self._message, self._partition_key + ) + self.message: uamqp_BatchMessage = self._message + self._size = self._amqp_transport.get_batch_message_encoded_size(self._message) self._count = 0 - self._internal_events: List[Union[EventData, AmqpAnnotatedMessage]] = [] - - def __repr__(self): - # type: () -> str - batch_repr = "max_size_in_bytes={}, partition_id={}, partition_key={!r}, event_count={}".format( - self.max_size_in_bytes, self._partition_id, self._partition_key, self._count + self._internal_events: List[ + Union[EventData, AmqpAnnotatedMessage] + ] = [] + + def __repr__(self) -> str: + batch_repr = ( + f"max_size_in_bytes={self.max_size_in_bytes}, partition_id={self._partition_id}, " + f"partition_key={self._partition_key!r}, event_count={self._count}" ) - return "EventDataBatch({})".format(batch_repr) + return f"EventDataBatch({batch_repr})" - def __len__(self): + def __len__(self) -> int: return self._count @classmethod - def _from_batch(cls, batch_data, partition_key=None): - # type: (Iterable[EventData], Optional[AnyStr]) -> EventDataBatch + def _from_batch( + cls, + batch_data: Iterable[EventData], + amqp_transport: AmqpTransport, + partition_key: Optional[AnyStr] = None, + ) -> EventDataBatch: outgoing_batch_data = [ - transform_outbound_single_message(m, EventData) for m in batch_data + transform_outbound_single_message( + m, EventData, amqp_transport.to_outgoing_amqp_message + ) + for m in batch_data ] batch_data_instance = cls(partition_key=partition_key) - for data in outgoing_batch_data: - batch_data_instance.add(data) + + for event_data in outgoing_batch_data: + batch_data_instance.add(event_data) return batch_data_instance def _load_events(self, events): @@ -581,16 +565,14 @@ def _load_events(self, events): ) @property - def size_in_bytes(self): - # type: () -> int + def size_in_bytes(self) -> int: """The combined size of the events in the batch, in bytes. :rtype: int """ return self._size - def add(self, event_data): - # type: (Union[EventData, AmqpAnnotatedMessage]) -> None + def add(self, event_data: Union[EventData, AmqpAnnotatedMessage]) -> None: """Try to add an EventData to the batch. The total size of an added event is the sum of its body, properties, etc. @@ -603,7 +585,9 @@ def add(self, event_data): :raise: :class:`ValueError`, when exceeding the size limit. """ - outgoing_event_data = transform_outbound_single_message(event_data, EventData) + outgoing_event_data = transform_outbound_single_message( + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message + ) if self._partition_key: if ( @@ -614,13 +598,15 @@ def add(self, event_data): "The partition key of event_data does not match the partition key of this batch." ) if not outgoing_event_data.partition_key: - set_message_partition_key( - outgoing_event_data.message, self._partition_key + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, # pylint: disable=protected-access + self._partition_key, ) trace_message(outgoing_event_data) - event_data_size = outgoing_event_data.message.get_message_encoded_size() - + event_data_size = self._amqp_transport.get_message_encoded_size( + outgoing_event_data._message # pylint: disable=protected-access + ) # For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that # message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes. size_after_add = ( @@ -631,14 +617,9 @@ def add(self, event_data): if size_after_add > self.max_size_in_bytes: raise ValueError( - "EventDataBatch has reached its size limit: {}".format( - self.max_size_in_bytes - ) + f"EventDataBatch has reached its size limit: {self.max_size_in_bytes}" ) - self._internal_events.append(event_data) - self.message._body_gen.append( # pylint: disable=protected-access - outgoing_event_data - ) + self._amqp_transport.add_batch(self, outgoing_event_data, event_data) self._size = size_after_add self._count += 1 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index 652fcb9bbc04..00c03ca4197b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -3,14 +3,10 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from typing import Optional, Dict, Any +from urllib.parse import urlparse -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - -from uamqp.constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT from azure.core.pipeline.policies import RetryMode +from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT class Configuration(object): # pylint:disable=too-many-instance-attributes @@ -39,10 +35,14 @@ def __init__(self, **kwargs): self.connection_verify = kwargs.get("connection_verify") # type: Optional[str] self.connection_port = DEFAULT_AMQPS_PORT self.custom_endpoint_hostname = None + self.hostname = kwargs.pop("hostname") + uamqp_transport = kwargs.pop("uamqp_transport") - if self.http_proxy or self.transport_type == TransportType.AmqpOverWebsocket: + if self.http_proxy or self.transport_type.value == TransportType.AmqpOverWebsocket.value: self.transport_type = TransportType.AmqpOverWebsocket self.connection_port = DEFAULT_AMQP_WSS_PORT + if not uamqp_transport: + self.hostname += "/$servicebus/websocket" # custom end point if self.custom_endpoint_address: @@ -53,5 +53,7 @@ def __init__(self, **kwargs): endpoint = urlparse(self.custom_endpoint_address) self.transport_type = TransportType.AmqpOverWebsocket self.custom_endpoint_hostname = endpoint.hostname + if not uamqp_transport: + self.custom_endpoint_address += "/$servicebus/websocket" # in case proxy and custom endpoint are both provided, we default port to 443 if it's not provided self.connection_port = endpoint.port or DEFAULT_AMQP_WSS_PORT diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 623a25ece678..f8e109a224cc 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -3,12 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +from threading import Lock +from enum import Enum -from uamqp import Connection +from ._transport._uamqp_transport import UamqpTransport +from ._constants import TransportType if TYPE_CHECKING: from uamqp.authentication import JWTTokenAuth + from uamqp import Connection try: from typing_extensions import Protocol @@ -16,8 +21,9 @@ Protocol = object # type: ignore class ConnectionManager(Protocol): - def get_connection(self, host, auth): - # type: (str, 'JWTTokenAuth') -> Connection + def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAuth] = None, endpoint: Optional[str] = None + ) -> Connection: pass def close_connection(self): @@ -27,12 +33,74 @@ def reset_connection_if_broken(self): pass +class _ConnectionMode(Enum): + ShareConnection = 1 + SeparateConnection = 2 + + +class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes + def __init__(self, **kwargs): + self._lock = Lock() + self._conn: Connection = None + + self._container_id = kwargs.get("container_id") + self._debug = kwargs.get("debug") + self._error_policy = kwargs.get("error_policy") + self._properties = kwargs.get("properties") + self._encoding = kwargs.get("encoding") or "UTF-8" + self._transport_type = kwargs.get("transport_type") or TransportType.Amqp + self._http_proxy = kwargs.get("http_proxy") + self._max_frame_size = kwargs.get("max_frame_size") + self._channel_max = kwargs.get("channel_max") + self._idle_timeout = kwargs.get("idle_timeout") + self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( + "remote_idle_timeout_empty_frame_send_ratio" + ) + self._amqp_transport = kwargs.get("amqp_transport", UamqpTransport) + + def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAuth] = None, endpoint: Optional[str] = None + ) -> Connection: + with self._lock: + if self._conn is None: + self._conn = self._amqp_transport.create_connection( + host=host, + auth=auth, + endpoint=endpoint, + container_id=self._container_id, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, + error_policy=self._error_policy, + debug=self._debug, + encoding=self._encoding, + ) + return self._conn + + def close_connection(self): + # type: () -> None + with self._lock: + if self._conn: + self._amqp_transport.close_connection(self._conn) + self._conn = None + + def reset_connection_if_broken(self): + # type: () -> None + with self._lock: + conn_state = self._amqp_transport.get_connection_state(self._conn) + if self._conn and conn_state in self._amqp_transport.CONNECTION_CLOSING_STATES: + self._conn = None + + class _SeparateConnectionManager(object): def __init__(self, **kwargs): pass - def get_connection(self, host, auth): # pylint:disable=unused-argument, no-self-use - # type: (str, JWTTokenAuth) -> None + def get_connection( # pylint:disable=unused-argument, no-self-use + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAuth] = None, endpoint: Optional[str] = None + ) -> None: return None def close_connection(self): @@ -46,4 +114,7 @@ def reset_connection_if_broken(self): def get_connection_manager(**kwargs): # type: (...) -> 'ConnectionManager' + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) # type: ignore + if connection_mode == _ConnectionMode.ShareConnection: + return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py index 8c21614a3932..eb8fd4f6198f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_constants.py @@ -4,13 +4,12 @@ # -------------------------------------------------------------------------------------------- from __future__ import unicode_literals -from uamqp import types +from enum import Enum PROP_SEQ_NUMBER = b"x-opt-sequence-number" PROP_OFFSET = b"x-opt-offset" PROP_PARTITION_KEY = b"x-opt-partition-key" -PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) PROP_TIMESTAMP = b"x-opt-enqueued-time" PROP_LAST_ENQUEUED_SEQUENCE_NUMBER = b"last_enqueued_sequence_number" PROP_LAST_ENQUEUED_OFFSET = b"last_enqueued_offset" @@ -45,10 +44,33 @@ MGMT_STATUS_DESC = b"status-description" USER_AGENT_PREFIX = "azsdk-python-eventhubs" -NO_RETRY_ERRORS = ( +NO_RETRY_ERRORS = [ b"com.microsoft:argument-out-of-range", b"com.microsoft:entity-disabled", b"com.microsoft:auth-failed", b"com.microsoft:precondition-failed", b"com.microsoft:argument-error", -) +] + +CUSTOM_CONDITION_BACKOFF = { + b"com.microsoft:server-busy": 4, + b"com.microsoft:timeout": 2, + b"com.microsoft:operation-cancelled": 0, + b"com.microsoft:container-close": 4 +} + + +## all below - previously uamqp +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 + +DEFAULT_AMQPS_PORT = 5671 +DEFAULT_AMQP_WSS_PORT = 443 +READ_OPERATION = b"READ" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index a2a2f80e3df4..8f647a60c6e7 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -2,19 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import time import uuid import logging from collections import deque -from typing import TYPE_CHECKING, Callable, Dict, Optional, Any +from typing import TYPE_CHECKING, Callable, Dict, Optional, Any, Deque -import uamqp -from uamqp import types, errors, utils -from uamqp import ReceiveClient, Source, Message - -from .exceptions import _error_handler from ._common import EventData from ._client_base import ConsumerProducerMixin from ._utils import create_properties, event_position_selector @@ -25,8 +20,9 @@ ) if TYPE_CHECKING: - from typing import Deque - from uamqp.authentication import JWTTokenAuth + from uamqp import ReceiveClient as uamqp_ReceiveClient, Message as uamqp_Message, types as uamqp_types + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + from ._consumer_client import EventHubConsumerClient @@ -69,8 +65,7 @@ class EventHubConsumer( It is set to `False` by default. """ - def __init__(self, client, source, **kwargs): - # type: (EventHubConsumerClient, str, Any) -> None + def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs: Any) -> None: event_position = kwargs.get("event_position", None) prefetch = kwargs.get("prefetch", 300) owner_level = kwargs.get("owner_level", None) @@ -86,9 +81,10 @@ def __init__(self, client, source, **kwargs): self.stop = False # used by event processor self.handler_ready = False - self._on_event_received = kwargs[ + self._amqp_transport = kwargs.pop("amqp_transport") + self._on_event_received: Callable[[EventData], None] = kwargs[ "on_event_received" - ] # type: Callable[[EventData], None] + ] self._client = client self._source = source self._offset = event_position @@ -97,110 +93,89 @@ def __init__(self, client, source, **kwargs): self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] + link_properties: Dict[uamqp_types.AMQPType, uamqp_types.AMQPType] = {} self._error = None self._timeout = 0 - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - partition = self._source.split("/")[-1] - self._partition = partition - self._name = "EHConsumer-{}-partition{}".format(uuid.uuid4(), partition) + self._idle_timeout = (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) if idle_timeout else None + self._partition = self._source.split("/")[-1] + self._name = f"EHConsumer-{uuid.uuid4()}-partition{self._partition}" if owner_level is not None: - self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong( - int(owner_level) - ) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( - self._client._config.receive_timeout - or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong( - int(link_property_timeout_ms) - ) - self._handler = None # type: Optional[ReceiveClient] + self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access + ) * self._amqp_transport.TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) + self._handler: Optional[uamqp_ReceiveClient] = None self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties ) - self._message_buffer = deque() # type: Deque[Message] - self._last_received_event = None # type: Optional[EventData] - self._receive_start_time = None # type: Optional[float] - - def _create_handler(self, auth): - # type: (JWTTokenAuth) -> None - source = Source(self._source) - if self._offset is not None: - source.set_filter( - event_position_selector(self._offset, self._offset_inclusive) - ) - desired_capabilities = None - if self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - - properties = create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._message_buffer: Deque[uamqp_Message] = deque() + self._last_received_event: Optional[EventData] = None + self._receive_start_time: Optional[float]= None + + def _create_handler(self, auth: uamqp_JWTTokenAuth) -> None: + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) ) - self._handler = ReceiveClient( - source, + desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + link_credit=self._prefetch, link_properties=self._link_properties, timeout=self._timeout, idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=properties, + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, + streaming_receive=True, + message_received_callback=self._message_received, ) - self._handler._streaming_receive = True # pylint:disable=protected-access - self._handler._message_received_callback = ( # pylint:disable=protected-access - self._message_received - ) - - def _open_with_retry(self): - # type: () -> None + def _open_with_retry(self) -> None: self._do_retryable_operation(self._open, operation_need_param=False) - def _message_received(self, message): - # type: (uamqp.Message) -> None + def _message_received(self, message: uamqp_Message) -> None: # pylint:disable=protected-access - self._message_buffer.appendleft(message) + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access - message = self._message_buffer.pop() + message = self._message_buffer.popleft() event_data = EventData._from_message(message) self._last_received_event = event_data return event_data - def _open(self): - # type: () -> bool - """Open the EventHubConsumer/EventHubProducer using the supplied connection.""" + def _open(self) -> bool: + """Open the EventHubConsumer/EventHubProducer using the supplied connection. + """ # pylint: disable=protected-access if not self.running: if self._handler: self._handler.close() auth = self._client._create_auth() self._create_handler(auth) - self._handler.open( - connection=self._client._conn_manager.get_connection( - self._client._address.hostname, auth - ) # pylint: disable=protected-access + conn = self._client._conn_manager.get_connection( # pylint: disable=protected-access + host=self._client._address.hostname, auth=auth ) - self.handler_ready = False + self._handler.open(connection=conn) + while not self._handler.client_ready(): + time.sleep(0.05) + self.handler_ready = True self.running = True - if not self.handler_ready: - if self._handler.client_ready(): # type: ignore - self.handler_ready = True return self.handler_ready def receive(self, batch=False, max_batch_size=300, max_wait_time=None): @@ -211,20 +186,15 @@ def receive(self, batch=False, max_batch_size=300, max_wait_time=None): self._receive_start_time = self._receive_start_time or time.time() deadline = self._receive_start_time + ( max_wait_time or 0 - ) # max_wait_time can be None + ) if len(self._message_buffer) < max_batch_size: while retried_times <= max_retries: try: if self._open(): - self._handler.do_work() # type: ignore + self._handler.do_work(batch=self._prefetch) # type: ignore break except Exception as exception: # pylint: disable=broad-except - if ( - isinstance(exception, uamqp.errors.LinkDetach) - and exception.condition # pylint: disable=no-member - == uamqp.constants.ErrorCodes.LinkStolen - ): - raise self._handle_exception(exception) + self._amqp_transport.check_link_stolen(self, exception) if not self.running: # exit by close return if self._last_received_event: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index f45597f06c6b..5b48324bbe05 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -146,6 +146,7 @@ def __init__( **kwargs # type: Any ): # type: (...) -> None + self._checkpoint_store = kwargs.pop("checkpoint_store", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", None) if self._load_balancing_interval is None: @@ -210,6 +211,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, ) return handler @@ -222,9 +224,6 @@ def from_connection_string(cls, conn_str, consumer_group, **kwargs): :param str consumer_group: Receive events from the Event Hub for this consumer group. :keyword str eventhub_name: The path of the specific Event Hub to connect the client to. :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. - :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following - keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). - Additionally the following keys may also be present: `'username', 'password'`. :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. The default value is 60 seconds. If set to 0, no timeout will be enforced from the client. :keyword str user_agent: If specified, this will be added in front of the user agent string. @@ -254,6 +253,9 @@ def from_connection_string(cls, conn_str, consumer_group, **kwargs): If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could be used instead which uses port 443 for communication. :paramtype transport_type: ~azure.eventhub.TransportType + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. :keyword checkpoint_store: A manager that stores the partition load-balancing and checkpoint data when receiving events. The checkpoint store will be used in both cases of receiving from all partitions or a single partition. In the latter case load-balancing does not apply. @@ -285,9 +287,9 @@ def from_connection_string(cls, conn_str, consumer_group, **kwargs): :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. - :rtype: ~azure.eventhub.EventHubConsumerClient + .. admonition:: Example: .. literalinclude:: ../samples/sync_samples/sample_code_eventhub.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 9bda02cf1d7d..5b42d964400e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -2,11 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import uuid import logging -import time import threading from typing import ( Iterable, @@ -18,17 +17,10 @@ TYPE_CHECKING, ) # pylint: disable=unused-import -from uamqp import types, constants, errors -from uamqp import SendClient - -from azure.core.tracing import AbstractSpan - -from .exceptions import _error_handler, OperationTimeoutError from ._common import EventData, EventDataBatch from ._client_base import ConsumerProducerMixin from ._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, @@ -39,19 +31,29 @@ _LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: - from uamqp.authentication import JWTTokenAuth # pylint: disable=ungrouped-imports + from azure.core.tracing import AbstractSpan + + from uamqp import constants as uamqp_constants, SendClient as uamqp_SendClient + from uamqp.authentication import JWTTokenAuth as uamqp_JWTTokenAuth + from ._transport._base import AmqpTransport from ._producer_client import EventHubProducerClient +_LOGGER = logging.getLogger(__name__) -def _set_partition_key(event_datas, partition_key): - # type: (Iterable[EventData], AnyStr) -> Iterable[EventData] + +def _set_partition_key( + event_datas: Iterable[EventData], + partition_key: AnyStr, + amqp_transport: AmqpTransport, +) -> Iterable[EventData]: for ed in iter(event_datas): - set_message_partition_key(ed.message, partition_key) + amqp_transport.set_message_partition_key(ed._message, partition_key) # pylint: disable=protected-access yield ed -def _set_trace_message(event_datas, parent_span=None): - # type: (Iterable[EventData], Optional[AbstractSpan]) -> Iterable[EventData] +def _set_trace_message( + event_datas: Iterable[EventData], parent_span: Optional["AbstractSpan"] = None +) -> Iterable[EventData]: for ed in iter(event_datas): trace_message(ed, parent_span) yield ed @@ -82,8 +84,11 @@ class EventHubProducer( Default value is `True`. """ - def __init__(self, client, target, **kwargs): - # type: (EventHubProducerClient, str, Any) -> None + def __init__( + self, client: "EventHubProducerClient", target: str, **kwargs: Any + ) -> None: + + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -98,83 +103,59 @@ def __init__(self, client, target, **kwargs): self._target = target self._partition = partition self._timeout = send_timeout - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None + self._idle_timeout = ( + (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) + if idle_timeout + else None + ) self._error = None self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint: disable=protected-access + self._retry_policy = self._amqp_transport.create_retry_policy( + config=self._client._config ) self._reconnect_backoff = 1 - self._name = "EHProducer-{}".format(uuid.uuid4()) - self._unsent_events = [] # type: List[Any] + self._name = f"EHProducer-{uuid.uuid4()}" + self._unsent_events: List[Any] = [] if partition: self._target += "/Partitions/" + partition - self._name += "-partition{}".format(partition) - self._handler = None # type: Optional[SendClient] - self._outcome = None # type: Optional[constants.MessageSendResult] - self._condition = None # type: Optional[Exception] + self._name += f"-partition{partition}" + self._handler: Optional[uamqp_SendClient] = None + self._outcome: Optional[uamqp_constants.MessageSendResult] = None + self._condition: Optional[Exception] = None self._lock = threading.Lock() - self._link_properties = { - types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) - } - - def _create_handler(self, auth): - # type: (JWTTokenAuth) -> None - self._handler = SendClient( - self._target, + self._link_properties = self._amqp_transport.create_link_properties( + {TIMEOUT_SYMBOL: int(self._timeout * self._amqp_transport.TIMEOUT_FACTOR)} + ) + + def _create_handler( + self, auth: uamqp_JWTTokenAuth + ) -> None: + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout * 1000, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, properties=create_properties( - self._client._config.user_agent # pylint: disable=protected-access + self._client._config.user_agent, # pylint: disable=protected-access + amqp_transport=self._amqp_transport, ), + msg_timeout=self._timeout * 1000, ) - def _open_with_retry(self): - # type: () -> None + def _open_with_retry(self) -> None: return self._do_retryable_operation(self._open, operation_need_param=False) - def _set_msg_timeout(self, timeout_time, last_exception): - # type: (Optional[float], Optional[Exception]) -> None - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - - def _send_event_data(self, timeout_time=None, last_exception=None): - # type: (Optional[float], Optional[Exception]) -> None - if self._unsent_events: - self._open() - self._set_msg_timeout(timeout_time, last_exception) - self._handler.queue_message(*self._unsent_events) # type: ignore - self._handler.wait() # type: ignore - self._unsent_events = self._handler.pending_messages # type: ignore - if self._outcome != constants.MessageSendResult.Ok: - if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("Send operation timed out") - if self._condition: - raise self._condition - - def _send_event_data_with_retry(self, timeout=None): - # type: (Optional[float]) -> None - return self._do_retryable_operation(self._send_event_data, timeout=timeout) - - def _on_outcome(self, outcome, condition): - # type: (constants.MessageSendResult, Optional[Exception]) -> None + def _on_outcome( + self, + outcome: "uamqp_constants.MessageSendResult", + condition: Optional[Exception], + ) -> None: """ Called when the outcome is received for a delivery. @@ -186,19 +167,33 @@ def _on_outcome(self, outcome, condition): self._outcome = outcome self._condition = condition + def _send_event_data( + self, + timeout_time: Optional[float] = None, + last_exception: Optional[Exception] = None, + ) -> None: + if self._unsent_events: + self._amqp_transport.send_messages( + self, timeout_time, last_exception, _LOGGER + ) + + def _send_event_data_with_retry(self, timeout: Optional[float] = None) -> None: + return self._do_retryable_operation(self._send_event_data, timeout=timeout) + def _wrap_eventdata( self, - event_data, # type: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]] - span, # type: Optional[AbstractSpan] - partition_key, # type: Optional[AnyStr] - ): - # type: (...) -> Union[EventData, EventDataBatch] + event_data: Union[EventData, EventDataBatch, Iterable[EventData], AmqpAnnotatedMessage], + span: Optional["AbstractSpan"], + partition_key: Optional[AnyStr], + ) -> Union[EventData, EventDataBatch]: if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message( - event_data, EventData + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, partition_key # pylint: disable=protected-access + ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: @@ -217,30 +212,32 @@ def _wrap_eventdata( ) for ( event - ) in event_data.message._body_gen: # pylint: disable=protected-access + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key( + event_data, partition_key, self._amqp_transport + ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome + wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access + event_data, self._amqp_transport, partition_key=partition_key + ) return wrapper_event_data def send( self, - event_data, # type: Union[EventData, AmqpAnnotatedMessage, EventDataBatch, Iterable[EventData]] - partition_key=None, # type: Optional[AnyStr] - timeout=None, # type: Optional[float] - ): - # type:(...) -> None + event_data: Union[EventData, EventDataBatch, Iterable[EventData], AmqpAnnotatedMessage], + partition_key: Optional[AnyStr] = None, + timeout: Optional[float] = None, + ) -> None: """ Sends an event data and blocks until acknowledgement is received or operation times out. :param event_data: The event to be sent. It can be an EventData object, or iterable of EventData objects - :type event_data: ~azure.eventhub.common.EventData, Iterator, Generator, list + :type event_data: ~azure.eventhub.common.EventData, Iterator, Generator, list or AmqpAnnotatedMessage :param partition_key: With the given partition_key, event data will land to a particular partition of the Event Hub decided by the service. partition_key could be omitted if event_data is of type ~azure.eventhub.EventDataBatch. @@ -269,17 +266,14 @@ def send( if not wrapper_event_data: return - self._unsent_events = [wrapper_event_data.message] - + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access child ) - self._send_event_data_with_retry(timeout=timeout) - def close(self): - # type:() -> None + def close(self) -> None: """ Close down the handler. If the handler has already closed, this will be a no op. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 595e7736bb40..efebf97f5fe0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -19,12 +19,10 @@ ) from typing_extensions import Literal -from uamqp import constants - from ._client_base import ClientBase -from ._common import EventDataBatch, EventData -from ._constants import ALL_PARTITIONS from ._producer import EventHubProducer +from ._constants import ALL_PARTITIONS +from ._common import EventDataBatch, EventData from ._buffered_producer import BufferedProducerDispatcher from ._utils import set_event_partition_key from .amqp import AmqpAnnotatedMessage @@ -249,7 +247,7 @@ def _buffered_send(self, events, **kwargs): self._max_message_size_on_link, max_wait_time=self._max_wait_time, max_buffer_length=self._max_buffer_length, - executor=self._executor, + executor=self._executor ) self._buffered_producer_dispatcher.enqueue_events(events, **kwargs) @@ -296,7 +294,7 @@ def _buffered_send_batch(self, event_data_batch, **kwargs): def _buffered_send_event(self, event, **kwargs): partition_key = kwargs.get("partition_key") - set_event_partition_key(event, partition_key) + set_event_partition_key(event, partition_key, self._amqp_transport) timeout = kwargs.get("timeout") timeout_time = time.time() + timeout if timeout else None self._buffered_send( @@ -322,8 +320,10 @@ def _get_max_message_size(self): EventHubProducer, self._producers[ALL_PARTITIONS] )._open_with_retry() self._max_message_size_on_link = ( - self._producers[ALL_PARTITIONS]._handler.message_handler._link.peer_max_message_size # type: ignore - or constants.MAX_MESSAGE_LENGTH_BYTES + self._amqp_transport.get_remote_max_message_size( + self._producers[ALL_PARTITIONS]._handler # type: ignore + ) + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) def _start_producer(self, partition_id, send_timeout): @@ -364,6 +364,7 @@ def _create_producer(self, partition_id=None, send_timeout=None): partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport=self._amqp_transport, ) return handler @@ -477,6 +478,9 @@ def from_connection_string( If the port 5671 is unavailable/blocked in the network environment, `TransportType.AmqpOverWebsocket` could be used instead which uses port 443 for communication. :paramtype transport_type: ~azure.eventhub.TransportType + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. :keyword str custom_endpoint_address: The custom endpoint address to use for establishing a connection to the Event Hubs service, allowing network requests to be routed through any application gateways or other paths needed for the host environment. Default is None. @@ -723,7 +727,7 @@ def create_batch(self, **kwargs): event_data_batch = EventDataBatch( max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, - partition_key=partition_key, + partition_key=partition_key ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py new file mode 100644 index 000000000000..d67cceedcd40 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_base.py @@ -0,0 +1,289 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import annotations +from typing import Tuple, Union, TYPE_CHECKING +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from uamqp import types as uamqp_types + +class AmqpTransport(ABC): # pylint: disable=too-many-public-methods + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ + # define constants + MAX_FRAME_SIZE_BYTES: int + MAX_MESSAGE_LENGTH_BYTES: int + TIMEOUT_FACTOR: int + CONNECTION_CLOSING_STATES: Tuple + + # define symbols + PRODUCT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + VERSION_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + FRAMEWORK_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PLATFORM_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + USER_AGENT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PROP_PARTITION_KEY_AMQP_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + + @staticmethod + @abstractmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def build_batch_message(**kwargs): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message or pyamqp.Message message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + + @staticmethod + @abstractmethod + def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + + @staticmethod + @abstractmethod + def create_connection(**kwargs): + """ + Creates and returns the uamqp Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + + @staticmethod + @abstractmethod + def close_connection(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def get_connection_state(connection): + """ + Gets connection state. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def create_send_client(*, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + def set_message_partition_key(message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + + @staticmethod + @abstractmethod + def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + + @staticmethod + @abstractmethod + def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + + @staticmethod + @abstractmethod + def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + def open_receive_client(*, handler, client, auth): + """ + Opens the receive client. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + """ + + @staticmethod + @abstractmethod + def check_link_stolen(consumer, exception): + """ + Checks if link stolen and handles exception. + :param consumer: The EventHubConsumer. + :param exception: Exception to check. + """ + + @staticmethod + @abstractmethod + def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @staticmethod + @abstractmethod + def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @staticmethod + @abstractmethod + def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @staticmethod + @abstractmethod + def get_error(status_code, description): + """ + Gets error corresponding to status code. + :param status_code: Status code. + :param str description: Description of error. + """ + + @staticmethod + @abstractmethod + def check_timeout_exception(base, exception): + """ + Checks if timeout exception. + :param base: ClientBase. + :param exception: Exception to check. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py new file mode 100644 index 000000000000..018d3611aa72 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_uamqp_transport.py @@ -0,0 +1,645 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import time +import logging +from typing import Optional, Union, Any, Tuple + +try: + from uamqp import ( + c_uamqp, + BatchMessage, + constants, + MessageBodyType, + Message, + types, + SendClient, + ReceiveClient, + Source, + utils, + authentication, + AMQPClient, + compat, + errors, + Connection, + ) + from uamqp.message import ( + MessageHeader, + MessageProperties, + ) + uamqp_installed = True +except ImportError: + uamqp_installed = False + +from ._base import AmqpTransport +from ..amqp._constants import AmqpMessageBodyType +from .._constants import ( + NO_RETRY_ERRORS, + PROP_PARTITION_KEY, +) + +from ..exceptions import ( + ConnectError, + OperationTimeoutError, + EventHubError, + AuthenticationError, + ConnectionLostError, + EventDataError, + EventDataSendError, +) + +_LOGGER = logging.getLogger(__name__) + +if uamqp_installed: + def _error_handler(error): + """ + Called internally when an event has failed to send so we + can parse the error to determine whether we should attempt + to retry sending the event again. + Returns the action to take according to error type. + + :param error: The error received in the send attempt. + :type error: Exception + :rtype: ~uamqp.errors.ErrorAction + """ + if error.condition == b"com.microsoft:server-busy": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition == b"com.microsoft:timeout": + return errors.ErrorAction(retry=True, backoff=2) + if error.condition == b"com.microsoft:operation-cancelled": + return errors.ErrorAction(retry=True) + if error.condition == b"com.microsoft:container-close": + return errors.ErrorAction(retry=True, backoff=4) + if error.condition in NO_RETRY_ERRORS: + return errors.ErrorAction(retry=False) + return errors.ErrorAction(retry=True) + + + class UamqpTransport(AmqpTransport): # pylint: disable=too-many-public-methods + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + # define constants + MAX_FRAME_SIZE_BYTES = constants.MAX_FRAME_SIZE_BYTES + MAX_MESSAGE_LENGTH_BYTES = constants.MAX_MESSAGE_LENGTH_BYTES + TIMEOUT_FACTOR = 1000 + CONNECTION_CLOSING_STATES: Tuple = ( # pylint:disable=protected-access + c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member + c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member + ) + + # define symbols + PRODUCT_SYMBOL = types.AMQPSymbol("product") + VERSION_SYMBOL = types.AMQPSymbol("version") + FRAMEWORK_SYMBOL = types.AMQPSymbol("framework") + PLATFORM_SYMBOL = types.AMQPSymbol("platform") + USER_AGENT_SYMBOL = types.AMQPSymbol("user-agent") + PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) + + @staticmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + return Message(**kwargs) + + @staticmethod + def build_batch_message(**kwargs): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ + return BatchMessage(**kwargs) + + @staticmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message + """ + message_header = None + if annotated_message.header: + message_header = MessageHeader() + message_header.delivery_count = annotated_message.header.delivery_count + message_header.time_to_live = annotated_message.header.time_to_live + message_header.first_acquirer = annotated_message.header.first_acquirer + message_header.durable = annotated_message.header.durable + message_header.priority = annotated_message.header.priority + + message_properties = None + if annotated_message.properties: + message_properties = MessageProperties( + message_id=annotated_message.properties.message_id, + user_id=annotated_message.properties.user_id, + to=annotated_message.properties.to, + subject=annotated_message.properties.subject, + reply_to=annotated_message.properties.reply_to, + correlation_id=annotated_message.properties.correlation_id, + content_type=annotated_message.properties.content_type, + content_encoding=annotated_message.properties.content_encoding, + creation_time=int(annotated_message.properties.creation_time) + if annotated_message.properties.creation_time else None, + absolute_expiry_time=int(annotated_message.properties.absolute_expiry_time) + if annotated_message.properties.absolute_expiry_time else None, + group_id=annotated_message.properties.group_id, + group_sequence=annotated_message.properties.group_sequence, + reply_to_group_id=annotated_message.properties.reply_to_group_id, + encoding=annotated_message._encoding # pylint: disable=protected-access + ) + + # pylint: disable=protected-access + amqp_body_type = annotated_message.body_type + if amqp_body_type == AmqpMessageBodyType.DATA: + amqp_body_type = MessageBodyType.Data + amqp_body = list(annotated_message._data_body) + elif amqp_body_type == AmqpMessageBodyType.SEQUENCE: + amqp_body_type = MessageBodyType.Sequence + amqp_body = list(annotated_message._sequence_body) + else: + amqp_body_type = MessageBodyType.Value + amqp_body = annotated_message._value_body + + return Message( + body=amqp_body, + body_type=amqp_body_type, + header=message_header, + properties=message_properties, + application_properties=annotated_message.application_properties, + annotations=annotated_message.annotations, + delivery_annotations=annotated_message.delivery_annotations, + footer=annotated_message.footer + ) + + @staticmethod + def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + return message.gather()[0].get_message_encoded_size() + + @staticmethod + def get_message_encoded_size(message): + """ + Gets the message encoded size given an underlying Message. + :param uamqp.Message message: Message to get encoded size of. + :rtype: int + """ + return message.get_message_encoded_size() + + @staticmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + return handler.message_handler._link.peer_max_message_size # pylint:disable=protected-access + + @staticmethod + def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + return errors.ErrorPolicy(max_retries=config.max_retries, on_error=_error_handler) + + @staticmethod + def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + return {types.AMQPSymbol(symbol): types.AMQPLong(value) for (symbol, value) in link_properties.items()} + + @staticmethod + def create_connection(**kwargs): + """ + Creates and returns the uamqp Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + endpoint = kwargs.pop("endpoint") # pylint:disable=unused-variable + host = kwargs.pop("host") + auth = kwargs.pop("auth") + return Connection( + host, + auth, + **kwargs + ) + + @staticmethod + def close_connection(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + connection.destroy() + + @staticmethod + def get_connection_state(connection): + """ + Gets connection state. + :param connection: uamqp or pyamqp Connection. + """ + return connection._state # pylint:disable=protected-access + + @staticmethod + def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClient( + target, + debug=network_trace, + error_policy=retry_policy, + **kwargs + ) + + @staticmethod + def _set_msg_timeout(producer, timeout_time, last_exception, logger): + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + logger.info("%r send operation timed out. (%r)", producer._name, error) # pylint: disable=protected-access + raise error + producer._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access + + @staticmethod + def send_messages(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + UamqpTransport._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + producer._handler.wait() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + @staticmethod + def set_message_partition_key(message, partition_key, **kwargs): # pylint:disable=unused-argument + # type: (Message, Optional[Union[bytes, str]], Any) -> Message + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param str partition_key: The partition key value. + :rtype: Message + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = {} + annotations[ + UamqpTransport.PROP_PARTITION_KEY_AMQP_SYMBOL # TODO: see if setting non-amqp symbol is valid + ] = partition_key + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + return message + + @staticmethod + def add_batch(batch_message, outgoing_event_data, event_data): + """ + Add EventData to the data body of the BatchMessage. + :param batch_message: BatchMessage to add data to. + :param outgoing_event_data: Transformed EventData for sending. + :param event_data: EventData to add to internal batch events. uamqp use only. + :rtype: None + """ + # pylint: disable=protected-access + batch_message._internal_events.append(event_data) + batch_message._message._body_gen.append( + outgoing_event_data._message + ) + + @staticmethod + def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + source = Source(source) + if offset is not None: + source.set_filter(selector) + return source + + @staticmethod + def create_receive_client(*, config, **kwargs): # pylint: disable=unused-argument + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClient( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + @staticmethod + def open_receive_client(*, handler, client, auth): + """ + Opens the receive client and returns ready status. + :param ReceiveClient handler: The receive client. + :param ~azure.eventhub.EventHubConsumerClient client: The consumer client. + :param auth: Auth. + :rtype: bool + """ + # pylint:disable=protected-access + handler.open(connection=client._conn_manager.get_connection( + client._address.hostname, auth + )) + + @staticmethod + def check_link_stolen(consumer, exception): + """ + Checks if link stolen and handles exception. + :param consumer: The EventHubConsumer. + :param exception: Exception to check. + """ + if ( + isinstance(exception, errors.LinkDetach) + and exception.condition == constants.ErrorCodes.LinkStolen # pylint: disable=no-member + ): + raise consumer._handle_exception(exception) # pylint: disable=protected-access + + @staticmethod + def create_token_auth(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAuth( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + token_auth.update_token() + return token_auth + + @staticmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClient( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + @staticmethod + def get_updated_token(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + @staticmethod + def mgmt_client_request(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return mgmt_client.mgmt_request( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) + + @staticmethod + def get_error(status_code, description): + """ + Gets error corresponding to status code. + :param status_code: Status code. + :param str description: Description of error. + """ + if status_code in [401]: + return errors.AuthenticationException( + f"Management authentication failed. Status code: {status_code}, Description: {description!r}" + ) + if status_code in [404]: + return ConnectError( + f"Management connection failed. Status code: {status_code}, Description: {description!r}" + ) + return errors.AMQPConnectionError( + f"Management request error. Status code: {status_code}, Description: {description!r}" + ) + + @staticmethod + def check_timeout_exception(base, exception): + """ + Checks if timeout exception. + :param base: ClientBase. + :param exception: Exception to check. + """ + if not base.running and isinstance( + exception, compat.TimeoutException + ): + exception = UamqpTransport.get_error( + errors.AuthenticationException, + "Authorization timeout." + ) + return exception + + @staticmethod + def _create_eventhub_exception(exception): + if isinstance(exception, errors.AuthenticationException): + error = AuthenticationError(str(exception), exception) + elif isinstance(exception, errors.VendorLinkDetach): + error = ConnectError(str(exception), exception) + elif isinstance(exception, errors.LinkDetach): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.ConnectionClose): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.MessageHandlerError): + error = ConnectionLostError(str(exception), exception) + elif isinstance(exception, errors.AMQPConnectionError): + error_type = ( + AuthenticationError + if str(exception).startswith("Unable to open authentication session") + else ConnectError + ) + error = error_type(str(exception), exception) + elif isinstance(exception, compat.TimeoutException): + error = ConnectionLostError(str(exception), exception) + else: + error = EventHubError(str(exception), exception) + return error + + @staticmethod + def _handle_exception( + exception, closable + ): # pylint:disable=too-many-branches, too-many-statements + try: # closable is a producer/consumer object + name = closable._name # pylint: disable=protected-access + except AttributeError: # closable is an client object + name = closable._container_id # pylint: disable=protected-access + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + closable._close_connection() # pylint:disable=protected-access + raise exception + elif isinstance(exception, EventHubError): + closable._close_handler() # pylint:disable=protected-access + raise exception + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + if isinstance(exception, errors.AuthenticationException): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.LinkDetach): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + elif isinstance(exception, errors.ConnectionClose): + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + elif isinstance(exception, errors.MessageHandlerError): + if hasattr(closable, "_close_handler"): + closable._close_handler() # pylint:disable=protected-access + else: # errors.AMQPConnectionError, compat.TimeoutException + if hasattr(closable, "_close_connection"): + closable._close_connection() # pylint:disable=protected-access + return UamqpTransport._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py index 744ebfaf3df0..fc9b1bc8c9c9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_utils.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations from contextlib import contextmanager import sys @@ -10,8 +10,14 @@ import datetime import calendar import logging +from base64 import b64encode +from hashlib import sha256 +from hmac import HMAC +from urllib.parse import urlencode, quote_plus +import time from typing import ( TYPE_CHECKING, + cast, Type, Optional, Dict, @@ -20,12 +26,11 @@ Iterable, Tuple, Mapping, + Callable ) import six - -from uamqp import types -from uamqp.message import MessageHeader +from uamqp import types as uamqp_types from azure.core.settings import settings from azure.core.tracing import SpanKind, Link @@ -33,7 +38,6 @@ from .amqp import AmqpAnnotatedMessage, AmqpMessageHeader from ._version import VERSION from ._constants import ( - PROP_PARTITION_KEY_AMQP_SYMBOL, MAX_USER_AGENT_LENGTH, USER_AGENT_PREFIX, PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, @@ -43,9 +47,10 @@ PROP_TIMESTAMP, ) + if TYPE_CHECKING: # pylint: disable=ungrouped-imports - from uamqp import Message + from ._transport._base import AmqpTransport from azure.core.tracing import AbstractSpan from azure.core.credentials import AzureSasCredential from ._common import EventData @@ -87,43 +92,52 @@ def utc_from_timestamp(timestamp): return datetime.datetime.fromtimestamp(timestamp, tz=TZ_UTC) -def create_properties(user_agent=None): - # type: (Optional[str]) -> Dict[types.AMQPSymbol, str] +def create_properties( + user_agent: Optional[str] = None, *, amqp_transport: AmqpTransport +) -> Dict[uamqp_types.AMQPSymbol, str]: """ Format the properties with which to instantiate the connection. This acts like a user agent over HTTP. :rtype: dict """ - properties = {} - properties[types.AMQPSymbol("product")] = USER_AGENT_PREFIX - properties[types.AMQPSymbol("version")] = VERSION - framework = "Python/{}.{}.{}".format( - sys.version_info[0], sys.version_info[1], sys.version_info[2] - ) - properties[types.AMQPSymbol("framework")] = framework + properties: Dict[Any, str] = {} + properties[amqp_transport.PRODUCT_SYMBOL] = USER_AGENT_PREFIX + properties[amqp_transport.VERSION_SYMBOL] = VERSION + framework = f"Python/{sys.version_info[0]}.{sys.version_info[1]}.{sys.version_info[2]}" + properties[amqp_transport.FRAMEWORK_SYMBOL] = framework platform_str = platform.platform() - properties[types.AMQPSymbol("platform")] = platform_str + properties[amqp_transport.PLATFORM_SYMBOL] = platform_str - final_user_agent = "{}/{} {} ({})".format( - USER_AGENT_PREFIX, VERSION, framework, platform_str - ) + final_user_agent = f"{USER_AGENT_PREFIX}/{VERSION} {framework} ({platform_str})" if user_agent: - final_user_agent = "{} {}".format(user_agent, final_user_agent) + final_user_agent = f"{user_agent} {final_user_agent}" if len(final_user_agent) > MAX_USER_AGENT_LENGTH: raise ValueError( - "The user-agent string cannot be more than {} in length." - "Current user_agent string is: {} with length: {}".format( - MAX_USER_AGENT_LENGTH, final_user_agent, len(final_user_agent) - ) + f"The user-agent string cannot be more than {MAX_USER_AGENT_LENGTH} in length." + f"Current user_agent string is: {final_user_agent} with length: {len(final_user_agent)}" ) - properties[types.AMQPSymbol("user-agent")] = final_user_agent + properties[amqp_transport.USER_AGENT_SYMBOL] = final_user_agent return properties -def set_event_partition_key(event, partition_key): - # type: (Union[AmqpAnnotatedMessage, EventData], Optional[Union[bytes, str]]) -> None +@contextmanager +def send_context_manager(): + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + + if span_impl_type is not None: + with span_impl_type(name="Azure.EventHubs.send", kind=SpanKind.CLIENT) as child: + yield child + else: + yield None + + +def set_event_partition_key( + event: Union[AmqpAnnotatedMessage, EventData], + partition_key: Optional[Union[bytes, str]], + amqp_transport: AmqpTransport +) -> None: if not partition_key: return @@ -134,9 +148,9 @@ def set_event_partition_key(event, partition_key): annotations = raw_message.annotations if annotations is None: - annotations = dict() + annotations = {} annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL + amqp_transport.PROP_PARTITION_KEY_AMQP_SYMBOL ] = partition_key # pylint:disable=protected-access if not raw_message.header: raw_message.header = AmqpMessageHeader(header=True) @@ -144,38 +158,6 @@ def set_event_partition_key(event, partition_key): raw_message.header.durable = True -def set_message_partition_key(message, partition_key): - # type: (Message, Optional[Union[bytes, str]]) -> None - """Set the partition key as an annotation on a uamqp message. - - :param ~uamqp.Message message: The message to update. - :param str partition_key: The partition key value. - :rtype: None - """ - if partition_key: - annotations = message.annotations - if annotations is None: - annotations = dict() - annotations[ - PROP_PARTITION_KEY_AMQP_SYMBOL - ] = partition_key # pylint:disable=protected-access - header = MessageHeader() - header.durable = True - message.annotations = annotations - message.header = header - - -@contextmanager -def send_context_manager(): - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - - if span_impl_type is not None: - with span_impl_type(name="Azure.EventHubs.send", kind=SpanKind.CLIENT) as child: - yield child - else: - yield None - - def trace_message(event, parent_span=None): # type: (EventData, Optional[AbstractSpan]) -> None """Add tracing information to this event. @@ -236,15 +218,13 @@ def event_position_selector(value, inclusive=False): value.microsecond / 1000 ) return ( - "amqp.annotation.x-opt-enqueued-time {} '{}'".format( - operator, int(timestamp) - ) + f"amqp.annotation.x-opt-enqueued-time {operator} '{int(timestamp)}'" ).encode("utf-8") elif isinstance(value, six.integer_types): return ( - "amqp.annotation.x-opt-sequence-number {} '{}'".format(operator, value) + f"amqp.annotation.x-opt-sequence-number {operator} '{value}'" ).encode("utf-8") - return ("amqp.annotation.x-opt-offset {} '{}'".format(operator, value)).encode( + return (f"amqp.annotation.x-opt-offset {operator} '{value}'").encode( "utf-8" ) @@ -259,23 +239,23 @@ def get_last_enqueued_event_properties(event_data): if event_data._last_enqueued_event_properties: return event_data._last_enqueued_event_properties - if event_data.message.delivery_annotations: - sequence_number = event_data.message.delivery_annotations.get( + if event_data._message.delivery_annotations: + sequence_number = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, None ) - enqueued_time_stamp = event_data.message.delivery_annotations.get( + enqueued_time_stamp = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_TIME_UTC, None ) if enqueued_time_stamp: enqueued_time_stamp = utc_from_timestamp(float(enqueued_time_stamp) / 1000) - retrieval_time_stamp = event_data.message.delivery_annotations.get( + retrieval_time_stamp = event_data._message.delivery_annotations.get( PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, None ) if retrieval_time_stamp: retrieval_time_stamp = utc_from_timestamp( float(retrieval_time_stamp) / 1000 ) - offset_bytes = event_data.message.delivery_annotations.get( + offset_bytes = event_data._message.delivery_annotations.get( PROP_LAST_ENQUEUED_OFFSET, None ) offset = offset_bytes.decode("UTF-8") if offset_bytes else None @@ -301,8 +281,8 @@ def parse_sas_credential(credential): return (sas, expiry) -def transform_outbound_single_message(message, message_type): - # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData]) -> EventData +def transform_outbound_single_message(message, message_type, to_outgoing_amqp_message): + # type: (Union[AmqpAnnotatedMessage, EventData], Type[EventData], Callable) -> EventData """ This method serves multiple goals: 1. update the internal message to reflect any updates to settable properties on EventData @@ -314,14 +294,20 @@ def transform_outbound_single_message(message, message_type): :rtype: EventData """ try: - # EventData # pylint: disable=protected-access - return message._to_outgoing_message() # type: ignore + # If EventData, set EventData._message to uamqp/pyamqp.Message right before sending. + message = cast("EventData", message) + message._message = to_outgoing_amqp_message(message.raw_amqp_message) + return message # type: ignore except AttributeError: - # AmqpAnnotatedMessage # pylint: disable=protected-access + # If AmqpAnnotatedMessage, create EventData object with _from_message. + # event_data._message will be set to outgoing uamqp/pyamqp.Message. + # event_data.raw_amqp_message will be set to AmqpAnnotatedMessage. + message = cast(AmqpAnnotatedMessage, message) + amqp_message = to_outgoing_amqp_message(message) return message_type._from_message( - message=message._to_outgoing_amqp_message(), raw_amqp_message=message # type: ignore + message=amqp_message, raw_amqp_message=message # type: ignore ) @@ -359,3 +345,32 @@ def decode_with_recurse(data, encoding="UTF-8"): return decoded_list return data + + +def generate_sas_token(audience, policy, key, expiry=None): + """ + Generate a sas token according to the given audience, policy, key and expiry + :param str audience: + :param str policy: + :param str key: + :param int expiry: abs expiry time + :rtype: str + """ + if not expiry: + expiry = int(time.time()) + 3600 # Default to 1 hour. + + encoded_uri = quote_plus(audience) + encoded_policy = quote_plus(policy).encode("utf-8") + encoded_key = key.encode("utf-8") + + ttl = int(expiry) + sign_key = '%s\n%d' % (encoded_uri, ttl) + signature = b64encode(HMAC(encoded_key, sign_key.encode('utf-8'), sha256).digest()) + result = { + 'sr': audience, + 'sig': signature, + 'se': str(ttl) + } + if policy: + result['skn'] = encoded_policy + return 'SharedAccessSignature ' + urlencode(result) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py index 90590263024f..2c2a7158a838 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import queue @@ -14,6 +15,7 @@ from ...exceptions import OperationTimeoutError if TYPE_CHECKING: + from .._transport._base_async import AmqpTransportAsync from ..._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py index ecae49098086..d3f2135ff170 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_dispatcher_async.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging from typing import Dict, List, Callable, Optional, Awaitable, TYPE_CHECKING @@ -13,6 +14,7 @@ from ...exceptions import EventDataSendError, ConnectError, EventHubError if TYPE_CHECKING: + from .._transport._base_async import AmqpTransportAsync from ..._producer_client import SendEventTypes _LOGGER = logging.getLogger(__name__) @@ -32,7 +34,7 @@ def __init__( max_message_size_on_link: int, *, max_buffer_length: int = 1500, - max_wait_time: float = 1, + max_wait_time: float = 1 ): self._buffered_producers: Dict[str, BufferedProducer] = {} self._partition_ids: List[str] = partitions diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 5a5a312485c4..e9becf77a884 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals +from __future__ import unicode_literals, annotations import logging import asyncio @@ -10,15 +10,6 @@ import functools from typing import TYPE_CHECKING, Any, Dict, List, Callable, Optional, Union, cast -import six -from uamqp import ( - authentication, - constants, - errors, - compat, - Message, - AMQPClientAsync, -) from azure.core.credentials import ( AccessToken, AzureSasCredential, @@ -32,19 +23,25 @@ _get_backoff_time, ) from .._utils import utc_from_timestamp, parse_sas_credential -from ..exceptions import ClientClosedError, ConnectError +from ..exceptions import ClientClosedError from .._constants import ( JWT_TOKEN_SCOPE, MGMT_OPERATION, MGMT_PARTITION_OPERATION, MGMT_STATUS_CODE, MGMT_STATUS_DESC, + READ_OPERATION, ) from ._async_utils import get_dict_with_loop_if_needed from ._connection_manager_async import get_connection_manager -from ._error_async import _handle_exception +from ._transport._uamqp_transport_async import UamqpTransportAsync if TYPE_CHECKING: + from uamqp import ( + authentication, + Message, + AMQPClientAsync, + ) from azure.core.credentials_async import AsyncTokenCredential CredentialTypes = Union[ @@ -211,6 +208,8 @@ def __init__( **kwargs: Any ) -> None: self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) + uamqp_transport = kwargs.pop("uamqp_transport", True) + self._amqp_transport = UamqpTransportAsync if isinstance(credential, AzureSasCredential): self._credential = EventhubAzureSasTokenCredentialAsync(credential) # type: ignore elif isinstance(credential, AzureNamedKeyCredential): @@ -221,6 +220,8 @@ def __init__( fully_qualified_namespace=fully_qualified_namespace, eventhub_name=eventhub_name, credential=self._credential, + uamqp_transport=uamqp_transport, + amqp_transport=self._amqp_transport, **kwargs ) self._conn_manager_async = get_connection_manager(**kwargs) @@ -255,32 +256,19 @@ async def _create_auth_async(self) -> authentication.JWTTokenAsync: except AttributeError: token_type = b"jwt" if token_type == b"servicebus.windows.net:sastoken": - auth = authentication.JWTTokenAsync( - self._auth_uri, + return await self._amqp_transport.create_token_auth_async( self._auth_uri, functools.partial(self._credential.get_token, self._auth_uri), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, - refresh_window=300, + config=self._config, + update_token=True, ) - await auth.update_token() - return auth - return authentication.JWTTokenAsync( - self._auth_uri, + return await self._amqp_transport.create_token_auth_async( self._auth_uri, functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, - timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, - custom_endpoint_hostname=self._config.custom_endpoint_hostname, - port=self._config.connection_port, - verify=self._config.connection_verify, + config=self._config, + update_token=False, ) async def _close_connection_async(self) -> None: @@ -322,19 +310,24 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> last_exception = None while retried_times <= self._config.max_retries: mgmt_auth = await self._create_auth_async() - mgmt_client = AMQPClientAsync( - self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing + mgmt_client = self._amqp_transport.create_mgmt_client( + self._address, mgmt_auth=mgmt_auth, config=self._config ) try: conn = await self._conn_manager_async.get_connection( - self._address.hostname, mgmt_auth + host=self._address.hostname, auth=mgmt_auth ) - mgmt_msg.application_properties["security_token"] = mgmt_auth.token await mgmt_client.open_async(connection=conn) - response = await mgmt_client.mgmt_request_async( + while not await mgmt_client.client_ready_async(): + await asyncio.sleep(0.05) + mgmt_msg.application_properties[ + "security_token" + ] = await self._amqp_transport.get_updated_token_async(mgmt_auth) + response = await self._amqp_transport.mgmt_client_request_async( + mgmt_client, mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, + operation=READ_OPERATION, + operation_type=op_type, status_code_field=MGMT_STATUS_CODE, description_fields=MGMT_STATUS_DESC, ) @@ -342,31 +335,15 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> description = response.application_properties.get( MGMT_STATUS_DESC ) # type: Optional[Union[str, bytes]] - if description and isinstance(description, six.binary_type): + if description and isinstance(description, bytes): description = description.decode("utf-8") if status_code < 400: return response - if status_code in [401]: - raise errors.AuthenticationException( - "Management authentication failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - if status_code in [404]: - raise ConnectError( - "Management connection failed. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) - raise errors.AMQPConnectionError( - "Management request error. Status code: {}, Description: {!r}".format( - status_code, description - ) - ) + raise self._amqp_transport.get_error(status_code, description) except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as exception: # pylint:disable=broad-except - last_exception = await _handle_exception(exception, self) + last_exception = await self._amqp_transport._handle_exception_async(exception, self) # pylint: disable=protected-access await self._backoff_async( retried_times=retried_times, last_exception=last_exception ) @@ -380,12 +357,14 @@ async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> await mgmt_client.close_async() async def _get_eventhub_properties_async(self) -> Dict[str, Any]: - mgmt_msg = Message(application_properties={"name": self.eventhub_name}) + mgmt_msg = mgmt_msg = self._amqp_transport.build_message( + application_properties={"name": self.eventhub_name} + ) response = await self._management_request_async( mgmt_msg, op_type=MGMT_OPERATION ) output = {} - eh_info = response.get_data() # type: Dict[bytes, Any] + eh_info: Dict[bytes, Any] = response.value if eh_info: output["eventhub_name"] = eh_info[b"name"].decode("utf-8") output["created_at"] = utc_from_timestamp( @@ -402,7 +381,7 @@ async def _get_partition_ids_async(self) -> List[str]: async def _get_partition_properties_async( self, partition_id: str ) -> Dict[str, Any]: - mgmt_msg = Message( + mgmt_msg = self._amqp_transport.build_message( application_properties={ "name": self.eventhub_name, "partition": partition_id, @@ -411,7 +390,7 @@ async def _get_partition_properties_async( response = await self._management_request_async( mgmt_msg, op_type=MGMT_PARTITION_OPERATION ) - partition_info = response.get_data() # type: Dict[bytes, Union[bytes, int]] + partition_info = response.value # type: Dict[bytes, Union[bytes, int]] output = {} # type: Dict[str, Any] if partition_info: output["eventhub_name"] = cast(bytes, partition_info[b"name"]).decode( @@ -463,16 +442,16 @@ async def _open(self) -> None: await self._handler.close_async() auth = await self._client._create_auth_async() self._create_handler(auth) - await self._handler.open_async( - connection=await self._client._conn_manager_async.get_connection( - self._client._address.hostname, auth - ) + conn = await self._client._conn_manager_async.get_connection( + host=self._client._address.hostname, auth=auth ) + await self._handler.open_async(connection=conn) while not await self._handler.client_ready_async(): await asyncio.sleep(0.05, **self._internal_kwargs) + # pylint: disable=protected-access self._max_message_size_on_link = ( - self._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES + self._client._amqp_transport.get_remote_max_message_size(self._handler) + or self._client._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) self.running = True @@ -487,11 +466,11 @@ async def _close_connection_async(self) -> None: await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access async def _handle_exception(self, exception: Exception) -> Exception: - if not self.running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return await _handle_exception(exception, self) - - return await _handle_exception(exception, self) + # pylint: disable=protected-access + exception = self._client._amqp_transport.check_timeout_exception(self, exception) + return await self._client._amqp_transport._handle_exception_async( + exception, self + ) async def _do_retryable_operation( self, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index d53443f12fd7..ec3c0fffaebf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -3,12 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +from asyncio import Lock -from uamqp.async_ops import ConnectionAsync +from ._transport._uamqp_transport_async import UamqpTransportAsync +from .._connection_manager import _ConnectionMode +from .._constants import TransportType if TYPE_CHECKING: from uamqp.authentication import JWTTokenAsync + from uamqp.async_ops import ConnectionAsync try: from typing_extensions import Protocol @@ -17,7 +22,7 @@ class ConnectionManager(Protocol): async def get_connection( - self, host: str, auth: "JWTTokenAsync" + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAsync] = None, endpoint: Optional[str] = None ) -> ConnectionAsync: pass @@ -28,11 +33,69 @@ async def reset_connection_if_broken(self) -> None: pass +class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes + def __init__(self, **kwargs) -> None: + self._loop = kwargs.get("loop") + self._lock = Lock(loop=self._loop) + self._conn = None + + self._container_id = kwargs.get("container_id") + self._debug = kwargs.get("debug") + self._error_policy = kwargs.get("error_policy") + self._properties = kwargs.get("properties") + self._encoding = kwargs.get("encoding") or "UTF-8" + self._transport_type = kwargs.get("transport_type") or TransportType.Amqp + self._http_proxy = kwargs.get("http_proxy") + self._max_frame_size = kwargs.get("max_frame_size") + self._channel_max = kwargs.get("channel_max") + self._idle_timeout = kwargs.get("idle_timeout") + self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( + "remote_idle_timeout_empty_frame_send_ratio" + ) + self._amqp_transport = kwargs.get("amqp_transport", UamqpTransportAsync) + + async def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAsync] = None, endpoint: Optional[str] = None + ) -> ConnectionAsync: + async with self._lock: + if self._conn is None: + self._conn = self._amqp_transport.create_connection_async( + host=host, + auth=auth, + endpoint=endpoint, + container_id=self._container_id, + max_frame_size=self._max_frame_size, + channel_max=self._channel_max, + idle_timeout=self._idle_timeout, + properties=self._properties, + remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, + error_policy=self._error_policy, + debug=self._debug, + loop=self._loop, + encoding=self._encoding, + ) + return self._conn + + async def close_connection(self) -> None: + async with self._lock: + if self._conn: + await self._amqp_transport.close_connection_async(self._conn) + self._conn = None + + async def reset_connection_if_broken(self) -> None: + async with self._lock: + conn_state = self._amqp_transport.get_connection_state(self._conn) + if self._conn and conn_state in self._amqp_transport.CONNECTION_CLOSING_STATES: + self._conn = None + + class _SeparateConnectionManager(object): def __init__(self, **kwargs) -> None: pass - async def get_connection(self, host: str, auth: "JWTTokenAsync") -> None: + async def get_connection( + self, *, host: Optional[str] = None, auth: Optional[JWTTokenAsync] = None, endpoint: Optional[str] = None + ) -> None: pass # return None async def close_connection(self) -> None: @@ -43,4 +106,7 @@ async def reset_connection_if_broken(self) -> None: def get_connection_manager(**kwargs) -> "ConnectionManager": + connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) # type: ignore + if connection_mode == _ConnectionMode.ShareConnection: + return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index afdcad0ad9e2..c918bca559c8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -2,26 +2,22 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import time -import asyncio +from __future__ import annotations import uuid import logging from collections import deque -from typing import TYPE_CHECKING, Callable, Awaitable, cast, Dict, Optional, Union, List - -import uamqp -from uamqp import errors, types, utils -from uamqp import ReceiveClientAsync, Source +from typing import TYPE_CHECKING, Callable, Awaitable, Dict, Optional, Union, List from ._client_base_async import ConsumerProducerMixin from ._async_utils import get_dict_with_loop_if_needed from .._common import EventData -from ..exceptions import _error_handler from .._utils import create_properties, event_position_selector from .._constants import EPOCH_SYMBOL, TIMEOUT_SYMBOL, RECEIVER_RUNTIME_METRIC_SYMBOL if TYPE_CHECKING: from typing import Deque + import uamqp + from uamqp import ReceiveClientAsync, Source, types from uamqp.authentication import JWTTokenAsync from ._consumer_client_async import EventHubConsumerClient @@ -79,9 +75,10 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self.running = False self.closed = False - self._on_event_received = kwargs[ + self._amqp_transport = kwargs.pop("amqp_transport") + self._on_event_received: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]] = kwargs[ "on_event_received" - ] # type: Callable[[Union[Optional[EventData], List[EventData]]], Awaitable[None]] + ] self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None)) self._client = client self._source = source @@ -91,81 +88,64 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N self._owner_level = owner_level self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access - ) + self._retry_policy = self._amqp_transport.create_retry_policy(self._client._config) self._reconnect_backoff = 1 self._timeout = 0 - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - self._link_properties = {} # type: Dict[types.AMQPType, types.AMQPType] - partition = self._source.split("/")[-1] - self._partition = partition - self._name = "EHReceiver-{}-partition{}".format(uuid.uuid4(), partition) + self._idle_timeout = (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) if idle_timeout else None + link_properties: Dict[types.AMQPType, types.AMQPType] = {} + self._partition = self._source.split("/")[-1] + self._name = f"EHReceiver-{uuid.uuid4()}-partition{self._partition}" if owner_level is not None: - self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong( - int(owner_level) - ) + link_properties[EPOCH_SYMBOL] = int(owner_level) link_property_timeout_ms = ( self._client._config.receive_timeout or self._timeout # pylint:disable=protected-access - ) * 1000 - self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong( - int(link_property_timeout_ms) - ) - self._handler = None # type: Optional[ReceiveClientAsync] + ) * self._amqp_transport.TIMEOUT_FACTOR + link_properties[TIMEOUT_SYMBOL] = int(link_property_timeout_ms) + self._link_properties = self._amqp_transport.create_link_properties(link_properties) + self._handler: Optional[ReceiveClientAsync] = None self._track_last_enqueued_event_properties = ( track_last_enqueued_event_properties ) - self._message_buffer = deque() # type: Deque[uamqp.Message] - self._last_received_event = None # type: Optional[EventData] - - def _create_handler(self, auth: "JWTTokenAsync") -> None: - source = Source(self._source) - if self._offset is not None: - source.set_filter( - event_position_selector(self._offset, self._offset_inclusive) - ) - desired_capabilities = None - if self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] - desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) - - properties = create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._message_buffer: Deque[uamqp.Message] = deque() + self._last_received_event: Optional[EventData] = None + + def _create_handler(self, auth: JWTTokenAsync) -> None: + source = self._amqp_transport.create_source( + self._source, + self._offset, + event_position_selector(self._offset, self._offset_inclusive) ) - self._handler = ReceiveClientAsync( - source, + desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None + + self._handler = self._amqp_transport.create_receive_client( + config=self._client._config, # pylint:disable=protected-access + source=source, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + link_credit=self._prefetch, link_properties=self._link_properties, - timeout=self._timeout, idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=properties, + properties=create_properties( + self._client._config.user_agent, amqp_transport=self._amqp_transport # pylint:disable=protected-access + ), desired_capabilities=desired_capabilities, - **self._internal_kwargs - ) - - self._handler._streaming_receive = True # pylint:disable=protected-access - self._handler._message_received_callback = ( # pylint:disable=protected-access - self._message_received + streaming_receive=True, + message_received_callback=self._message_received, ) async def _open_with_retry(self) -> None: await self._do_retryable_operation(self._open, operation_need_param=False) def _message_received(self, message: uamqp.Message) -> None: - self._message_buffer.appendleft(message) + self._message_buffer.append(message) def _next_message_in_buffer(self): # pylint:disable=protected-access - message = self._message_buffer.pop() + message = self._message_buffer.popleft() event_data = EventData._from_message(message) self._last_received_event = event_data return event_data @@ -173,57 +153,4 @@ def _next_message_in_buffer(self): async def receive( self, batch=False, max_batch_size=300, max_wait_time=None ) -> None: - max_retries = ( - self._client._config.max_retries # pylint:disable=protected-access - ) - has_not_fetched_once = True # ensure one trip when max_wait_time is very small - deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None - while len(self._message_buffer) < max_batch_size and ( - time.time() < deadline or has_not_fetched_once - ): - retried_times = 0 - has_not_fetched_once = False - while retried_times <= max_retries: - try: - await self._open() - await cast( - ReceiveClientAsync, self._handler - ).do_work_async() # uamqp sleeps 0.05 if none received - break - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except Exception as exception: # pylint: disable=broad-except - if ( - isinstance(exception, uamqp.errors.LinkDetach) - and exception.condition # pylint: disable=no-member - == uamqp.constants.ErrorCodes.LinkStolen - ): - raise await self._handle_exception(exception) - if not self.running: # exit by close - return - if self._last_received_event: - self._offset = self._last_received_event.offset - last_exception = await self._handle_exception(exception) - retried_times += 1 - if retried_times > max_retries: - _LOGGER.info( - "%r operation has exhausted retry. Last exception: %r.", - self._name, - last_exception, - ) - raise last_exception - - if self._message_buffer: - while self._message_buffer: - if batch: - events_for_callback = [] # type: List[EventData] - for _ in range(min(max_batch_size, len(self._message_buffer))): - events_for_callback.append(self._next_message_in_buffer()) - await self._on_event_received(events_for_callback) - else: - await self._on_event_received(self._next_message_in_buffer()) - elif max_wait_time: - if batch: - await self._on_event_received([]) - else: - await self._on_event_received(None) + await self._amqp_transport.receive_messages(self, batch, max_batch_size, max_wait_time) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py index 20c3468bd8b1..73ef31f79ac4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import asyncio import logging import datetime @@ -21,13 +22,12 @@ from ._eventprocessor.event_processor import EventProcessor from ._consumer_async import EventHubConsumer from ._client_base_async import ClientBaseAsync -from .._constants import ALL_PARTITIONS +from .._constants import ALL_PARTITIONS, TransportType from .._eventprocessor.common import LoadBalancingStrategy if TYPE_CHECKING: from ._client_base_async import CredentialTypes - from uamqp.constants import TransportType from ._eventprocessor.partition_context import PartitionContext from ._eventprocessor.checkpoint_store import CheckpointStore from .._common import EventData @@ -215,6 +215,7 @@ def _create_consumer( prefetch=prefetch, idle_timeout=self._idle_timeout, track_last_enqueued_event_properties=track_last_enqueued_event_properties, + amqp_transport=self._amqp_transport, **self._internal_kwargs, ) return handler @@ -231,7 +232,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, + transport_type: TransportType = TransportType.Amqp, checkpoint_store: Optional["CheckpointStore"] = None, load_balancing_interval: float = 10, **kwargs: Any diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py deleted file mode 100644 index e272f496ec81..000000000000 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_error_async.py +++ /dev/null @@ -1,74 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -import asyncio -import logging -from typing import TYPE_CHECKING, Union, cast - -from uamqp import errors - -from ..exceptions import ( - _create_eventhub_exception, - EventHubError, - EventDataSendError, - EventDataError, -) - -if TYPE_CHECKING: - from ._client_base_async import ClientBaseAsync, ConsumerProducerMixin - -_LOGGER = logging.getLogger(__name__) - - -async def _handle_exception( # pylint:disable=too-many-branches, too-many-statements - exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] -) -> Exception: - # pylint: disable=protected-access - if isinstance(exception, asyncio.CancelledError): - raise exception - error = exception - try: - name = cast("ConsumerProducerMixin", closable)._name - except AttributeError: - name = cast("ClientBaseAsync", closable)._container_id - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - await cast("ConsumerProducerMixin", closable)._close_connection_async() - raise error - elif isinstance(exception, EventHubError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - raise error - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - try: - if isinstance(exception, errors.AuthenticationException): - await closable._close_connection_async() - elif isinstance(exception, errors.LinkDetach): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - elif isinstance(exception, errors.ConnectionClose): - await closable._close_connection_async() - elif isinstance(exception, errors.MessageHandlerError): - await cast("ConsumerProducerMixin", closable)._close_handler_async() - else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors - await closable._close_connection_async() - except AttributeError: - pass - return _create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 47db456b0a5f..8f0fe7d6004f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -2,23 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import annotations import uuid import asyncio import logging from typing import Iterable, Union, Optional, Any, AnyStr, List, TYPE_CHECKING -import time - -from uamqp import types, constants, errors -from uamqp import SendClientAsync from azure.core.tracing import AbstractSpan from .._common import EventData, EventDataBatch -from ..exceptions import _error_handler, OperationTimeoutError from .._producer import _set_partition_key, _set_trace_message from .._utils import ( create_properties, - set_message_partition_key, trace_message, send_context_manager, transform_outbound_single_message, @@ -29,6 +24,9 @@ from ._async_utils import get_dict_with_loop_if_needed if TYPE_CHECKING: + from uamqp import types, constants, errors + from uamqp import SendClientAsync + from uamqp.authentication import JWTTokenAsync # pylint: disable=ungrouped-imports from ._producer_client_async import EventHubProducerClient @@ -60,8 +58,9 @@ class EventHubProducer( Default value is `True`. """ - def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> None: + def __init__(self, client: EventHubProducerClient, target: str, **kwargs) -> None: super().__init__() + self._amqp_transport = kwargs.pop("amqp_transport") partition = kwargs.get("partition", None) send_timeout = kwargs.get("send_timeout", 60) keep_alive = kwargs.get("keep_alive", None) @@ -79,10 +78,14 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect self._timeout = send_timeout - self._idle_timeout = (idle_timeout * 1000) if idle_timeout else None - self._retry_policy = errors.ErrorPolicy( - max_retries=self._client._config.max_retries, - on_error=_error_handler, # pylint:disable=protected-access + self._idle_timeout = ( + (idle_timeout * self._amqp_transport.TIMEOUT_FACTOR) + if idle_timeout + else None + ) + + self._retry_policy = self._amqp_transport.create_retry_policy( + config=self._client._config ) self._reconnect_backoff = 1 self._name = "EHProducer-{}".format(uuid.uuid4()) @@ -91,29 +94,31 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N if partition: self._target += "/Partitions/" + partition self._name += "-partition{}".format(partition) - self._handler = None # type: Optional[SendClientAsync] - self._outcome = None # type: Optional[constants.MessageSendResult] - self._condition = None # type: Optional[Exception] + self._handler: Optional[SendClientAsync] = None + self._outcome: Optional[constants.MessageSendResult] = None + self._condition: Optional[Exception] = None self._lock = asyncio.Lock(**self._internal_kwargs) - self._link_properties = { - types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000)) - } + self._link_properties = self._amqp_transport.create_link_properties( + {TIMEOUT_SYMBOL: int(self._timeout * self._amqp_transport.TIMEOUT_FACTOR)} + ) + def _create_handler(self, auth: "JWTTokenAsync") -> None: - self._handler = SendClientAsync( - self._target, + self._handler = self._amqp_transport.create_send_client( + config=self._client._config, # pylint:disable=protected-access + target=self._target, auth=auth, - debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout * 1000, + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access idle_timeout=self._idle_timeout, - error_policy=self._retry_policy, + retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, properties=create_properties( - self._client._config.user_agent # pylint:disable=protected-access + self._client._config.user_agent, # pylint: disable=protected-access + amqp_transport=self._amqp_transport, ), - **self._internal_kwargs + msg_timeout=self._timeout * 1000, ) async def _open_with_retry(self) -> Any: @@ -121,38 +126,15 @@ async def _open_with_retry(self) -> Any: self._open, operation_need_param=False ) - def _set_msg_timeout( - self, timeout_time: Optional[float], last_exception: Optional[Exception] - ) -> None: - if not timeout_time: - return - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("Send operation timed out") - _LOGGER.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # type: ignore # pylint: disable=protected-access - async def _send_event_data( self, timeout_time: Optional[float] = None, last_exception: Optional[Exception] = None, ) -> None: - # TODO: Correct uAMQP type hints if self._unsent_events: - await self._open() - self._set_msg_timeout(timeout_time, last_exception) - self._handler.queue_message(*self._unsent_events) # type: ignore - await self._handler.wait_async() # type: ignore - self._unsent_events = self._handler.pending_messages # type: ignore - if self._outcome != constants.MessageSendResult.Ok: - if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("Send operation timed out") - if self._condition: - raise self._condition + await self._amqp_transport.send_messages_async( + self, timeout_time, last_exception, _LOGGER + ) async def _send_event_data_with_retry( self, timeout: Optional[float] = None @@ -183,16 +165,20 @@ def _wrap_eventdata( ) -> Union[EventData, EventDataBatch]: if isinstance(event_data, (EventData, AmqpAnnotatedMessage)): outgoing_event_data = transform_outbound_single_message( - event_data, EventData + event_data, EventData, self._amqp_transport.to_outgoing_amqp_message ) if partition_key: - set_message_partition_key(outgoing_event_data.message, partition_key) + self._amqp_transport.set_message_partition_key( + outgoing_event_data._message, partition_key # pylint: disable=protected-access + ) wrapper_event_data = outgoing_event_data trace_message(wrapper_event_data, span) else: if isinstance( event_data, EventDataBatch ): # The partition_key in the param will be omitted. + if not event_data: + return event_data if ( partition_key and partition_key @@ -203,15 +189,18 @@ def _wrap_eventdata( ) for ( event - ) in event_data.message._body_gen: # pylint: disable=protected-access + ) in event_data._message.data: # pylint: disable=protected-access trace_message(event, span) wrapper_event_data = event_data # type:ignore else: if partition_key: - event_data = _set_partition_key(event_data, partition_key) + event_data = _set_partition_key( + event_data, partition_key, self._amqp_transport + ) event_data = _set_trace_message(event_data, span) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # type: ignore # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome + wrapper_event_data = EventDataBatch._from_batch( # type: ignore # pylint: disable=protected-access + event_data, self._amqp_transport, partition_key + ) return wrapper_event_data async def send( @@ -253,7 +242,11 @@ async def send( wrapper_event_data = self._wrap_eventdata( event_data, child, partition_key ) - self._unsent_events = [wrapper_event_data.message] + + if not wrapper_event_data: + return + + self._unsent_events = [wrapper_event_data._message] # pylint: disable=protected-access if child: self._client._add_span_request_attributes( # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index a425bea6d059..af59da1efc5a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -8,7 +8,6 @@ from typing import Any, Union, List, Optional, Dict, Callable, cast from typing_extensions import TYPE_CHECKING, Literal, Awaitable, overload -from uamqp import constants from ..exceptions import ConnectError, EventHubError from ..amqp import AmqpAnnotatedMessage @@ -16,12 +15,11 @@ from ._producer_async import EventHubProducer from ._buffered_producer import BufferedProducerDispatcher from .._utils import set_event_partition_key -from .._constants import ALL_PARTITIONS +from .._constants import ALL_PARTITIONS, TransportType from .._common import EventDataBatch, EventData if TYPE_CHECKING: from ._client_base_async import CredentialTypes - from uamqp.constants import TransportType # pylint: disable=ungrouped-imports SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -277,7 +275,7 @@ async def _buffered_send_batch(self, event_data_batch, **kwargs): async def _buffered_send_event(self, event, **kwargs): partition_key = kwargs.get("partition_key") - set_event_partition_key(event, partition_key) + set_event_partition_key(event, partition_key, self._amqp_transport) timeout = kwargs.get("timeout") timeout_time = time.time() + timeout if timeout else None await self._buffered_send( @@ -301,10 +299,12 @@ async def _get_max_message_size(self) -> None: EventHubProducer, self._producers[ALL_PARTITIONS] )._open_with_retry() self._max_message_size_on_link = ( - cast( # type: ignore + self._amqp_transport.get_remote_max_message_size( + cast( # type: ignore EventHubProducer, self._producers[ALL_PARTITIONS] - )._handler.message_handler._link.peer_max_message_size - or constants.MAX_MESSAGE_LENGTH_BYTES + )._handler + ) + or self._amqp_transport.MAX_MESSAGE_LENGTH_BYTES ) async def _start_producer( @@ -350,6 +350,7 @@ def _create_producer( partition=partition_id, send_timeout=send_timeout, idle_timeout=self._idle_timeout, + amqp_transport = self._amqp_transport, **self._internal_kwargs ) return handler @@ -402,7 +403,7 @@ def from_connection_string( auth_timeout: float = 60, user_agent: Optional[str] = None, retry_total: int = 3, - transport_type: Optional["TransportType"] = None, + transport_type: TransportType = TransportType.Amqp, **kwargs: Any ) -> "EventHubProducerClient": """Create an EventHubProducerClient from a connection string. @@ -718,7 +719,7 @@ async def create_batch( event_data_batch = EventDataBatch( max_size_in_bytes=(max_size_in_bytes or self._max_message_size_on_link), partition_id=partition_id, - partition_key=partition_key, + partition_key=partition_key ) return event_data_batch diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py new file mode 100644 index 000000000000..34913fb394d7 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py new file mode 100644 index 000000000000..ce9342c607ed --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_base_async.py @@ -0,0 +1,272 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import annotations +from typing import Tuple, Union, TYPE_CHECKING +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from uamqp import types as uamqp_types + +class AmqpTransportAsync(ABC): # pylint: disable=too-many-public-methods + """ + Abstract class that defines a set of common methods needed by producer and consumer. + """ + # define constants + MAX_FRAME_SIZE_BYTES: int + MAX_MESSAGE_LENGTH_BYTES: int + TIMEOUT_FACTOR: int + CONNECTION_CLOSING_STATES: Tuple + + # define symbols + PRODUCT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + VERSION_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + FRAMEWORK_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PLATFORM_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + USER_AGENT_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + PROP_PARTITION_KEY_AMQP_SYMBOL: Union[uamqp_types.AMQPSymbol, str, bytes] + + + @staticmethod + @abstractmethod + def build_message(**kwargs): + """ + Creates a uamqp.Message or pyamqp.Message with given arguments. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def build_batch_message(**kwargs): + """ + Creates a uamqp.BatchMessage or pyamqp.BatchMessage with given arguments. + :rtype: uamqp.BatchMessage or pyamqp.BatchMessage + """ + + @staticmethod + @abstractmethod + def to_outgoing_amqp_message(annotated_message): + """ + Converts an AmqpAnnotatedMessage into an Amqp Message. + :param AmqpAnnotatedMessage annotated_message: AmqpAnnotatedMessage to convert. + :rtype: uamqp.Message or pyamqp.Message + """ + + @staticmethod + @abstractmethod + def get_batch_message_encoded_size(message): + """ + Gets the batch message encoded size given an underlying Message. + :param uamqp.BatchMessage message: Message to get encoded size of. + :rtype: int + """ + + @staticmethod + @abstractmethod + def get_remote_max_message_size(handler): + """ + Returns max peer message size. + :param AMQPClient handler: Client to get remote max message size on link from. + :rtype: int + """ + + @staticmethod + @abstractmethod + def create_retry_policy(config): + """ + Creates the error retry policy. + :param ~azure.eventhub._configuration.Configuration config: Configuration. + """ + + @staticmethod + @abstractmethod + def create_link_properties(link_properties): + """ + Creates and returns the link properties. + :param dict[bytes, int] link_properties: The dict of symbols and corresponding values. + :rtype: dict + """ + + @staticmethod + @abstractmethod + async def create_connection_async(**kwargs): + """ + Creates and returns the uamqp async Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + + @staticmethod + @abstractmethod + async def close_connection_async(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def get_connection_state(connection): + """ + Gets connection state. + :param connection: uamqp or pyamqp Connection. + """ + + @staticmethod + @abstractmethod + def create_send_client(*, config, **kwargs): + """ + Creates and returns the send client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + + @staticmethod + @abstractmethod + async def send_messages_async(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + + @staticmethod + @abstractmethod + def set_message_partition_key(message, partition_key, **kwargs): + """Set the partition key as an annotation on a uamqp message. + + :param message: The message to update. + :param str partition_key: The partition key value. + :rtype: None + """ + + @staticmethod + @abstractmethod + def create_source(source, offset, selector): + """ + Creates and returns the Source. + + :param str source: Required. + :param int offset: Required. + :param bytes selector: Required. + """ + + @staticmethod + @abstractmethod + def create_receive_client(*, config, **kwargs): + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword Source source: Required. The source. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. Missing in pyamqp. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + @staticmethod + @abstractmethod + async def receive_messages(consumer, batch, max_batch_size, max_wait_time): + """ + Receives messages, creates events, and returns them by calling the on received callback. + :param ~azure.eventhub.aio.EventHubConsumer consumer: The EventHubConsumer. + :param bool batch: If receive batch or single event. + :param int max_batch_size: Max batch size. + :param int or None max_wait_time: Max wait time. + """ + + @staticmethod + @abstractmethod + async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Whether to update token. If not updating token, + then pass 300 to refresh_window. Only used by uamqp. + """ + + @staticmethod + @abstractmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + @staticmethod + @abstractmethod + async def get_updated_token_async(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + + @staticmethod + @abstractmethod + async def mgmt_client_request_async(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + + @staticmethod + @abstractmethod + def get_error(status_code, description): + """ + Gets error corresponding to status code. + :param status_code: Status code. + :param str description: Description of error. + """ + + @staticmethod + @abstractmethod + def check_timeout_exception(base, exception): + """ + Checks if timeout exception. + :param base: ClientBase. + :param exception: Exception to check. + """ diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py new file mode 100644 index 000000000000..72b91aee7766 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_uamqp_transport_async.py @@ -0,0 +1,372 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from __future__ import annotations +import asyncio +import time +import logging +from typing import Union, cast, TYPE_CHECKING, List + +from uamqp import ( + constants, + types, + SendClientAsync, + ReceiveClientAsync, + utils, + authentication, + AMQPClientAsync, + errors, +) +from uamqp.async_ops import ConnectionAsync + +from ._base_async import AmqpTransportAsync +from ..._transport._uamqp_transport import UamqpTransport +from ...exceptions import ( + OperationTimeoutError, + EventHubError, + EventDataError, + EventDataSendError, +) + +if TYPE_CHECKING: + from .._client_base_async import ClientBaseAsync, ConsumerProducerMixin + from ..._common import EventData + +_LOGGER = logging.getLogger(__name__) + +class UamqpTransportAsync(UamqpTransport, AmqpTransportAsync): + """ + Class which defines uamqp-based methods used by the producer and consumer. + """ + + @staticmethod + async def create_connection_async(**kwargs): + """ + Creates and returns the uamqp async Connection object. + :keyword str host: The hostname, used by uamqp. + :keyword JWTTokenAuth auth: The auth, used by uamqp. + :keyword str endpoint: The endpoint, used by pyamqp. + :keyword str container_id: Required. + :keyword int max_frame_size: Required. + :keyword int channel_max: Required. + :keyword int idle_timeout: Required. + :keyword Dict properties: Required. + :keyword int remote_idle_timeout_empty_frame_send_ratio: Required. + :keyword error_policy: Required. + :keyword bool debug: Required. + :keyword str encoding: Required. + """ + endpoint = kwargs.pop("endpoint") # pylint:disable=unused-variable + host = kwargs.pop("host") + auth = kwargs.pop("auth") + return ConnectionAsync( + host, + auth, + **kwargs + ) + + @staticmethod + async def close_connection_async(connection): + """ + Closes existing connection. + :param connection: uamqp or pyamqp Connection. + """ + await connection.destroy_async() + + @staticmethod + def create_send_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the uamqp SendClient. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str target: Required. The target. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword keep_alive_interval: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + """ + target = kwargs.pop("target") + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + + return SendClientAsync( + target, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + **kwargs + ) + + @staticmethod + async def send_messages_async(producer, timeout_time, last_exception, logger): + """ + Handles sending of event data messages. + :param ~azure.eventhub._producer.EventHubProducer producer: The producer with handler to send messages. + :param int timeout_time: Timeout time. + :param last_exception: Exception to raise if message timed out. Only used by uamqp transport. + :param logger: Logger. + """ + # pylint: disable=protected-access + await producer._open() + producer._unsent_events[0].on_send_complete = producer._on_outcome + UamqpTransportAsync._set_msg_timeout(producer, timeout_time, last_exception, logger) + producer._handler.queue_message(*producer._unsent_events) # type: ignore + await producer._handler.wait_async() # type: ignore + producer._unsent_events = producer._handler.pending_messages # type: ignore + if producer._outcome != constants.MessageSendResult.Ok: + if producer._outcome == constants.MessageSendResult.Timeout: + producer._condition = OperationTimeoutError("Send operation timed out") + if producer._condition: + raise producer._condition + + @staticmethod + def create_receive_client(*, config, **kwargs): # pylint:disable=unused-argument + """ + Creates and returns the receive client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + + :keyword str source: Required. The source. + :keyword str offset: Required. + :keyword str offset_inclusive: Required. + :keyword JWTTokenAuth auth: Required. + :keyword int idle_timeout: Required. + :keyword network_trace: Required. + :keyword retry_policy: Required. + :keyword str client_name: Required. + :keyword dict link_properties: Required. + :keyword properties: Required. + :keyword link_credit: Required. The prefetch. + :keyword keep_alive_interval: Required. + :keyword desired_capabilities: Required. + :keyword streaming_receive: Required. + :keyword message_received_callback: Required. + :keyword timeout: Required. + """ + + source = kwargs.pop("source") + symbol_array = kwargs.pop("desired_capabilities") + desired_capabilities = None + if symbol_array: + symbol_array = [types.AMQPSymbol(symbol) for symbol in symbol_array] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + retry_policy = kwargs.pop("retry_policy") + network_trace = kwargs.pop("network_trace") + link_credit = kwargs.pop("link_credit") + streaming_receive = kwargs.pop("streaming_receive") + message_received_callback = kwargs.pop("message_received_callback") + + client = ReceiveClientAsync( + source, + debug=network_trace, # pylint:disable=protected-access + error_policy=retry_policy, + desired_capabilities=desired_capabilities, + prefetch=link_credit, + receive_settle_mode=constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + **kwargs + ) + # pylint:disable=protected-access + client._streaming_receive = streaming_receive + client._message_received_callback = (message_received_callback) + return client + + @staticmethod + async def receive_messages(consumer, batch, max_batch_size, max_wait_time): + """ + Receives messages, creates events, and returns them by calling the on received callback. + :param ~azure.eventhub.aio.EventHubConsumer consumer: The EventHubConsumer. + :param bool batch: If receive batch or single event. + :param int max_batch_size: Max batch size. + :param int or None max_wait_time: Max wait time. + """ + # pylint:disable=protected-access + max_retries = ( + consumer._client._config.max_retries # pylint:disable=protected-access + ) + has_not_fetched_once = True # ensure one trip when max_wait_time is very small + deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None + while len(consumer._message_buffer) < max_batch_size and ( + time.time() < deadline or has_not_fetched_once + ): + retried_times = 0 + has_not_fetched_once = False + while retried_times <= max_retries: + try: + await consumer._open() + await cast( + ReceiveClientAsync, consumer._handler + ).do_work_async() # uamqp sleeps 0.05 if none received + break + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as exception: # pylint: disable=broad-except + if ( + isinstance(exception, errors.LinkDetach) + and exception.condition == constants.ErrorCodes.LinkStolen # pylint: disable=no-member + ): + raise await consumer._handle_exception(exception) + if not consumer.running: # exit by close + return + if consumer._last_received_event: + consumer._offset = consumer._last_received_event.offset + last_exception = await consumer._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info( + "%r operation has exhausted retry. Last exception: %r.", + consumer._name, + last_exception, + ) + raise last_exception + + if consumer._message_buffer: + while consumer._message_buffer: + if batch: + events_for_callback: List[EventData] = [] + for _ in range(min(max_batch_size, len(consumer._message_buffer))): + events_for_callback.append(consumer._next_message_in_buffer()) + await consumer._on_event_received(events_for_callback) + else: + await consumer._on_event_received(consumer._next_message_in_buffer()) + elif max_wait_time: + if batch: + await consumer._on_event_received([]) + else: + await consumer._on_event_received(None) + + @staticmethod + async def create_token_auth_async(auth_uri, get_token, token_type, config, **kwargs): + """ + Creates the JWTTokenAuth. + :param str auth_uri: The auth uri to pass to JWTTokenAuth. + :param get_token: The callback function used for getting and refreshing + tokens. It should return a valid jwt token each time it is called. + :param bytes token_type: Token type. + :param ~azure.eventhub._configuration.Configuration config: EH config. + + :keyword bool update_token: Required. Whether to update token. If not updating token, + then pass 300 to refresh_window. + """ + update_token = kwargs.pop("update_token") + refresh_window = 300 + if update_token: + refresh_window = 0 + + token_auth = authentication.JWTTokenAsync( + auth_uri, + auth_uri, + get_token, + token_type=token_type, + timeout=config.auth_timeout, + http_proxy=config.http_proxy, + transport_type=config.transport_type, + custom_endpoint_hostname=config.custom_endpoint_hostname, + port=config.connection_port, + verify=config.connection_verify, + refresh_window=refresh_window + ) + if update_token: + await token_auth.update_token() + return token_auth + + @staticmethod + def create_mgmt_client(address, mgmt_auth, config): + """ + Creates and returns the mgmt AMQP client. + :param _Address address: Required. The Address. + :param JWTTokenAuth mgmt_auth: Auth for client. + :param ~azure.eventhub._configuration.Configuration config: The configuration. + """ + + mgmt_target = f"amqps://{address.hostname}{address.path}" + return AMQPClientAsync( + mgmt_target, + auth=mgmt_auth, + debug=config.network_tracing + ) + + @staticmethod + async def get_updated_token_async(mgmt_auth): + """ + Return updated auth token. + :param mgmt_auth: Auth. + """ + return mgmt_auth.token + + @staticmethod + async def mgmt_client_request_async(mgmt_client, mgmt_msg, **kwargs): + """ + Send mgmt request. + :param AMQP Client mgmt_client: Client to send request with. + :param str mgmt_msg: Message. + :keyword bytes operation: Operation. + :keyword operation_type: Op type. + :keyword status_code_field: mgmt status code. + :keyword description_fields: mgmt status desc. + """ + operation_type = kwargs.pop("operation_type") + operation = kwargs.pop("operation") + return await mgmt_client.mgmt_request_async( + mgmt_msg, + operation, + op_type=operation_type, + **kwargs + ) + + @staticmethod + async def _handle_exception_async( # pylint:disable=too-many-branches, too-many-statements + exception: Exception, closable: Union["ClientBaseAsync", "ConsumerProducerMixin"] + ) -> Exception: + # pylint: disable=protected-access + if isinstance(exception, asyncio.CancelledError): + raise exception + error = exception + try: + name = cast("ConsumerProducerMixin", closable)._name + except AttributeError: + name = cast("ClientBaseAsync", closable)._container_id + if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise + _LOGGER.info("%r stops due to keyboard interrupt", name) + await cast("ConsumerProducerMixin", closable)._close_connection_async() + raise error + elif isinstance(exception, EventHubError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + raise error + elif isinstance( + exception, + ( + errors.MessageAccepted, + errors.MessageAlreadySettled, + errors.MessageModified, + errors.MessageRejected, + errors.MessageReleased, + errors.MessageContentTooLarge, + ), + ): + _LOGGER.info("%r Event data error (%r)", name, exception) + error = EventDataError(str(exception), exception) + raise error + elif isinstance(exception, errors.MessageException): + _LOGGER.info("%r Event data send error (%r)", name, exception) + error = EventDataSendError(str(exception), exception) + raise error + else: + try: + if isinstance(exception, errors.AuthenticationException): + await closable._close_connection_async() + elif isinstance(exception, errors.LinkDetach): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + elif isinstance(exception, errors.ConnectionClose): + await closable._close_connection_async() + elif isinstance(exception, errors.MessageHandlerError): + await cast("ConsumerProducerMixin", closable)._close_handler_async() + else: # errors.AMQPConnectionError, compat.TimeoutException, and any other errors + await closable._close_connection_async() + except AttributeError: + pass + return UamqpTransport._create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 28a3c9e79fa4..d6cc5937a11d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -4,11 +4,11 @@ # license information. # ------------------------------------------------------------------------- -from typing import Optional, Any, cast, Mapping, Dict +from __future__ import annotations +from typing import Optional, Any, cast, Mapping, Dict, Union, List -import uamqp - -from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType +from ._amqp_utils import normalized_data_body, normalized_sequence_body +from ._constants import AmqpMessageBodyType from .._mixin import DictMixin @@ -19,11 +19,9 @@ class AmqpAnnotatedMessage(object): access to low-level AMQP message sections. There should be one and only one of either data_body, sequence_body or value_body being set as the body of the AmqpAnnotatedMessage; if more than one body is set, `ValueError` will be raised. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#section-message-format for more information on the message format. - :keyword data_body: The body consists of one or more data sections and each section contains opaque binary data. :paramtype data_body: Union[str, bytes, List[Union[str, bytes]]] :keyword sequence_body: The body consists of one or more sequence sections and @@ -47,12 +45,15 @@ class AmqpAnnotatedMessage(object): def __init__(self, **kwargs): # type: (Any) -> None - self._message = kwargs.pop("message", None) self._encoding = kwargs.pop("encoding", "UTF-8") + self._data_body: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None + self._sequence_body: Optional[List[Any]] = None + self._value_body: Any = None # internal usage only for Event Hub received message - if self._message: - self._from_amqp_message(self._message) + message = kwargs.pop("message", None) + if message: + self._from_amqp_message(message) return # manually constructed AMQPAnnotatedMessage @@ -69,21 +70,17 @@ def __init__(self, **kwargs): "or value_body being set as the body of the AmqpAnnotatedMessage." ) - self._body = None - self._body_type = None + self._body_type: AmqpMessageBodyType = None # type: ignore if "data_body" in kwargs: - self._body = kwargs.get("data_body") - self._body_type = uamqp.MessageBodyType.Data + self._data_body = normalized_data_body(kwargs.get("data_body")) + self._body_type = AmqpMessageBodyType.DATA elif "sequence_body" in kwargs: - self._body = kwargs.get("sequence_body") - self._body_type = uamqp.MessageBodyType.Sequence + self._sequence_body = normalized_sequence_body(kwargs.get("sequence_body")) + self._body_type = AmqpMessageBodyType.SEQUENCE elif "value_body" in kwargs: - self._body = kwargs.get("value_body") - self._body_type = uamqp.MessageBodyType.Value + self._value_body = kwargs.get("value_body") + self._body_type = AmqpMessageBodyType.VALUE - self._message = uamqp.message.Message( - body=self._body, body_type=self._body_type - ) header_dict = cast(Mapping, kwargs.get("header")) self._header = AmqpMessageHeader(**header_dict) if "header" in kwargs else None self._footer = kwargs.get("footer") @@ -95,11 +92,16 @@ def __init__(self, **kwargs): self._annotations = kwargs.get("annotations") self._delivery_annotations = kwargs.get("delivery_annotations") - def __str__(self): - return str(self._message) + def __str__(self) -> str: + if self._body_type == AmqpMessageBodyType.DATA: + return "".join(d.decode(self._encoding) for d in self._data_body) # type: ignore + if self._body_type == AmqpMessageBodyType.SEQUENCE: + return str(self._sequence_body) + if self._body_type == AmqpMessageBodyType.VALUE: + return str(self._value_body) + return "" - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: # pylint: disable=bare-except message_repr = "body={}".format(str(self)) message_repr += ", body_type={}".format(self.body_type) @@ -134,143 +136,79 @@ def __repr__(self): return "AmqpAnnotatedMessage({})".format(message_repr)[:1024] def _from_amqp_message(self, message): - # populate the properties from an uamqp message - self._properties = ( - AmqpMessageProperties( - message_id=message.properties.message_id, - user_id=message.properties.user_id, - to=message.properties.to, - subject=message.properties.subject, - reply_to=message.properties.reply_to, - correlation_id=message.properties.correlation_id, - content_type=message.properties.content_type, - content_encoding=message.properties.content_encoding, - absolute_expiry_time=message.properties.absolute_expiry_time, - creation_time=message.properties.creation_time, - group_id=message.properties.group_id, - group_sequence=message.properties.group_sequence, - reply_to_group_id=message.properties.reply_to_group_id, - ) - if message.properties - else None - ) - self._header = ( - AmqpMessageHeader( - delivery_count=message.header.delivery_count, - time_to_live=message.header.time_to_live, - first_acquirer=message.header.first_acquirer, - durable=message.header.durable, - priority=message.header.priority, - ) - if message.header - else None - ) - self._footer = message.footer - self._annotations = message.annotations - self._delivery_annotations = message.delivery_annotations - self._application_properties = message.application_properties - - def _to_outgoing_amqp_message(self): - message_header = None - if self.header: - message_header = uamqp.message.MessageHeader() - message_header.delivery_count = self.header.delivery_count - message_header.time_to_live = self.header.time_to_live - message_header.first_acquirer = self.header.first_acquirer - message_header.durable = self.header.durable - message_header.priority = self.header.priority - - message_properties = None - if self.properties: - message_properties = uamqp.message.MessageProperties( - message_id=self.properties.message_id, - user_id=self.properties.user_id, - to=self.properties.to, - subject=self.properties.subject, - reply_to=self.properties.reply_to, - correlation_id=self.properties.correlation_id, - content_type=self.properties.content_type, - content_encoding=self.properties.content_encoding, - creation_time=int(self.properties.creation_time) - if self.properties.creation_time - else None, - absolute_expiry_time=int(self.properties.absolute_expiry_time) - if self.properties.absolute_expiry_time - else None, - group_id=self.properties.group_id, - group_sequence=self.properties.group_sequence, - reply_to_group_id=self.properties.reply_to_group_id, - encoding=self._encoding, - ) - - amqp_body = self._message._body # pylint: disable=protected-access - if isinstance(amqp_body, uamqp.message.DataBody): - amqp_body_type = uamqp.MessageBodyType.Data - amqp_body = list(amqp_body.data) - elif isinstance(amqp_body, uamqp.message.SequenceBody): - amqp_body_type = uamqp.MessageBodyType.Sequence - amqp_body = list(amqp_body.data) + self._properties = AmqpMessageProperties( + message_id=message.properties.message_id, + user_id=message.properties.user_id, + to=message.properties.to, + subject=message.properties.subject, + reply_to=message.properties.reply_to, + correlation_id=message.properties.correlation_id, + content_type=message.properties.content_type, + content_encoding=message.properties.content_encoding, + absolute_expiry_time=message.properties.absolute_expiry_time, + creation_time=message.properties.creation_time, + group_id=message.properties.group_id, + group_sequence=message.properties.group_sequence, + reply_to_group_id=message.properties.reply_to_group_id, + ) if message.properties else None + self._header = AmqpMessageHeader( + delivery_count=message.header.delivery_count, + time_to_live=message.header.ttl, + first_acquirer=message.header.first_acquirer, + durable=message.header.durable, + priority=message.header.priority + ) if message.header else None + self._footer = message.footer if message.footer else {} + self._annotations = message.message_annotations if message.message_annotations else {} + self._delivery_annotations = message.delivery_annotations if message.delivery_annotations else {} + self._application_properties = message.application_properties if message.application_properties else {} + if message.data: + self._data_body = cast(List, list(message.data)) + self._body_type = AmqpMessageBodyType.DATA + elif message.sequence: + self._sequence_body = cast(List, list(message.sequence)) + self._body_type = AmqpMessageBodyType.SEQUENCE else: - # amqp_body is type of uamqp.message.ValueBody - amqp_body_type = uamqp.MessageBodyType.Value - amqp_body = amqp_body.data - - return uamqp.message.Message( - body=amqp_body, - body_type=amqp_body_type, - header=message_header, - properties=message_properties, - application_properties=self.application_properties, - annotations=self.annotations, - delivery_annotations=self.delivery_annotations, - footer=self.footer, - ) + self._value_body = message.value + self._body_type = AmqpMessageBodyType.VALUE @property - def body(self): - # type: () -> Any + def body(self) -> Any: """The body of the Message. The format may vary depending on the body type: - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.DATA`, - the body could be bytes or Iterable[bytes]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.SEQUENCE`, - the body could be List or Iterable[List]. - For :class:`azure.eventhub.amqp.AmqpMessageBodyType.VALUE`, - the body could be any type. - + For ~azure.eventhub.AmqpMessageBodyType.DATA, the body could be bytes or Iterable[bytes] + For ~azure.eventhub.AmqpMessageBodyType.SEQUENCE, the body could be List or Iterable[List] + For ~azure.eventhub.AmqpMessageBodyType.VALUE, the body could be any type. :rtype: Any """ - return self._message.get_data() + if self._body_type == AmqpMessageBodyType.DATA: # pylint:disable=no-else-return + return (i for i in cast(List, self._data_body)) # type: ignore + elif self._body_type == AmqpMessageBodyType.SEQUENCE: + return (i for i in cast(List, self._sequence_body)) + elif self._body_type == AmqpMessageBodyType.VALUE: + return self._value_body + return None @property - def body_type(self): - # type: () -> AmqpMessageBodyType + def body_type(self) -> AmqpMessageBodyType: """The body type of the underlying AMQP message. - - :rtype: ~azure.eventhub.amqp.AmqpMessageBodyType + rtype: ~azure.eventhub.amqp.AmqpMessageBodyType """ - return AMQP_MESSAGE_BODY_TYPE_MAP.get( - self._message._body.type, # pylint: disable=protected-access - AmqpMessageBodyType.VALUE, - ) + return self._body_type @property - def properties(self): - # type: () -> Optional[AmqpMessageProperties] + def properties(self) -> Optional[AmqpMessageProperties]: """ Properties to add to the message. - :rtype: Optional[~azure.eventhub.amqp.AmqpMessageProperties] """ return self._properties @properties.setter - def properties(self, value): - # type: (AmqpMessageProperties) -> None + def properties(self, value: AmqpMessageProperties) -> None: self._properties = value @property - def application_properties(self): - # type: () -> Optional[Dict] + def application_properties(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific application properties. @@ -279,13 +217,11 @@ def application_properties(self): return self._application_properties @application_properties.setter - def application_properties(self, value): - # type: (Dict) -> None + def application_properties(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._application_properties = value @property - def annotations(self): - # type: () -> Optional[Dict] + def annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Service specific message annotations. @@ -294,13 +230,11 @@ def annotations(self): return self._annotations @annotations.setter - def annotations(self, value): - # type: (Dict) -> None + def annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._annotations = value @property - def delivery_annotations(self): - # type: () -> Optional[Dict] + def delivery_annotations(self) -> Optional[Dict[Union[str, bytes], Any]]: """ Delivery-specific non-standard properties at the head of the message. Delivery annotations convey information from the sending peer to the receiving peer. @@ -310,28 +244,23 @@ def delivery_annotations(self): return self._delivery_annotations @delivery_annotations.setter - def delivery_annotations(self, value): - # type: (Dict) -> None + def delivery_annotations(self, value: Optional[Dict[Union[str, bytes], Any]]) -> None: self._delivery_annotations = value @property - def header(self): - # type: () -> Optional[AmqpMessageHeader] + def header(self) -> Optional[AmqpMessageHeader]: """ The message header. - :rtype: Optional[~azure.eventhub.amqp.AmqpMessageHeader] """ return self._header @header.setter - def header(self, value): - # type: (AmqpMessageHeader) -> None + def header(self, value: AmqpMessageHeader) -> None: self._header = value @property - def footer(self): - # type: () -> Optional[Dict] + def footer(self) -> Optional[Dict[Any, Any]]: """ The message footer. @@ -340,8 +269,7 @@ def footer(self): return self._footer @footer.setter - def footer(self, value): - # type: (Dict) -> None + def footer(self, value: Optional[Dict[Any, Any]]) -> None: self._footer = value @@ -350,11 +278,9 @@ class AmqpMessageHeader(DictMixin): The Message header. This is only used on received message, and not set on messages being sent. The properties set on any given message will depend on the Service and not all messages will have all properties. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-header for more information on the message header. - :keyword delivery_count: The number of unsuccessful previous attempts to deliver this message. If this value is non-zero it can be taken as an indication that the delivery might be a duplicate. On first delivery, the value is zero. It is @@ -425,11 +351,9 @@ class AmqpMessageProperties(DictMixin): The properties that are actually used will depend on the service implementation. Not all received messages will have all properties, and not all properties will be utilized on a sent message. - Please refer to the AMQP spec: http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-messaging-v1.0-os.html#type-properties for more information on the message properties. - :keyword message_id: Message-id, if set, uniquely identifies a message within the message system. The message producer is usually responsible for setting the message-id in such a way that it is assured to be globally unique. A broker MAY discard a message as a duplicate if the value diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py new file mode 100644 index 000000000000..c620c149ea5e --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_utils.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +def encode_str(data, encoding='utf-8'): + try: + return data.encode(encoding) + except AttributeError: + return data + +def normalized_data_body(data, **kwargs): + # A helper method to normalize input into AMQP Data Body format + encoding = kwargs.get("encoding", "utf-8") + if isinstance(data, list): + return [encode_str(item, encoding) for item in data] + return [encode_str(data, encoding)] + +def normalized_sequence_body(sequence): + # A helper method to normalize input into AMQP Sequence Body format + if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): + return sequence + if isinstance(sequence, list): + return [sequence] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py index 1e2e7d3b6577..576321d4cc2b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_constants.py @@ -4,8 +4,6 @@ # license information. # ------------------------------------------------------------------------- from enum import Enum - -from uamqp import MessageBodyType from azure.core import CaseInsensitiveEnumMeta @@ -13,10 +11,3 @@ class AmqpMessageBodyType(str, Enum, metaclass=CaseInsensitiveEnumMeta): DATA = "data" SEQUENCE = "sequence" VALUE = "value" - - -AMQP_MESSAGE_BODY_TYPE_MAP = { - MessageBodyType.Data.value: AmqpMessageBodyType.DATA, - MessageBodyType.Sequence.value: AmqpMessageBodyType.SEQUENCE, - MessageBodyType.Value.value: AmqpMessageBodyType.VALUE, -} diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py index 6d90033502f8..f686251e6e95 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/exceptions.py @@ -2,40 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import logging import six -from uamqp import errors, compat - -from ._constants import NO_RETRY_ERRORS - -_LOGGER = logging.getLogger(__name__) - - -def _error_handler(error): - """ - Called internally when an event has failed to send so we - can parse the error to determine whether we should attempt - to retry sending the event again. - Returns the action to take according to error type. - - :param error: The error received in the send attempt. - :type error: Exception - :rtype: ~uamqp.errors.ErrorAction - """ - if error.condition == b"com.microsoft:server-busy": - return errors.ErrorAction(retry=True, backoff=4) - if error.condition == b"com.microsoft:timeout": - return errors.ErrorAction(retry=True, backoff=2) - if error.condition == b"com.microsoft:operation-cancelled": - return errors.ErrorAction(retry=True) - if error.condition == b"com.microsoft:container-close": - return errors.ErrorAction(retry=True, backoff=4) - if error.condition in NO_RETRY_ERRORS: - return errors.ErrorAction(retry=False) - return errors.ErrorAction(retry=True) - - class EventHubError(Exception): """Represents an error occurred in the client. @@ -125,79 +93,3 @@ class OperationTimeoutError(EventHubError): class OwnershipLostError(Exception): """Raised when `update_checkpoint` detects the ownership to a partition has been lost.""" - - -def _create_eventhub_exception(exception): - if isinstance(exception, errors.AuthenticationException): - error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.VendorLinkDetach): - error = ConnectError(str(exception), exception) - elif isinstance(exception, errors.LinkDetach): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.ConnectionClose): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.MessageHandlerError): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.AMQPConnectionError): - error_type = ( - AuthenticationError - if str(exception).startswith("Unable to open authentication session") - else ConnectError - ) - error = error_type(str(exception), exception) - elif isinstance(exception, compat.TimeoutException): - error = ConnectionLostError(str(exception), exception) - else: - error = EventHubError(str(exception), exception) - return error - - -def _handle_exception( - exception, closable -): # pylint:disable=too-many-branches, too-many-statements - try: # closable is a producer/consumer object - name = closable._name # pylint: disable=protected-access - except AttributeError: # closable is an client object - name = closable._container_id # pylint: disable=protected-access - if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - _LOGGER.info("%r stops due to keyboard interrupt", name) - closable._close_connection() # pylint:disable=protected-access - raise exception - elif isinstance(exception, EventHubError): - closable._close_handler() # pylint:disable=protected-access - raise exception - elif isinstance( - exception, - ( - errors.MessageAccepted, - errors.MessageAlreadySettled, - errors.MessageModified, - errors.MessageRejected, - errors.MessageReleased, - errors.MessageContentTooLarge, - ), - ): - _LOGGER.info("%r Event data error (%r)", name, exception) - error = EventDataError(str(exception), exception) - raise error - elif isinstance(exception, errors.MessageException): - _LOGGER.info("%r Event data send error (%r)", name, exception) - error = EventDataSendError(str(exception), exception) - raise error - else: - if isinstance(exception, errors.AuthenticationException): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.LinkDetach): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - elif isinstance(exception, errors.ConnectionClose): - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - elif isinstance(exception, errors.MessageHandlerError): - if hasattr(closable, "_close_handler"): - closable._close_handler() # pylint:disable=protected-access - else: # errors.AMQPConnectionError, compat.TimeoutException - if hasattr(closable, "_close_connection"): - closable._close_connection() # pylint:disable=protected-access - return _create_eventhub_exception(exception) diff --git a/sdk/eventhub/azure-eventhub/conftest.py b/sdk/eventhub/azure-eventhub/conftest.py index 802b99d429f9..981fcf68f53a 100644 --- a/sdk/eventhub/azure-eventhub/conftest.py +++ b/sdk/eventhub/azure-eventhub/conftest.py @@ -42,6 +42,9 @@ def sleep(request): sleep = request.config.getoption("--sleep") return sleep.lower() in ('true', 'yes', '1', 'y') +@pytest.fixture(scope="session", params=[True]) +def uamqp_transport(request): + return request.param def get_logger(filename, level=logging.INFO): azure_logger = logging.getLogger("azure.eventhub") @@ -68,6 +71,13 @@ def get_logger(filename, level=logging.INFO): log = get_logger(None, logging.DEBUG) +@pytest.fixture(scope="session") +def timeout_factor(uamqp_transport): + if uamqp_transport: + return 1000 + else: + return 1 + @pytest.fixture(scope="session") def resource_group(): try: diff --git a/sdk/eventhub/azure-eventhub/dev_requirements.txt b/sdk/eventhub/azure-eventhub/dev_requirements.txt index df47262912ac..eef3c3014389 100644 --- a/sdk/eventhub/azure-eventhub/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhub/dev_requirements.txt @@ -5,4 +5,3 @@ azure-mgmt-eventhub==10.0.0 azure-mgmt-resource==20.0.0 aiohttp>=3.0 -e ../../../tools/azure-devtools --e ../../servicebus/azure-servicebus \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/setup.py b/sdk/eventhub/azure-eventhub/setup.py index f7c5843f4a68..b7d2c5294da4 100644 --- a/sdk/eventhub/azure-eventhub/setup.py +++ b/sdk/eventhub/azure-eventhub/setup.py @@ -69,7 +69,7 @@ packages=find_packages(exclude=exclude_packages), install_requires=[ "azure-core<2.0.0,>=1.14.0", - "uamqp>=1.5.1,<2.0.0", + "uamqp>=1.6.0,<2.0.0", "typing-extensions>=4.0.1", ] ) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index 6b9157d30440..141d2b4861d6 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -22,11 +22,12 @@ AmqpMessageProperties, ) + @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_amqp_annotated_message(connstr_receivers): +async def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) async with client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} @@ -126,7 +127,7 @@ async def on_event(partition_context, event): on_event.received = [] client = EventHubConsumerClient.from_connection_string(connection_str, - consumer_group='$default') + consumer_group='$default', uamqp_transport=uamqp_transport) async with client: task = asyncio.ensure_future(client.receive(on_event, starting_position="-1")) await asyncio.sleep(15) @@ -342,7 +343,7 @@ async def test_send_multiple_partition_with_app_prop_async(connstr_receivers): async def test_send_over_websocket_async(connstr_receivers): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str, - transport_type=TransportType.AmqpOverWebsocket) + transport_type=uamqp.constants.TransportType.AmqpOverWebsocket) async with client: batch = await client.create_batch(partition_id="0") diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index c00ea84067ea..3fd512bbf141 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -14,17 +14,20 @@ @pytest.mark.liveTest -def test_client_secret_credential(live_eventhub): +def test_client_secret_credential(live_eventhub, uamqp_transport): credential = EnvironmentCredential() producer_client = EventHubProducerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') batch.add(EventData(body='A single message')) @@ -51,10 +54,12 @@ def on_event(partition_context, event): @pytest.mark.liveTest -def test_client_sas_credential(live_eventhub): +def test_client_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -67,7 +72,8 @@ def test_client_sas_credential(live_eventhub): token = credential.get_token(auth_uri).token producer_client = EventHubProducerClient(fully_qualified_namespace=hostname, eventhub_name=live_eventhub['event_hub'], - credential=EventHubSASTokenCredential(token, time.time() + 3000)) + credential=EventHubSASTokenCredential(token, time.time() + 3000), + uamqp_transport=uamqp_transport) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -77,7 +83,8 @@ def test_client_sas_credential(live_eventhub): # Finally let's do it with SAS token + conn str token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str, - eventhub_name=live_eventhub['event_hub']) + eventhub_name=live_eventhub['event_hub'], + uamqp_transport=uamqp_transport) with conn_str_producer_client: batch = conn_str_producer_client.create_batch(partition_id='0') @@ -86,10 +93,12 @@ def test_client_sas_credential(live_eventhub): @pytest.mark.liveTest -def test_client_azure_sas_credential(live_eventhub): +def test_client_azure_sas_credential(live_eventhub, uamqp_transport): # This should "just work" to validate known-good. hostname = live_eventhub['hostname'] - producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + producer_client = EventHubProducerClient.from_connection_string( + live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'], uamqp_transport=uamqp_transport + ) with producer_client: batch = producer_client.create_batch(partition_id='0') @@ -111,13 +120,14 @@ def test_client_azure_sas_credential(live_eventhub): @pytest.mark.liveTest -def test_client_azure_named_key_credential(live_eventhub): +def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential = AzureNamedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) consumer_client = EventHubConsumerClient(fully_qualified_namespace=live_eventhub['hostname'], eventhub_name=live_eventhub['event_hub'], consumer_group='$default', credential=credential, - user_agent='customized information') + user_agent='customized information', + uamqp_transport=uamqp_transport) assert consumer_client.get_eventhub_properties() is not None diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py index 631b45595224..99b9f3da29a2 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py @@ -40,25 +40,26 @@ def random_pkey_generation(partitions): @pytest.mark.liveTest() -def test_producer_client_constructor(connection_str): +def test_producer_client_constructor(connection_str, uamqp_transport): def on_success(events, pid): pass def on_error(events, error, pid): pass with pytest.raises(TypeError): - EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True) + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, uamqp_transport=uamqp_transport) with pytest.raises(TypeError): - EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_success=on_success) + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_success=on_success, uamqp_transport=uamqp_transport) with pytest.raises(TypeError): - EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_error=on_error) + EventHubProducerClient.from_connection_string(connection_str, buffered_mode=True, on_error=on_error, uamqp_transport=uamqp_transport) with pytest.raises(ValueError): EventHubProducerClient.from_connection_string( connection_str, buffered_mode=True, on_success=on_success, on_error=on_error, - max_wait_time=0 + max_wait_time=0, + uamqp_transport=uamqp_transport ) with pytest.raises(ValueError): EventHubProducerClient.from_connection_string( @@ -66,7 +67,8 @@ def on_error(events, error, pid): buffered_mode=True, on_success=on_success, on_error=on_error, - max_buffer_length=0 + max_buffer_length=0, + uamqp_transport=uamqp_transport ) @@ -80,13 +82,13 @@ def on_error(events, error, pid): ] ) @pytest.mark.liveTest -def test_basic_send_single_events_round_robin(connection_str, flush_after_sending, close_after_sending): +def test_basic_send_single_events_round_robin(connection_str, flush_after_sending, close_after_sending, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -107,7 +109,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -185,13 +188,13 @@ def on_error(events, pid, err): (False, False) ] ) -def test_basic_send_batch_events_round_robin(connection_str, flush_after_sending, close_after_sending): +def test_basic_send_batch_events_round_robin(connection_str, flush_after_sending, close_after_sending, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -209,7 +212,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -292,13 +296,13 @@ def on_error(events, pid, err): @pytest.mark.liveTest -def test_send_with_hybrid_partition_assignment(connection_str): +def test_send_with_hybrid_partition_assignment(connection_str, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -316,7 +320,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -381,13 +386,13 @@ def on_error(events, pid, err): receive_thread.join() -def test_send_with_timing_configuration(connection_str): +def test_send_with_timing_configuration(connection_str, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -408,7 +413,8 @@ def on_error(events, pid, err): buffered_mode=True, max_wait_time=10, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: @@ -428,7 +434,8 @@ def on_error(events, pid, err): max_wait_time=1000, max_buffer_length=10, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) sent_events.clear() @@ -457,13 +464,13 @@ def on_error(events, pid, err): @pytest.mark.liveTest -def test_long_sleep(connection_str): +def test_long_sleep(connection_str, uamqp_transport): received_events = defaultdict(list) def on_event(partition_context, event): received_events[partition_context.partition_id].append(event) - consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default", uamqp_transport=uamqp_transport) receive_thread = Thread(target=consumer.receive, args=(on_event,)) receive_thread.daemon = True receive_thread.start() @@ -481,7 +488,8 @@ def on_error(events, pid, err): connection_str, buffered_mode=True, on_success=on_success, - on_error=on_error + on_error=on_error, + uamqp_transport=uamqp_transport ) with producer: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 8da5ddeb6ead..8b7420075ef4 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -9,11 +9,16 @@ @pytest.mark.liveTest -def test_receive_no_partition(connstr_senders): +def test_receive_no_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', receive_timeout=1) + client = EventHubConsumerClient.from_connection_string( + connection_str, + consumer_group='$default', + receive_timeout=1, + uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -36,7 +41,7 @@ def on_event(partition_context, event): args=(on_event,), kwargs={"starting_position": "-1"}) worker.start() - time.sleep(10) + time.sleep(20) assert on_event.received == 2 checkpoints = list(client._event_processors.values())[0]._checkpoint_store.list_checkpoints( on_event.namespace, on_event.eventhub_name, on_event.consumer_group @@ -46,10 +51,12 @@ def on_event(partition_context, event): @pytest.mark.liveTest -def test_receive_partition(connstr_senders): +def test_receive_partition(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): on_event.received += 1 @@ -74,16 +81,18 @@ def on_event(partition_context, event): @pytest.mark.liveTest -def test_receive_load_balancing(connstr_senders): +def test_receive_load_balancing(connstr_senders, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - test code using multiple threads. Sometimes OSX aborts python process") connection_str, senders = connstr_senders cs = InMemoryCheckpointStore() client1 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) client2 = EventHubConsumerClient.from_connection_string( - connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1) + connection_str, consumer_group='$default', checkpoint_store=cs, load_balancing_interval=1, uamqp_transport=uamqp_transport + ) def on_event(partition_context, event): pass @@ -105,13 +114,15 @@ def on_event(partition_context, event): assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 -def test_receive_batch_no_max_wait_time(connstr_senders): +def test_receive_batch_no_max_wait_time(connstr_senders, uamqp_transport): '''Test whether callback is called when max_wait_time is None and max_batch_size has reached ''' connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) @@ -133,7 +144,7 @@ def on_event_batch(partition_context, event_batch): worker = threading.Thread(target=client.receive_batch, args=(on_event_batch,), kwargs={"starting_position": "-1"}) worker.start() - time.sleep(10) + time.sleep(20) assert on_event_batch.received == 2 checkpoints = list(client._event_processors.values())[0]._checkpoint_store.list_checkpoints( @@ -146,14 +157,14 @@ def on_event_batch(partition_context, event_batch): worker.join() + @pytest.mark.parametrize("max_wait_time, sleep_time, expected_result", [(3, 10, []), - (3, 2, None), - ]) -def test_receive_batch_empty_with_max_wait_time(connection_str, max_wait_time, sleep_time, expected_result): + (3, 2, None)]) +def test_receive_batch_empty_with_max_wait_time(uamqp_transport, connection_str, max_wait_time, sleep_time, expected_result): '''Test whether event handler is called when max_wait_time > 0 and no event is received ''' - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) def on_event_batch(partition_context, event_batch): on_event_batch.event_batch = event_batch @@ -168,13 +179,15 @@ def on_event_batch(partition_context, event_batch): worker.join() -def test_receive_batch_early_callback(connstr_senders): +def test_receive_batch_early_callback(connstr_senders, uamqp_transport): ''' Test whether the callback is called once max_batch_size reaches and before max_wait_time reaches. ''' connection_str, senders = connstr_senders for _ in range(10): senders[0].send(EventData("Test EventData")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) def on_event_batch(partition_context, event_batch): on_event_batch.received += len(event_batch) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py index 3b7249c2cace..ba58cf2a00c7 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_negative.py @@ -19,14 +19,19 @@ ) from azure.eventhub import EventHubConsumerClient from azure.eventhub import EventHubProducerClient +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None @pytest.mark.liveTest -def test_send_batch_with_invalid_hostname(invalid_hostname): +def test_send_batch_with_invalid_hostname(invalid_hostname, uamqp_transport): + amqp_transport = UamqpTransport if uamqp_transport else None if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): batch = EventDataBatch() @@ -40,7 +45,7 @@ def on_error(events, pid, err): on_error.err = err on_error.err = None - client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error, uamqp_transport=uamqp_transport) with client: batch = EventDataBatch() batch.add(EventData("test data")) @@ -48,18 +53,20 @@ def on_error(events, pid, err): assert isinstance(on_error.err, ConnectError) on_error.err = None - client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error) + client = EventHubProducerClient.from_connection_string(invalid_hostname, on_error=on_error, uamqp_transport=uamqp_transport) with client: client.send_event(EventData("test data")) assert isinstance(on_error.err, ConnectError) @pytest.mark.liveTest -def test_receive_with_invalid_hostname_sync(invalid_hostname): +def test_receive_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): def on_event(partition_context, event): pass - client = EventHubConsumerClient.from_connection_string(invalid_hostname, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + invalid_hostname, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event, )) @@ -70,8 +77,9 @@ def on_event(partition_context, event): @pytest.mark.liveTest -def test_send_batch_with_invalid_key(invalid_key): - client = EventHubProducerClient.from_connection_string(invalid_key) +def test_send_batch_with_invalid_key(invalid_key, uamqp_transport): + client = EventHubProducerClient.from_connection_string(invalid_key, uamqp_transport=uamqp_transport) + amqp_transport = UamqpTransport if uamqp_transport else None try: with pytest.raises(ConnectError): batch = EventDataBatch() @@ -82,10 +90,10 @@ def test_send_batch_with_invalid_key(invalid_key): @pytest.mark.liveTest -def test_send_batch_to_invalid_partitions(connection_str): +def test_send_batch_to_invalid_partitions(connection_str, uamqp_transport): partitions = ["XYZ", "-1", "1000", "-"] for p in partitions: - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ConnectError): batch = client.create_batch(partition_id=p) @@ -96,10 +104,10 @@ def test_send_batch_to_invalid_partitions(connection_str): @pytest.mark.liveTest -def test_send_batch_too_large_message(connection_str): +def test_send_batch_too_large_message(connection_str, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: data = EventData(b"A" * 1100000) batch = client.create_batch() @@ -110,8 +118,8 @@ def test_send_batch_too_large_message(connection_str): @pytest.mark.liveTest -def test_send_batch_null_body(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_batch_null_body(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) try: with pytest.raises(ValueError): data = EventData(None) @@ -123,19 +131,19 @@ def test_send_batch_null_body(connection_str): @pytest.mark.liveTest -def test_create_batch_with_invalid_hostname_sync(invalid_hostname): +def test_create_batch_with_invalid_hostname_sync(invalid_hostname, uamqp_transport): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " "and blocking other tests") - client = EventHubProducerClient.from_connection_string(invalid_hostname) + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) with client: with pytest.raises(ConnectError): client.create_batch(max_size_in_bytes=300) @pytest.mark.liveTest -def test_create_batch_with_too_large_size_sync(connection_str): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_create_batch_with_too_large_size_sync(connection_str, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(ValueError): client.create_batch(max_size_in_bytes=5 * 1024 * 1024) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py index eb197eec44b0..6a4cb8b6eccf 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_properties.py @@ -12,57 +12,67 @@ @pytest.mark.liveTest -def test_get_properties(live_eventhub): +def test_get_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_eventhub_properties() assert properties['eventhub_name'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] @pytest.mark.liveTest -def test_get_properties_with_auth_error_sync(live_eventhub): +def test_get_properties_with_auth_error_sync(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf=")) + EventHubSharedKeyCredential(live_eventhub['key_name'], "AaBbCcDdEeFf="), + uamqp_transport=uamqp_transport + ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential("invalid", live_eventhub['access_key']) + EventHubSharedKeyCredential("invalid", live_eventhub['access_key']), uamqp_transport=uamqp_transport ) with client: with pytest.raises(AuthenticationError) as e: client.get_eventhub_properties() @pytest.mark.liveTest -def test_get_properties_with_connect_error(live_eventhub): +def test_get_properties_with_connect_error(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], "invalid", '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(ConnectError) as e: client.get_eventhub_properties() client = EventHubConsumerClient("invalid.servicebus.windows.net", live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport ) with client: with pytest.raises(EventHubError) as e: # This can be either ConnectError or ConnectionLostError client.get_eventhub_properties() @pytest.mark.liveTest -def test_get_partition_ids(live_eventhub): +def test_get_partition_ids(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: partition_ids = client.get_partition_ids() assert partition_ids == ['0', '1'] @pytest.mark.liveTest -def test_get_partition_properties(live_eventhub): +def test_get_partition_properties(live_eventhub, uamqp_transport): client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], '$default', - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']), + uamqp_transport=uamqp_transport + ) with client: properties = client.get_partition_properties('0') assert properties['eventhub_name'] == live_eventhub['event_hub'] \ diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 21d6e249581e..22133f7983e3 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -9,13 +9,14 @@ import pytest import time import datetime +import uamqp from azure.eventhub import EventData, TransportType, EventHubConsumerClient from azure.eventhub.exceptions import EventHubError @pytest.mark.liveTest -def test_receive_end_of_stream(connstr_senders): +def test_receive_end_of_stream(connstr_senders, uamqp_transport): def on_event(partition_context, event): if partition_context.partition_id == "0": assert event.body_as_str() == "Receiving only a single event" @@ -29,7 +30,9 @@ def on_event(partition_context, event): assert ", partition_key: 0" in event_str on_event.called = False connection_str, senders = connstr_senders - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "@latest"}) @@ -50,7 +53,7 @@ def on_event(partition_context, event): ("sequence", True, "Inclusive"), ("enqueued_time", False, "Exclusive")]) @pytest.mark.liveTest -def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result): +def test_receive_with_event_position_sync(uamqp_transport, connstr_senders, position, inclusive, expected_result): def on_event(partition_context, event): assert partition_context.last_enqueued_event_properties.get('sequence_number') == event.sequence_number assert partition_context.last_enqueued_event_properties.get('offset') == event.offset @@ -69,7 +72,9 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders senders[0].send(EventData(b"Inclusive")) senders[1].send(EventData(b"Inclusive")) - client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1", @@ -82,7 +87,9 @@ def on_event(partition_context, event): thread.join() senders[0].send(EventData(expected_result)) senders[1].send(EventData(expected_result)) - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client2 = EventHubConsumerClient.from_connection_string( + connection_str, consumer_group='$default', uamqp_transport=uamqp_transport + ) with client2: thread = threading.Thread(target=client2.receive, args=(on_event,), kwargs={"starting_position": on_event.event_position, @@ -90,14 +97,14 @@ def on_event(partition_context, event): "track_last_enqueued_event_properties": True}) thread.daemon = True thread.start() - time.sleep(10) + time.sleep(15) assert on_event.event.body_as_str() == expected_result thread.join() @pytest.mark.liveTest -def test_receive_owner_level(connstr_senders): +def test_receive_owner_level(connstr_senders, uamqp_transport): def on_event(partition_context, event): pass def on_error(partition_context, error): @@ -105,8 +112,8 @@ def on_error(partition_context, error): on_error.error = None connection_str, senders = connstr_senders - client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') - client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') + client1 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) + client2 = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', uamqp_transport=uamqp_transport) with client1, client2: thread1 = threading.Thread(target=client1.receive, args=(on_event,), kwargs={"partition_id": "0", "starting_position": "-1", @@ -129,7 +136,7 @@ def on_error(partition_context, error): @pytest.mark.liveTest -def test_receive_over_websocket_sync(connstr_senders): +def test_receive_over_websocket_sync(connstr_senders, uamqp_transport): app_prop = {"raw_prop": "raw_value"} content_type = "text/plain" message_id_base = "mess_id_sample_" @@ -143,7 +150,8 @@ def on_event(partition_context, event): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default', - transport_type=TransportType.AmqpOverWebsocket) + transport_type=TransportType.AmqpOverWebsocket, + uamqp_transport=uamqp_transport) event_list = [] for i in range(5): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py index 0abfa7a12d2f..c07489acface 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_reconnect.py @@ -7,23 +7,26 @@ import time import pytest -import uamqp -from uamqp import authentication, errors, c_uamqp, compat - - from azure.eventhub import ( EventData, EventHubSharedKeyCredential, EventHubProducerClient, - EventHubConsumerClient + EventHubConsumerClient, ) from azure.eventhub.exceptions import OperationTimeoutError +from azure.eventhub._utils import transform_outbound_single_message +import uamqp +from uamqp import compat +from azure.eventhub._transport._uamqp_transport import UamqpTransport + @pytest.mark.liveTest -def test_send_with_long_interval_sync(live_eventhub, sleep): +def test_send_with_long_interval_sync(live_eventhub, sleep, uamqp_transport, timeout_factor): test_partition = "0" sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], - EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + EventHubSharedKeyCredential(live_eventhub['key_name'], + live_eventhub['access_key']), uamqp_transport=uamqp_transport + ) with sender: batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) @@ -31,7 +34,10 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): if sleep: time.sleep(250) else: - sender._producers[test_partition]._handler._connection._conn.destroy() + if uamqp_transport: + sender._producers[test_partition]._handler._connection._conn.destroy() + else: + pass batch = sender.create_batch(partition_id=test_partition) batch.add(EventData(b"A single event")) sender.send_batch(batch) @@ -39,22 +45,23 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): received = [] uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) - sas_auth = authentication.SASTokenAuth.from_shared_access_key( - uri, live_eventhub['key_name'], live_eventhub['access_key']) - + if uamqp_transport: + sas_auth = uamqp.authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub['key_name'], live_eventhub['access_key']) source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( live_eventhub['hostname'], live_eventhub['event_hub'], live_eventhub['consumer_group'], test_partition) - receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) + if uamqp_transport: + receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) try: receiver.open() # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. # So call it twice to ensure the two events are received. - received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5000)]) - received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5000)]) + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5 * timeout_factor)]) + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(max_batch_size=1, timeout=5 * timeout_factor)]) finally: receiver.close() assert len(received) == 2 @@ -62,42 +69,65 @@ def test_send_with_long_interval_sync(live_eventhub, sleep): @pytest.mark.liveTest -def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers): +def test_send_connection_idle_timeout_and_reconnect_sync(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(conn_str=connection_str, idle_timeout=10) + amqp_transport = UamqpTransport + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, uamqp_transport=uamqp_transport + ) with client: ed = EventData('data') sender = client._create_producer(partition_id='0') with sender: - sender._open_with_retry() - time.sleep(11) - sender._unsent_events = [ed.message] - ed.message.on_send_complete = sender._on_outcome + sender._open_with_retry() + time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed._message] + if uamqp_transport: + sender._unsent_events[0].on_send_complete = sender._on_outcome with pytest.raises((uamqp.errors.ConnectionClose, - uamqp.errors.MessageHandlerError, OperationTimeoutError)): - # Mac may raise OperationTimeoutError or MessageHandlerError + uamqp.errors.MessageHandlerError, OperationTimeoutError)): sender._send_event_data() + else: + # for pyamqp add later + pass + if uamqp_transport: sender._send_event_data_with_retry() + if not uamqp_transport: + client = EventHubProducerClient.from_connection_string( + conn_str=connection_str, idle_timeout=10, uamqp_transport=uamqp_transport + ) + with client: + ed = EventData('data') + sender = client._create_producer(partition_id='0') + with sender: + sender._open_with_retry() + time.sleep(11) + ed = transform_outbound_single_message(ed, EventData, amqp_transport.to_outgoing_amqp_message) + sender._unsent_events = [ed._message] + sender._send_event_data() + retry = 0 while retry < 3: try: - messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=10000) + messages = receivers[0].receive_message_batch(max_batch_size=10, timeout=10 * timeout_factor) if messages: received_ed1 = EventData._from_message(messages[0]) assert received_ed1.body_as_str() == 'data' break - except compat.TimeoutException: + except (compat.TimeoutException, TimeoutError): retry += 1 @pytest.mark.liveTest -def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders): +def test_receive_connection_idle_timeout_and_reconnect_sync(connstr_senders, uamqp_transport): connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string( conn_str=connection_str, consumer_group='$default', - idle_timeout=10 + idle_timeout=10, + uamqp_transport=uamqp_transport ) def on_event_received(event): @@ -112,7 +142,7 @@ def on_event_received(event): senders[0].send(ed) consumer._handler.do_work() - assert consumer._handler._connection._state == c_uamqp.ConnectionState.DISCARDING + assert consumer._handler._connection._state == uamqp.c_uamqp.ConnectionState.DISCARDING duration = 10 now_time = time.time() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 0276de9aae0d..d75b4d013470 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -12,6 +12,7 @@ import sys import uamqp +from uamqp.message import MessageProperties from azure.eventhub import EventData, TransportType, EventDataBatch from azure.eventhub import EventHubProducerClient, EventHubConsumerClient from azure.eventhub.exceptions import EventDataSendError, OperationTimeoutError @@ -21,12 +22,16 @@ AmqpAnnotatedMessage, AmqpMessageProperties, ) +try: + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except (ImportError, ModuleNotFoundError): + UamqpTransport = None @pytest.mark.liveTest -def test_send_with_partition_key(connstr_receivers, live_eventhub): +def test_send_with_partition_key(connstr_receivers, live_eventhub, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: @@ -54,7 +59,7 @@ def test_send_with_partition_key(connstr_receivers, live_eventhub): for index, partition in enumerate(receivers): retry_total = 0 while retry_total < 3: - timeout = 5000 + retry_total * 1000 + timeout = (5 + retry_total) * timeout_factor try: received = partition.receive_message_batch(timeout=timeout) for message in received: @@ -98,11 +103,11 @@ def test_send_with_partition_key(connstr_receivers, live_eventhub): @pytest.mark.liveTest -def test_send_and_receive_large_body_size(connstr_receivers): +def test_send_and_receive_large_body_size(connstr_receivers, uamqp_transport, timeout_factor): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: payload = 250 * 1024 batch = client.create_batch() @@ -111,8 +116,9 @@ def test_send_and_receive_large_body_size(connstr_receivers): client.send_event(EventData("A" * payload)) received = [] + timeout = 10 * timeout_factor for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 2 assert len(list(received[0].body)[0]) == payload @@ -128,7 +134,7 @@ def test_send_and_receive_large_body_size(connstr_receivers): received = [] for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 2 assert len(list(received[0].body)[0]) == payload @@ -136,9 +142,9 @@ def test_send_and_receive_large_body_size(connstr_receivers): @pytest.mark.liveTest -def test_send_amqp_annotated_message(connstr_receivers): +def test_send_amqp_annotated_message(connstr_receivers, uamqp_transport): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: sequence_body = [b'message', 123.456, True] footer = {'footer_key': 'footer_value'} @@ -174,7 +180,7 @@ def test_send_amqp_annotated_message(connstr_receivers): ) body_ed = """{"json_key": "json_val"}""" - prop_ed = {"raw_prop": "raw_value"} + prop_ed = {b"raw_prop": b"raw_value"} cont_type_ed = "text/plain" corr_id_ed = "corr_id" mess_id_ed = "mess_id" @@ -182,6 +188,7 @@ def test_send_amqp_annotated_message(connstr_receivers): event_data.content_type = cont_type_ed event_data.correlation_id = corr_id_ed event_data.message_id = mess_id_ed + event_data.properties = prop_ed batch = client.create_batch() batch.add(data_message) @@ -216,6 +223,7 @@ def check_values(event): assert event.correlation_id == corr_id_ed assert event.message_id == mess_id_ed assert event.content_type == cont_type_ed + assert event.properties == prop_ed assert event.body_type == AmqpMessageBodyType.DATA received_count["normal_msg"] += 1 elif raw_amqp_message.body_type == AmqpMessageBodyType.SEQUENCE: @@ -238,7 +246,8 @@ def on_event(partition_context, event): on_event.received = [] client = EventHubConsumerClient.from_connection_string(connection_str, - consumer_group='$default') + consumer_group='$default', + uamqp_transport=uamqp_transport) with client: thread = threading.Thread(target=client.receive, args=(on_event,), kwargs={"starting_position": "-1"}) @@ -257,17 +266,18 @@ def on_event(partition_context, event): @pytest.mark.parametrize("payload", [b"", b"A single event"]) @pytest.mark.liveTest -def test_send_and_receive_small_body(connstr_receivers, payload): +def test_send_and_receive_small_body(connstr_receivers, payload, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() batch.add(EventData(payload)) client.send_batch(batch) client.send_event(EventData(payload)) received = [] + timeout = 5 * timeout_factor for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=5000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 2 assert list(received[0].body)[0] == payload @@ -275,9 +285,10 @@ def test_send_and_receive_small_body(connstr_receivers, payload): @pytest.mark.liveTest -def test_send_partition(connstr_receivers): +def test_send_partition(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 5 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() @@ -291,8 +302,8 @@ def test_send_partition(connstr_receivers): client.send_batch(batch) client.send_event(EventData(b"Data"), partition_id="1") - partition_0 = receivers[0].receive_message_batch(timeout=5000) - partition_1 = receivers[1].receive_message_batch(timeout=5000) + partition_0 = receivers[0].receive_message_batch(timeout=timeout) + partition_1 = receivers[1].receive_message_batch(timeout=timeout) assert len(partition_1) >= 2 assert len(partition_0) + len(partition_1) == 4 @@ -309,16 +320,17 @@ def test_send_partition(connstr_receivers): client.send_event(EventData(b"Data"), partition_id="0") time.sleep(5) - partition_0 = receivers[0].receive_message_batch(timeout=5000) - partition_1 = receivers[1].receive_message_batch(timeout=5000) + partition_0 = receivers[0].receive_message_batch(timeout=timeout) + partition_1 = receivers[1].receive_message_batch(timeout=timeout) assert len(partition_0) >= 2 assert len(partition_0) + len(partition_1) == 4 @pytest.mark.liveTest -def test_send_non_ascii(connstr_receivers): +def test_send_non_ascii(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 5 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: batch = client.create_batch(partition_id="0") batch.add(EventData(u"é,è,à,ù,â,ê,î,ô,û")) @@ -330,8 +342,8 @@ def test_send_non_ascii(connstr_receivers): # receive_message_batch() returns immediately once it receives any messages before the max_batch_size # and timeout reach. Could be 1, 2, or any number between 1 and max_batch_size. # So call it twice to ensure the two events are received. - partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + \ - [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=timeout)] + \ + [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=timeout)] assert len(partition_0) == 4 assert partition_0[0].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" assert partition_0[1].body_as_json() == {"foo": u"漢字"} @@ -340,12 +352,13 @@ def test_send_non_ascii(connstr_receivers): @pytest.mark.liveTest -def test_send_multiple_partitions_with_app_prop(connstr_receivers): +def test_send_multiple_partitions_with_app_prop(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers + timeout = 5 * timeout_factor app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str) + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: ed0 = EventData(b"Message 0") ed0.properties = app_prop @@ -361,20 +374,23 @@ def test_send_multiple_partitions_with_app_prop(connstr_receivers): client.send_batch(batch) client.send_event(ed1, partition_id="1") - partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=timeout)] assert len(partition_0) == 2 assert partition_0[0].properties[b"raw_prop"] == b"raw_value" assert partition_0[1].properties[b"raw_prop"] == b"raw_value" - partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=5000)] + partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=timeout)] assert len(partition_1) == 2 assert partition_1[0].properties[b"raw_prop"] == b"raw_value" assert partition_1[1].properties[b"raw_prop"] == b"raw_value" @pytest.mark.liveTest -def test_send_over_websocket_sync(connstr_receivers): +def test_send_over_websocket_sync(connstr_receivers, uamqp_transport, timeout_factor): + timeout = 10 * timeout_factor connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=uamqp.constants.TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: batch = client.create_batch(partition_id="0") @@ -384,17 +400,20 @@ def test_send_over_websocket_sync(connstr_receivers): time.sleep(1) received = [] - received.extend(receivers[0].receive_message_batch(max_batch_size=5, timeout=10000)) + received.extend(receivers[0].receive_message_batch(max_batch_size=5, timeout=timeout)) assert len(received) == 2 @pytest.mark.liveTest -def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): +def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers + timeout = 5 * timeout_factor app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + client = EventHubProducerClient.from_connection_string( + connection_str, transport_type=TransportType.AmqpOverWebsocket, uamqp_transport=uamqp_transport + ) with client: event_data_batch = client.create_batch(max_size_in_bytes=100000) while True: @@ -407,61 +426,66 @@ def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): client.send_batch(event_data_batch) received = [] for r in receivers: - received.extend(r.receive_message_batch(timeout=5000)) + received.extend(r.receive_message_batch(timeout=timeout)) assert len(received) >= 1 assert EventData._from_message(received[0]).properties[b"raw_prop"] == b"raw_value" @pytest.mark.liveTest -def test_send_list(connstr_receivers): +def test_send_list(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 10 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)]) received = [] for r in receivers: - received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) assert len(received) == 1 assert received[0].body_as_str() == payload @pytest.mark.liveTest -def test_send_list_partition(connstr_receivers): +def test_send_list_partition(connstr_receivers, uamqp_transport, timeout_factor): connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) + timeout = 10 * timeout_factor + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) payload = "A1" with client: client.send_batch([EventData(payload)], partition_id="0") - message = receivers[0].receive_message_batch(timeout=10000)[0] + message = receivers[0].receive_message_batch(timeout=timeout)[0] received = EventData._from_message(message) assert received.body_as_str() == payload + @pytest.mark.parametrize("to_send, exception_type", [([EventData("A"*1024)]*1100, ValueError), - ("any str", AttributeError) - ]) + ("any str", AttributeError)]) @pytest.mark.liveTest -def test_send_list_wrong_data(connection_str, to_send, exception_type): - client = EventHubProducerClient.from_connection_string(connection_str) +def test_send_list_wrong_data(connection_str, to_send, exception_type, uamqp_transport): + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) with client: with pytest.raises(exception_type): client.send_batch(to_send) + @pytest.mark.parametrize("partition_id, partition_key", [("0", None), (None, "pk")]) -def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key): +def test_send_batch_pid_pk(invalid_hostname, partition_id, partition_key, uamqp_transport): # Use invalid_hostname because this is not a live test. - client = EventHubProducerClient.from_connection_string(invalid_hostname) + amqp_transport = UamqpTransport if uamqp_transport else None + client = EventHubProducerClient.from_connection_string(invalid_hostname, uamqp_transport=uamqp_transport) batch = EventDataBatch(partition_id=partition_id, partition_key=partition_key) with client: with pytest.raises(TypeError): client.send_batch(batch, partition_id=partition_id, partition_key=partition_key) -def test_send_with_callback(connstr_receivers): + +def test_send_with_callback(connstr_receivers, uamqp_transport): def on_error(events, pid, err): on_error.err = err @@ -472,7 +496,7 @@ def on_success(events, pid): sent_events = [] on_error.err = None connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str, on_success=on_success, on_error=on_error) + client = EventHubProducerClient.from_connection_string(connection_str, on_success=on_success, on_error=on_error, uamqp_transport=uamqp_transport) with client: batch = client.create_batch() @@ -506,3 +530,63 @@ def on_success(events, pid): assert sent_events[-1][1] == "0" assert not on_error.err + +# TODO: add more checks after LegacyMessage has been added +@pytest.mark.liveTest +def test_send_message_modify_backcompat(connstr_receivers, uamqp_transport, timeout_factor): + connection_str, receivers = connstr_receivers + if uamqp_transport: + properties = MessageProperties + + timeout = 10 * timeout_factor + outgoing_event_data = EventData(body="hello") + message = outgoing_event_data.message + message.properties = properties(user_id='fake_user') + assert outgoing_event_data.message.properties.user_id == b'fake_user' + assert outgoing_event_data.message.state == uamqp.constants.MessageState.WaitingToBeSent + assert outgoing_event_data.message.delivery_annotations is None + assert outgoing_event_data.message.delivery_no is None + assert outgoing_event_data.message.delivery_tag is None + assert outgoing_event_data.message.on_send_complete is None + assert outgoing_event_data.message.footer is None + assert outgoing_event_data.message.retries == 0 + assert outgoing_event_data.message.idle_time == 0 + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) + with client: + client.send_batch([outgoing_event_data]) + received = [] + for r in receivers: + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) + + assert len(received) == 1 + received_ed = received[0] + # check that setting properties directly on uamqp message doesn't update the outgoing message from the event data + assert received_ed.message.properties.user_id is None + assert received_ed.message.state == uamqp.constants.MessageState.ReceivedSettled + assert received_ed.message.delivery_annotations is None + assert received_ed.message.delivery_no >= 1 + assert received_ed.message.delivery_tag is None + assert received_ed.message.on_send_complete is None + assert received_ed.message.footer is None + assert received_ed.message.retries >= 0 + assert received_ed.message.idle_time >= 0 + + # setting message properties by calling event data properties SHOULD update the outgoing uamqp message + received_ed.properties = {'prop': 'test'} + received_ed.message_id = "id_message" + received_ed.content_type = "content type" + received_ed.correlation_id = "correlation" + + client = EventHubProducerClient.from_connection_string(connection_str, uamqp_transport=uamqp_transport) + with client: + client.send_batch([received_ed]) + received = [] + for r in receivers: + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=timeout)]) + + assert len(received) == 1 + received_ed = received[0] + assert received_ed.message.application_properties == {b"prop": b"test"} + assert received_ed.message_id == "id_message" + assert received_ed.content_type == "content type" + assert received_ed.correlation_id == "correlation" diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index ef562c2628a5..37a8b9056758 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -1,9 +1,22 @@ +# -- coding: utf-8 -- +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + import platform import pytest -import uamqp from packaging import version -from azure.eventhub.amqp import AmqpAnnotatedMessage +try: + import uamqp + from azure.eventhub._transport._uamqp_transport import UamqpTransport +except ImportError: + UamqpTransport = None + pass +from azure.eventhub.amqp import AmqpAnnotatedMessage, AmqpMessageHeader, AmqpMessageProperties from azure.eventhub import _common +from azure.eventhub._utils import transform_outbound_single_message pytestmark = pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="This is ignored for PyPy") @@ -55,23 +68,26 @@ def test_app_properties(): assert event_data.properties["a"] == "b" -def test_sys_properties(): - properties = uamqp.message.MessageProperties() - properties.message_id = "message_id" - properties.user_id = "user_id" - properties.to = "to" - properties.subject = "subject" - properties.reply_to = "reply_to" - properties.correlation_id = "correlation_id" - properties.content_type = "content_type" - properties.content_encoding = "content_encoding" - properties.absolute_expiry_time = 1 - properties.creation_time = 1 - properties.group_id = "group_id" - properties.group_sequence = 1 - properties.reply_to_group_id = "reply_to_group_id" - message = uamqp.Message(properties=properties) - message.annotations = {_common.PROP_OFFSET: "@latest"} +def test_sys_properties(uamqp_transport): + if uamqp_transport: + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + pass ed = EventData._from_message(message) # type: EventData assert ed.system_properties[_common.PROP_OFFSET] == "@latest" @@ -90,39 +106,98 @@ def test_sys_properties(): assert ed.system_properties[_common.PROP_REPLY_TO_GROUP_ID] == properties.reply_to_group_id -def test_event_data_batch(): +def test_event_data_batch(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + if version.parse(uamqp.__version__) >= version.parse("1.2.8"): + expected_result = 101 + else: + expected_result = 93 + else: + pass batch = EventDataBatch(max_size_in_bytes=110, partition_key="par") batch.add(EventData("A")) assert str(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" assert repr(batch) == "EventDataBatch(max_size_in_bytes=110, partition_id=None, partition_key='par', event_count=1)" - # In uamqp v1.2.8, the encoding size of a message has changed. delivery_count in message header is now set to 0 - # instead of None according to the C spec. - # This uamqp change is transparent to EH users so it's not considered as a breaking change. However, it's breaking - # the unit test here. The solution is to add backward compatibility in test. - if version.parse(uamqp.__version__) >= version.parse("1.2.8"): - assert batch.size_in_bytes == 101 and len(batch) == 1 - else: - assert batch.size_in_bytes == 93 and len(batch) == 1 + assert batch.size_in_bytes == expected_result and len(batch) == 1 + with pytest.raises(ValueError): batch.add(EventData("A")) -def test_event_data_from_message(): - message = uamqp.Message('A') + + +def test_event_data_from_message(uamqp_transport): + if uamqp_transport: + amqp_transport = UamqpTransport() + else: + pass + annotated_message = AmqpAnnotatedMessage(data_body=b'A') + message = amqp_transport.to_outgoing_amqp_message(annotated_message) event = EventData._from_message(message) assert event.content_type is None assert event.correlation_id is None assert event.message_id is None event.content_type = 'content_type' - event.correlation_id = 'correlation_id' + event.correlation_id = 'correlation_id' event.message_id = 'message_id' assert event.content_type == 'content_type' - assert event.correlation_id == 'correlation_id' + assert event.correlation_id == 'correlation_id' assert event.message_id == 'message_id' + assert list(event.body) == [b'A'] + def test_amqp_message_str_repr(): data_body = b'A' message = AmqpAnnotatedMessage(data_body=data_body) assert str(message) == 'A' assert 'AmqpAnnotatedMessage(body=A, body_type=data' in repr(message) + + +def test_amqp_message_from_message(uamqp_transport): + if uamqp_transport: + header = uamqp.message.MessageHeader() + header.delivery_count = 1 + header.time_to_live = 10000 + header.first_acquirer = True + header.durable = True + header.priority = 1 + properties = uamqp.message.MessageProperties() + properties.message_id = "message_id" + properties.user_id = "user_id" + properties.to = "to" + properties.subject = "subject" + properties.reply_to = "reply_to" + properties.correlation_id = "correlation_id" + properties.content_type = "content_type" + properties.content_encoding = "content_encoding" + properties.absolute_expiry_time = 1 + properties.creation_time = 1 + properties.group_id = "group_id" + properties.group_sequence = 1 + properties.reply_to_group_id = "reply_to_group_id" + message = uamqp.message.Message(header=header, properties=properties) + message.annotations = {_common.PROP_OFFSET: "@latest"} + else: + pass + + amqp_message = AmqpAnnotatedMessage(message=message) + assert amqp_message.properties.message_id == message.properties.message_id + assert amqp_message.properties.user_id == message.properties.user_id + assert amqp_message.properties.to == message.properties.to + assert amqp_message.properties.subject == message.properties.subject + assert amqp_message.properties.reply_to == message.properties.reply_to + assert amqp_message.properties.correlation_id == message.properties.correlation_id + assert amqp_message.properties.content_type == message.properties.content_type + assert amqp_message.properties.absolute_expiry_time == message.properties.absolute_expiry_time + assert amqp_message.properties.creation_time == message.properties.creation_time + assert amqp_message.properties.group_id == message.properties.group_id + assert amqp_message.properties.group_sequence == message.properties.group_sequence + assert amqp_message.properties.reply_to_group_id == message.properties.reply_to_group_id + assert amqp_message.header.time_to_live == message.header.ttl + assert amqp_message.header.delivery_count == message.header.delivery_count + assert amqp_message.header.first_acquirer == message.header.first_acquirer + assert amqp_message.header.durable == message.header.durable + assert amqp_message.header.priority == message.header.priority + assert amqp_message.annotations == message.message_annotations diff --git a/shared_requirements.txt b/shared_requirements.txt index 37bf110d5369..46ff09f99b82 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -118,7 +118,7 @@ msrest>=0.6.21 msrestazure<2.0.0,>=0.4.32 azure-mgmt-core<2.0.0,>=1.3.0 requests>=2.18.4 -uamqp~=1.5.0 +uamqp~=1.6.0 enum34>=1.0.4 certifi>=2017.4.17 aiohttp>=3.0 @@ -175,7 +175,7 @@ opentelemetry-sdk<2.0.0,>=1.5.0,!=1.10a0 #override azure-eventhub-checkpointstoreblob-aio azure-core<2.0.0,>=1.20.1 #override azure-eventhub-checkpointstoreblob-aio aiohttp<4.0,>=3.0 #override azure-eventhub-checkpointstoretable azure-core<2.0.0,>=1.14.0 -#override azure-eventhub uamqp>=1.5.1,<2.0.0 +#override azure-eventhub uamqp>=1.6.0,<2.0.0 #override azure-appconfiguration msrest>=0.6.10 #override azure-servicebus uamqp>=1.5.1,<2.0.0 #override azure-servicebus msrest>=0.6.17,<2.0.0