Skip to content

Commit

Permalink
[text analytics] Expose 'string_index_type' parameter in all service …
Browse files Browse the repository at this point in the history
…client methods where applicable (#16412)
  • Loading branch information
abhahn authored Feb 1, 2021
1 parent bf676d6 commit b547fbb
Show file tree
Hide file tree
Showing 34 changed files with 1,285 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1168,19 +1168,24 @@ class EntitiesRecognitionTask(DictMixin):
"""EntitiesRecognitionTask encapsulates the parameters for starting a long-running Entities Recognition operation.
:ivar str model_version: The model version to use for the analysis.
:ivar str string_index_type: Specifies the method used to interpret string offsets.
Can be one of 'UnicodeCodePoint' (default), 'Utf16CodePoint', or 'TextElements_v8'.
For additional information see https://aka.ms/text-analytics-offsets
"""

def __init__(self, **kwargs):
self.model_version = kwargs.get("model_version", "latest")
self.string_index_type = kwargs.get("string_index_type", "UnicodeCodePoint")

def __repr__(self, **kwargs):
return "EntitiesRecognitionTask(model_version={})" \
.format(self.model_version)[:1024]
return "EntitiesRecognitionTask(model_version={}, string_index_type={})" \
.format(self.model_version, self.string_index_type)[:1024]

def to_generated(self):
return _v3_1_preview_3_models.EntitiesTask(
parameters=_v3_1_preview_3_models.EntitiesTaskParameters(
model_version=self.model_version
model_version=self.model_version,
string_index_type=self.string_index_type
)
)

Expand Down Expand Up @@ -1210,21 +1215,26 @@ class PiiEntitiesRecognitionTask(DictMixin):
:ivar str model_version: The model version to use for the analysis.
:ivar str domain: An optional string to set the PII domain to include only a
subset of the entity categories. Possible values include 'PHI' or None.
:ivar str string_index_type: Specifies the method used to interpret string offsets.
Can be one of 'UnicodeCodePoint' (default), 'Utf16CodePoint', or 'TextElements_v8'.
For additional information see https://aka.ms/text-analytics-offsets
"""

def __init__(self, **kwargs):
self.model_version = kwargs.get("model_version", "latest")
self.domain = kwargs.get("domain", None)
self.string_index_type = kwargs.get("string_index_type", "UnicodeCodePoint")

def __repr__(self, **kwargs):
return "PiiEntitiesRecognitionTask(model_version={}, domain={})" \
.format(self.model_version, self.domain)[:1024]
return "PiiEntitiesRecognitionTask(model_version={}, domain={}, string_index_type={})" \
.format(self.model_version, self.domain, self.string_index_type)[:1024]

def to_generated(self):
return _v3_1_preview_3_models.PiiTask(
parameters=_v3_1_preview_3_models.PiiTaskParameters(
model_version=self.model_version,
domain=self.domain
domain=self.domain,
string_index_type=self.string_index_type
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,22 @@ def _validate_input(documents, hint, whole_input_hint):
request_batch.append(doc)

return request_batch


def _check_string_index_type_arg(string_index_type_arg, api_version, string_index_type_default="UnicodeCodePoint"):
string_index_type = None

if api_version == "v3.0":
if string_index_type_arg is not None:
raise ValueError(
"'string_index_type' is only available for API version v3.1-preview and up"
)

else:
if string_index_type_arg is None:
string_index_type = string_index_type_default

else:
string_index_type = string_index_type_arg

return string_index_type
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import HttpResponseError
from ._base_client import TextAnalyticsClientBase
from ._request_handlers import _validate_input
from ._request_handlers import _validate_input, _check_string_index_type_arg
from ._response_handlers import (
process_http_response_error,
entities_result,
Expand Down Expand Up @@ -104,11 +104,13 @@ def __init__(self, endpoint, credential, **kwargs):
credential=credential,
**kwargs
)
self._api_version = kwargs.get("api_version")
self._default_language = kwargs.pop("default_language", "en")
self._default_country_hint = kwargs.pop("default_country_hint", "US")
self._string_code_unit = None if kwargs.get("api_version") == "v3.0" else "UnicodeCodePoint"
self._string_index_type_default = None if kwargs.get("api_version") == "v3.0" else "UnicodeCodePoint"
self._deserialize = _get_deserialize()


@distributed_trace
def detect_language( # type: ignore
self,
Expand Down Expand Up @@ -208,6 +210,10 @@ def recognize_entities( # type: ignore
is not specified, the API will default to the latest, non-preview version.
:keyword bool show_stats: If set to true, response will contain document
level statistics in the `statistics` field of the document-level response.
:keyword str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or TextElements_v8`. For additional information
see https://aka.ms/text-analytics-offsets
:return: The combined list of :class:`~azure.ai.textanalytics.RecognizeEntitiesResult` and
:class:`~azure.ai.textanalytics.DocumentError` in the order the original documents
were passed in.
Expand All @@ -229,8 +235,14 @@ def recognize_entities( # type: ignore
docs = _validate_input(documents, "language", language)
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", False)
if self._string_code_unit:
kwargs.update({"string_index_type": self._string_code_unit})
string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})

try:
return self._client.entities_recognition_general(
documents=docs,
Expand Down Expand Up @@ -280,6 +292,10 @@ def recognize_pii_entities( # type: ignore
I.e., if set to 'PHI', will only return entities in the Protected Healthcare Information domain.
See https://aka.ms/tanerpii for more information.
:paramtype domain_filter: str or ~azure.ai.textanalytics.PiiEntityDomainType
:keyword str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or `TextElements_v8`. For additional information
see https://aka.ms/text-analytics-offsets
:return: The combined list of :class:`~azure.ai.textanalytics.RecognizePiiEntitiesResult`
and :class:`~azure.ai.textanalytics.DocumentError` in the order the original documents
were passed in.
Expand All @@ -302,8 +318,15 @@ def recognize_pii_entities( # type: ignore
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", False)
domain_filter = kwargs.pop("domain_filter", None)
if self._string_code_unit:
kwargs.update({"string_index_type": self._string_code_unit})

string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})

try:
return self._client.entities_recognition_pii(
documents=docs,
Expand Down Expand Up @@ -357,6 +380,10 @@ def recognize_linked_entities( # type: ignore
is not specified, the API will default to the latest, non-preview version.
:keyword bool show_stats: If set to true, response will contain document
level statistics in the `statistics` field of the document-level response.
:keyword str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or `TextElements_v8`. For additional information
see https://aka.ms/text-analytics-offsets
:return: The combined list of :class:`~azure.ai.textanalytics.RecognizeLinkedEntitiesResult`
and :class:`~azure.ai.textanalytics.DocumentError` in the order the original documents
were passed in.
Expand All @@ -378,8 +405,15 @@ def recognize_linked_entities( # type: ignore
docs = _validate_input(documents, "language", language)
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", False)
if self._string_code_unit:
kwargs.update({"string_index_type": self._string_code_unit})

string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})

try:
return self._client.entities_linking(
documents=docs,
Expand Down Expand Up @@ -430,6 +464,10 @@ def begin_analyze_healthcare( # type: ignore
be used for scoring, e.g. "latest", "2019-10-01". If a model-version
is not specified, the API will default to the latest, non-preview version.
:keyword bool show_stats: If set to true, response will contain document level statistics.
:keyword str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or `TextElements_v8`. For additional information
see https://aka.ms/text-analytics-offsets
:keyword int polling_interval: Waiting time between two polls for LRO operations
if no Retry-After header is present. Defaults to 5 seconds.
:keyword str continuation_token: A continuation token to restart a poller from a saved state.
Expand All @@ -453,14 +491,15 @@ def begin_analyze_healthcare( # type: ignore
show_stats = kwargs.pop("show_stats", False)
polling_interval = kwargs.pop("polling_interval", 5)
continuation_token = kwargs.pop("continuation_token", None)
string_index_type = kwargs.pop("string_index_type", self._string_index_type_default)

doc_id_order = [doc.get("id") for doc in docs]

try:
return self._client.begin_health(
docs,
model_version=model_version,
string_index_type=self._string_code_unit,
string_index_type=string_index_type,
cls=kwargs.pop("cls", partial(self._healthcare_result_callback, doc_id_order, show_stats=show_stats)),
polling=TextAnalyticsLROPollingMethod(
timeout=polling_interval,
Expand Down Expand Up @@ -626,8 +665,13 @@ def analyze_sentiment( # type: ignore
is not specified, the API will default to the latest, non-preview version.
:keyword bool show_stats: If set to true, response will contain document
level statistics in the `statistics` field of the document-level response.
:keyword str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or `TextElements_v8`. For additional information
see https://aka.ms/text-analytics-offsets
.. versionadded:: v3.1-preview
The *show_opinion_mining* parameter.
The *string_index_type* parameter.
:return: The combined list of :class:`~azure.ai.textanalytics.AnalyzeSentimentResult` and
:class:`~azure.ai.textanalytics.DocumentError` in the order the original documents were
passed in.
Expand All @@ -650,8 +694,14 @@ def analyze_sentiment( # type: ignore
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", False)
show_opinion_mining = kwargs.pop("show_opinion_mining", None)
if self._string_code_unit:
kwargs.update({"string_index_type": self._string_code_unit})

string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})

if show_opinion_mining is not None:
kwargs.update({"opinion_mining": show_opinion_mining})
Expand Down
Loading

0 comments on commit b547fbb

Please sign in to comment.