Skip to content

Commit

Permalink
[formrecognizer] reduce time for recorded tests runs (#11970)
Browse files Browse the repository at this point in the history
* wip

* wip

* add transport wrapper for get clients

* fix mypy

* fix

* refactor

* pass through kwargs

* add polling interval tests

* feedback
  • Loading branch information
kristapratico authored Jun 15, 2020
1 parent 4ee3368 commit 010bc93
Show file tree
Hide file tree
Showing 29 changed files with 1,551 additions and 496 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=endpoint,
credential=credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -111,7 +113,7 @@ def begin_recognize_receipts(self, receipt, **kwargs):
:caption: Recognize US sales receipt fields.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -162,7 +164,7 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs):
:caption: Recognize US sales receipt fields from a URL.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down Expand Up @@ -210,7 +212,7 @@ def begin_recognize_content(self, form, **kwargs):
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -246,7 +248,7 @@ def begin_recognize_content_from_url(self, form_url, **kwargs):
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

return self._client.begin_analyze_layout_async(
Expand Down Expand Up @@ -296,7 +298,7 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs):
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -348,7 +350,7 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs):
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.core.tracing.decorator import distributed_trace
from azure.core.polling import LROPoller
from azure.core.polling.base_polling import LROBasePolling
from azure.core.pipeline import Pipeline
from ._generated._form_recognizer_client import FormRecognizerClient as FormRecognizer
from ._generated.models import (
TrainRequest,
Expand All @@ -26,7 +27,7 @@
CopyOperationResult,
CopyAuthorizationResult
)
from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL
from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL, TransportWrapper
from ._models import (
CustomFormModelInfo,
AccountProperties,
Expand Down Expand Up @@ -78,11 +79,13 @@ def __init__(self, endpoint, credential, **kwargs):
self._endpoint = endpoint
self._credential = credential
authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=self._endpoint,
credential=self._credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -129,7 +132,7 @@ def callback(raw_response):

cls = kwargs.pop("cls", None)
continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
deserialization_callback = cls if cls else callback

if continuation_token:
Expand Down Expand Up @@ -339,7 +342,7 @@ def begin_copy_model(
if not model_id:
raise ValueError("model_id cannot be None or empty.")

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
Expand Down Expand Up @@ -371,11 +374,20 @@ def get_form_recognizer_client(self, **kwargs):
:rtype: ~azure.ai.formrecognizer.FormRecognizerClient
:return: A FormRecognizerClient
"""
return FormRecognizerClient(

_pipeline = Pipeline(
transport=TransportWrapper(self._client._client._pipeline._transport),
policies=self._client._client._pipeline._impl_policies
) # type: Pipeline
client = FormRecognizerClient(
endpoint=self._endpoint,
credential=self._credential,
pipeline=_pipeline,
**kwargs
)
# need to share config, but can't pass as a keyword into client
client._client._config = self._client._client._config
return client

def close(self):
# type: () -> None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import six
from azure.core.credentials import AzureKeyCredential
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.exceptions import (
ResourceNotFoundError,
ResourceExistsError,
Expand All @@ -24,6 +25,30 @@
}


class TransportWrapper(HttpTransport):
"""Wrapper class that ensures that an inner client created
by a `get_client` method does not close the outer transport for the parent
when used in a context manager.
"""
def __init__(self, transport):
self._transport = transport

def send(self, request, **kwargs):
return self._transport.send(request, **kwargs)

def open(self):
pass

def close(self):
pass

def __enter__(self):
pass

def __exit__(self, *args): # pylint: disable=arguments-differ
pass


def get_authentication_policy(credential):
authentication_policy = None
if credential is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def __init__(
) -> None:

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=endpoint,
credential=credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -119,7 +121,7 @@ async def begin_recognize_receipts(
:caption: Recognize US sales receipt fields.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -176,7 +178,7 @@ async def begin_recognize_receipts_from_url(
:caption: Recognize US sales receipt fields from a URL.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down Expand Up @@ -230,7 +232,7 @@ async def begin_recognize_content(
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -268,7 +270,7 @@ async def begin_recognize_content_from_url(self, form_url: str, **kwargs: Any) -
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
return await self._client.begin_analyze_layout_async( # type: ignore
file_stream={"source": form_url},
Expand Down Expand Up @@ -324,7 +326,7 @@ async def begin_recognize_custom_forms(
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
Expand Down Expand Up @@ -385,7 +387,7 @@ async def begin_recognize_custom_forms_from_url(
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_text_content = kwargs.pop("include_text_content", False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
TYPE_CHECKING,
)
from azure.core.polling import AsyncLROPoller
from azure.core.pipeline import AsyncPipeline
from azure.core.polling.async_base_polling import AsyncLROBasePolling
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from ._form_recognizer_client_async import FormRecognizerClient
from ._helpers_async import AsyncTransportWrapper
from .._generated.aio._form_recognizer_client_async import FormRecognizerClient as FormRecognizer
from .._generated.models import (
TrainRequest,
Expand Down Expand Up @@ -81,13 +83,14 @@ def __init__(
) -> None:
self._endpoint = endpoint
self._credential = credential

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
self._client = FormRecognizer(
endpoint=self._endpoint,
credential=self._credential, # type: ignore
sdk_moniker=USER_AGENT,
authentication_policy=authentication_policy,
polling_interval=polling_interval,
**kwargs
)

Expand Down Expand Up @@ -138,7 +141,7 @@ def callback(raw_response):

cls = kwargs.pop("cls", None)
continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
deserialization_callback = cls if cls else callback

if continuation_token:
Expand Down Expand Up @@ -361,7 +364,7 @@ async def begin_copy_model(
raise ValueError("model_id cannot be None or empty.")

continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
copy_result = self._client._deserialize(CopyOperationResult, raw_response)
Expand Down Expand Up @@ -395,11 +398,19 @@ def get_form_recognizer_client(self, **kwargs: Any) -> FormRecognizerClient:
:rtype: ~azure.ai.formrecognizer.aio.FormRecognizerClient
:return: A FormRecognizerClient
"""
return FormRecognizerClient(
_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._client._client._pipeline._transport),
policies=self._client._client._pipeline._impl_policies
) # type: AsyncPipeline
client = FormRecognizerClient(
endpoint=self._endpoint,
credential=self._credential,
pipeline=_pipeline,
**kwargs
)
# need to share config, but can't pass as a keyword into client
client._client._config = self._client._client._config
return client

async def __aenter__(self) -> "FormTrainingClient":
await self._client.__aenter__()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# coding=utf-8
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from azure.core.pipeline.transport import AsyncHttpTransport


class AsyncTransportWrapper(AsyncHttpTransport):
"""Wrapper class that ensures that an inner client created
by a `get_client` method does not close the outer transport for the parent
when used in a context manager.
"""
def __init__(self, async_transport):
self._transport = async_transport

async def send(self, request, **kwargs):
return await self._transport.send(request, **kwargs)

async def open(self):
pass

async def close(self):
pass

async def __aenter__(self):
pass

async def __aexit__(self, *args): # pylint: disable=arguments-differ
pass
Loading

0 comments on commit 010bc93

Please sign in to comment.