Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[textanalytics] enabling support for rehydration of LROs in TA #21338

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This version of the SDK defaults to the latest supported API version, which curr
### Breaking Changes

### Bugs Fixed
- Restarting a long-running operation from a saved state is now supported for the `begin_analyze_actions` and `begin_recognize_healthcare_entities` methods.

### Other Changes
- Package requires [azure-core](https://pypi.org/project/azure-core/) version 1.16.0 or greater
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import base64
import functools
from typing import TYPE_CHECKING, Generic
from six.moves.urllib.parse import urlencode
from azure.core.polling._poller import PollingReturnType
Expand Down Expand Up @@ -110,6 +113,8 @@ def _poll(self):

class AnalyzeHealthcareEntitiesLROPollingMethod(TextAnalyticsLROPollingMethod):
def __init__(self, *args, **kwargs):
self._doc_id_order = kwargs.pop("doc_id_order", None)
self._show_stats = kwargs.pop("show_stats", None)
self._text_analytics_client = kwargs.pop("text_analytics_client")
super(AnalyzeHealthcareEntitiesLROPollingMethod, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -143,6 +148,13 @@ def id(self):
return None
return self._current_body.job_id

def get_continuation_token(self):
# type: () -> str
import pickle
self._initial_response.context.options["doc_id_order"] = self._doc_id_order
self._initial_response.context.options["show_stats"] = self._show_stats
return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii')


class AnalyzeHealthcareEntitiesLROPoller(LROPoller, Generic[PollingReturnType]):
def polling_method(self):
Expand Down Expand Up @@ -190,6 +202,24 @@ def id(self):
"""
return self.polling_method().id

@classmethod
def from_continuation_token(cls, polling_method, continuation_token, **kwargs): # type: ignore
# type: (AnalyzeHealthcareEntitiesLROPollingMethod, str, Any) -> AnalyzeHealthcareEntitiesLROPoller
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
polling_method._lro_algorithms = [ # pylint: disable=protected-access
TextAnalyticsOperationResourcePolling(
show_stats=initial_response.context.options["show_stats"]
)
]
return cls(
client,
initial_response,
functools.partial(deserialization_callback, initial_response),
polling_method
)

def cancel(self, **kwargs): # type: ignore
# type: (Any) -> LROPoller[None]
"""Cancel the operation currently being polled.
Expand Down Expand Up @@ -231,6 +261,12 @@ def cancel(self, **kwargs): # type: ignore


class AnalyzeActionsLROPollingMethod(TextAnalyticsLROPollingMethod):
def __init__(self, *args, **kwargs):
self._doc_id_order = kwargs.pop("doc_id_order", None)
self._task_id_order = kwargs.pop("task_id_order", None)
self._show_stats = kwargs.pop("show_stats", None)
super(AnalyzeActionsLROPollingMethod, self).__init__(*args, **kwargs)

@property
def _current_body(self):
from ._generated.models import AnalyzeJobMetadata
Expand Down Expand Up @@ -291,6 +327,14 @@ def id(self):
return None
return self._current_body.job_id

def get_continuation_token(self):
# type: () -> str
import pickle
self._initial_response.context.options["doc_id_order"] = self._doc_id_order
self._initial_response.context.options["task_id_order"] = self._task_id_order
self._initial_response.context.options["show_stats"] = self._show_stats
return base64.b64encode(pickle.dumps(self._initial_response)).decode('ascii')


class AnalyzeActionsLROPoller(LROPoller, Generic[PollingReturnType]):
def polling_method(self):
Expand Down Expand Up @@ -390,3 +434,21 @@ def id(self):
:rtype: str
"""
return self.polling_method().id

@classmethod
def from_continuation_token(cls, polling_method, continuation_token, **kwargs): # type: ignore
# type: (AnalyzeActionsLROPollingMethod, str, Any) -> AnalyzeActionsLROPoller
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
continuation_token, **kwargs
)
polling_method._lro_algorithms = [ # pylint: disable=protected-access
TextAnalyticsOperationResourcePolling(
show_stats=initial_response.context.options["show_stats"]
)
]
return cls(
client,
initial_response,
functools.partial(deserialization_callback, initial_response),
polling_method
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# ------------------------------------
# pylint: disable=too-many-lines

import copy
from typing import (
Union,
Any,
Expand All @@ -18,6 +17,7 @@
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import HttpResponseError
from ._base_client import TextAnalyticsClientBase, TextAnalyticsApiVersion
from ._lro import AnalyzeActionsLROPoller, AnalyzeHealthcareEntitiesLROPoller
from ._request_handlers import (
_validate_input,
_determine_action_type,
Expand Down Expand Up @@ -71,7 +71,6 @@
MultiCategoryClassifyAction,
MultiCategoryClassifyResult,
)
from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller


class TextAnalyticsClient(TextAnalyticsClientBase):
Expand Down Expand Up @@ -564,7 +563,10 @@ def begin_analyze_healthcare_entities( # type: ignore
see https://aka.ms/text-analytics-offsets
:keyword int polling_interval: Waiting time between two polls for LRO operations
if no Retry-After header is present. Defaults to 5 seconds.
:keyword str continuation_token: A continuation token to restart a poller from a saved state.
:keyword str continuation_token:
Call `continuation_token()` on the poller object to save the long-running operation (LRO)
state into an opaque token. Pass the value as the `continuation_token` keyword argument
to restart the LRO from a saved state.
:keyword bool disable_service_logs: Defaults to true, meaning that Text Analytics will not log your
input text on the service side for troubleshooting. If set to False, Text Analytics logs your
input text for 48 hours, solely to allow for troubleshooting issues in providing you with
Expand Down Expand Up @@ -595,44 +597,64 @@ def begin_analyze_healthcare_entities( # type: ignore
"""
language_arg = kwargs.pop("language", None)
language = language_arg if language_arg is not None else self._default_language
docs = _validate_input(documents, "language", language)
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", False)
polling_interval = kwargs.pop("polling_interval", 5)
continuation_token = kwargs.pop("continuation_token", None)
string_index_type = kwargs.pop(
"string_index_type", self._string_index_type_default
)
disable_service_logs = kwargs.pop("disable_service_logs", None)

if continuation_token:
def get_result_from_cont_token(initial_response, pipeline_response):
doc_id_order = initial_response.context.options["doc_id_order"]
show_stats = initial_response.context.options["show_stats"]
return self._healthcare_result_callback(
doc_id_order, pipeline_response, None, {}, show_stats=show_stats
)

return AnalyzeHealthcareEntitiesLROPoller.from_continuation_token(
polling_method=AnalyzeHealthcareEntitiesLROPollingMethod(
text_analytics_client=self._client,
timeout=polling_interval,
**kwargs
),
client=self._client._client, # pylint: disable=protected-access
deserialization_callback=get_result_from_cont_token,
continuation_token=continuation_token
)

docs = _validate_input(documents, "language", language)
doc_id_order = [doc.get("id") for doc in docs]
my_cls = kwargs.pop(
"cls",
partial(
self._healthcare_result_callback, doc_id_order, show_stats=show_stats
),
)
disable_service_logs = kwargs.pop("disable_service_logs", None)
polling_kwargs = kwargs
operation_kwargs = copy.copy(kwargs)
if disable_service_logs is not None:
operation_kwargs["logging_opt_out"] = disable_service_logs

try:
return self._client.begin_health(
docs,
model_version=model_version,
string_index_type=string_index_type,
logging_opt_out=disable_service_logs,
cls=my_cls,
polling=AnalyzeHealthcareEntitiesLROPollingMethod(
text_analytics_client=self._client,
timeout=polling_interval,
doc_id_order=doc_id_order,
show_stats=show_stats,
lro_algorithms=[
TextAnalyticsOperationResourcePolling(show_stats=show_stats)
TextAnalyticsOperationResourcePolling(
show_stats=show_stats,
)
],
**polling_kwargs
**kwargs
),
continuation_token=continuation_token,
**operation_kwargs
**kwargs
)

except ValueError as error:
Expand Down Expand Up @@ -894,6 +916,10 @@ def begin_analyze_actions( # type: ignore
:keyword bool show_stats: If set to true, response will contain document level statistics.
:keyword int polling_interval: Waiting time between two polls for LRO operations
if no Retry-After header is present. Defaults to 5 seconds.
:keyword str continuation_token:
Call `continuation_token()` on the poller object to save the long-running operation (LRO)
state into an opaque token. Pass the value as the `continuation_token` keyword argument
to restart the LRO from a saved state.
:return: An instance of an AnalyzeActionsLROPoller. Call `result()` on the poller
object to return a pageable heterogeneous list of lists. This list of lists is first ordered
by the documents you input, then ordered by the actions you input. For example,
Expand Down Expand Up @@ -931,18 +957,37 @@ def begin_analyze_actions( # type: ignore
actions over a batch of documents.
"""

continuation_token = kwargs.pop("continuation_token", None)
display_name = kwargs.pop("display_name", None)
language_arg = kwargs.pop("language", None)
show_stats = kwargs.pop("show_stats", False)
polling_interval = kwargs.pop("polling_interval", 5)
language = language_arg if language_arg is not None else self._default_language

if continuation_token:
def get_result_from_cont_token(initial_response, pipeline_response):
doc_id_order = initial_response.context.options["doc_id_order"]
task_id_order = initial_response.context.options["task_id_order"]
show_stats = initial_response.context.options["show_stats"]
return self._analyze_result_callback(
doc_id_order, task_id_order, pipeline_response, None, {}, show_stats=show_stats
)

return AnalyzeActionsLROPoller.from_continuation_token(
polling_method=AnalyzeActionsLROPollingMethod(
timeout=polling_interval,
**kwargs
),
client=self._client._client, # pylint: disable=protected-access
deserialization_callback=get_result_from_cont_token,
continuation_token=continuation_token
)

docs = self._client.models(
api_version=self._api_version
).MultiLanguageBatchInput(
documents=_validate_input(documents, "language", language)
)
show_stats = kwargs.pop("show_stats", False)
polling_interval = kwargs.pop("polling_interval", 5)
continuation_token = kwargs.pop("continuation_token", None)

doc_id_order = [doc.get("id") for doc in docs.documents]
try:
generated_tasks = [
Expand Down Expand Up @@ -1012,8 +1057,13 @@ def begin_analyze_actions( # type: ignore
),
polling=AnalyzeActionsLROPollingMethod(
timeout=polling_interval,
show_stats=show_stats,
doc_id_order=doc_id_order,
task_id_order=task_order,
lro_algorithms=[
TextAnalyticsOperationResourcePolling(show_stats=show_stats)
TextAnalyticsOperationResourcePolling(
show_stats=show_stats,
)
],
**kwargs
),
Expand Down
Loading