-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: provide AsyncIO support for generated code (#365)
- Loading branch information
Showing
24 changed files
with
1,397 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
271 changes: 271 additions & 0 deletions
271
gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
{% extends '_base.py.j2' %} | ||
|
||
{% block content %} | ||
from collections import OrderedDict | ||
import functools | ||
import re | ||
from typing import Dict, {% if service.any_server_streaming %}AsyncIterable, {% endif %}{% if service.any_client_streaming %}AsyncIterator, {% endif %}Sequence, Tuple, Type, Union | ||
import pkg_resources | ||
|
||
import google.api_core.client_options as ClientOptions # type: ignore | ||
from google.api_core import exceptions # type: ignore | ||
from google.api_core import gapic_v1 # type: ignore | ||
from google.api_core import retry as retries # type: ignore | ||
from google.auth import credentials # type: ignore | ||
from google.oauth2 import service_account # type: ignore | ||
|
||
{% filter sort_lines -%} | ||
{% for method in service.methods.values() -%} | ||
{% for ref_type in method.flat_ref_types -%} | ||
{{ ref_type.ident.python_import }} | ||
{% endfor -%} | ||
{% endfor -%} | ||
{% endfilter %} | ||
from .transports.base import {{ service.name }}Transport | ||
from .transports.grpc_asyncio import {{ service.grpc_asyncio_transport_name }} | ||
from .client import {{ service.client_name }} | ||
|
||
|
||
class {{ service.async_client_name }}: | ||
"""{{ service.meta.doc|rst(width=72, indent=4) }}""" | ||
|
||
_client: {{ service.client_name }} | ||
|
||
DEFAULT_ENDPOINT = {{ service.client_name }}.DEFAULT_ENDPOINT | ||
DEFAULT_MTLS_ENDPOINT = {{ service.client_name }}.DEFAULT_MTLS_ENDPOINT | ||
|
||
{% for message in service.resource_messages -%} | ||
{{ message.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.{{ message.resource_type|snake_case }}_path) | ||
|
||
{% endfor %} | ||
|
||
from_service_account_file = {{ service.client_name }}.from_service_account_file | ||
from_service_account_json = from_service_account_file | ||
|
||
get_transport_class = functools.partial(type({{ service.client_name }}).get_transport_class, type({{ service.client_name }})) | ||
|
||
def __init__(self, *, | ||
credentials: credentials.Credentials = None, | ||
transport: Union[str, {{ service.name }}Transport] = 'grpc_asyncio', | ||
client_options: ClientOptions = None, | ||
) -> None: | ||
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}. | ||
|
||
Args: | ||
credentials (Optional[google.auth.credentials.Credentials]): The | ||
authorization credentials to attach to requests. These | ||
credentials identify the application to the service; if none | ||
are specified, the client will attempt to ascertain the | ||
credentials from the environment. | ||
transport (Union[str, ~.{{ service.name }}Transport]): The | ||
transport to use. If set to None, a transport is chosen | ||
automatically. | ||
client_options (ClientOptions): Custom options for the client. It | ||
won't take effect if a ``transport`` instance is provided. | ||
(1) The ``api_endpoint`` property can be used to override the | ||
default endpoint provided by the client. GOOGLE_API_USE_MTLS | ||
environment variable can also be used to override the endpoint: | ||
"always" (always use the default mTLS endpoint), "never" (always | ||
use the default regular endpoint, this is the default value for | ||
the environment variable) and "auto" (auto switch to the default | ||
mTLS endpoint if client SSL credentials is present). However, | ||
the ``api_endpoint`` property takes precedence if provided. | ||
(2) The ``client_cert_source`` property is used to provide client | ||
SSL credentials for mutual TLS transport. If not provided, the | ||
default SSL credentials will be used if present. | ||
|
||
Raises: | ||
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport | ||
creation failed for any reason. | ||
""" | ||
{# NOTE(lidiz) Not using kwargs since we want the docstring and types. #} | ||
self._client = {{ service.client_name }}( | ||
credentials=credentials, | ||
transport=transport, | ||
client_options=client_options, | ||
) | ||
|
||
{% for method in service.methods.values() -%} | ||
{% if not method.server_streaming %}async {% endif -%}def {{ method.name|snake_case }}(self, | ||
{%- if not method.client_streaming %} | ||
request: {{ method.input.ident }} = None, | ||
*, | ||
{% for field in method.flattened_fields.values() -%} | ||
{{ field.name }}: {{ field.ident }} = None, | ||
{% endfor -%} | ||
{%- else %} | ||
requests: AsyncIterator[{{ method.input.ident }}] = None, | ||
*, | ||
{% endif -%} | ||
retry: retries.Retry = gapic_v1.method.DEFAULT, | ||
timeout: float = None, | ||
metadata: Sequence[Tuple[str, str]] = (), | ||
{%- if not method.server_streaming %} | ||
) -> {{ method.client_output_async.ident }}: | ||
{%- else %} | ||
) -> AsyncIterable[{{ method.client_output_async.ident }}]: | ||
{%- endif %} | ||
r"""{{ method.meta.doc|rst(width=72, indent=8) }} | ||
|
||
Args: | ||
{%- if not method.client_streaming %} | ||
request (:class:`{{ method.input.ident.sphinx }}`): | ||
The request object.{{ ' ' -}} | ||
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }} | ||
{% for key, field in method.flattened_fields.items() -%} | ||
{{ field.name }} (:class:`{{ field.ident.sphinx }}`): | ||
{{ field.meta.doc|rst(width=72, indent=16, nl=False) }} | ||
This corresponds to the ``{{ key }}`` field | ||
on the ``request`` instance; if ``request`` is provided, this | ||
should not be set. | ||
{% endfor -%} | ||
{%- else %} | ||
requests (AsyncIterator[`{{ method.input.ident.sphinx }}`]): | ||
The request object AsyncIterator.{{ ' ' -}} | ||
{{ method.input.meta.doc|wrap(width=72, offset=36, indent=16) }} | ||
{%- endif %} | ||
retry (google.api_core.retry.Retry): Designation of what errors, if any, | ||
should be retried. | ||
timeout (float): The timeout for this request. | ||
metadata (Sequence[Tuple[str, str]]): Strings which should be | ||
sent along with the request as metadata. | ||
{%- if not method.void %} | ||
|
||
Returns: | ||
{%- if not method.server_streaming %} | ||
{{ method.client_output_async.ident.sphinx }}: | ||
{%- else %} | ||
AsyncIterable[{{ method.client_output_async.ident.sphinx }}]: | ||
{%- endif %} | ||
{{ method.client_output_async.meta.doc|rst(width=72, indent=16) }} | ||
{%- endif %} | ||
""" | ||
{%- if not method.client_streaming %} | ||
# Create or coerce a protobuf request object. | ||
{% if method.flattened_fields -%} | ||
# Sanity check: If we got a request object, we should *not* have | ||
# gotten any keyword arguments that map to the request. | ||
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]): | ||
raise ValueError('If the `request` argument is set, then none of ' | ||
'the individual field arguments should be set.') | ||
|
||
{% endif -%} | ||
{% if method.input.ident.package != method.ident.package -%} {# request lives in a different package, so there is no proto wrapper #} | ||
# The request isn't a proto-plus wrapped type, | ||
# so it must be constructed via keyword expansion. | ||
if isinstance(request, dict): | ||
request = {{ method.input.ident }}(**request) | ||
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #} | ||
elif not request: | ||
request = {{ method.input.ident }}() | ||
{% endif -%}{# Cross-package req and flattened fields #} | ||
{%- else %} | ||
request = {{ method.input.ident }}(request) | ||
{% endif %} {# different request package #} | ||
|
||
{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #} | ||
{% if method.flattened_fields -%} | ||
# If we have keyword arguments corresponding to fields on the | ||
# request, apply these. | ||
{% endif -%} | ||
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %} | ||
if {{ field.name }} is not None: | ||
request.{{ key }} = {{ field.name }} | ||
{%- endfor %} | ||
{# They can be _extended_, however -#} | ||
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %} | ||
if {{ field.name }}: | ||
request.{{ key }}.extend({{ field.name }}) | ||
{%- endfor %} | ||
{%- endif %} | ||
|
||
# Wrap the RPC method; this adds retry and timeout information, | ||
# and friendly error handling. | ||
rpc = gapic_v1.method_async.wrap_method( | ||
self._client._transport.{{ method.name|snake_case }}, | ||
{%- if method.retry %} | ||
default_retry=retries.Retry( | ||
{% if method.retry.initial_backoff %}initial={{ method.retry.initial_backoff }},{% endif %} | ||
{% if method.retry.max_backoff %}maximum={{ method.retry.max_backoff }},{% endif %} | ||
{% if method.retry.backoff_multiplier %}multiplier={{ method.retry.backoff_multiplier }},{% endif %} | ||
predicate=retries.if_exception_type( | ||
{%- filter sort_lines %} | ||
{%- for ex in method.retry.retryable_exceptions %} | ||
exceptions.{{ ex.__name__ }}, | ||
{%- endfor %} | ||
{%- endfilter %} | ||
), | ||
), | ||
{%- endif %} | ||
default_timeout={{ method.timeout }}, | ||
client_info=_client_info, | ||
) | ||
{%- if method.field_headers %} | ||
|
||
# Certain fields should be provided within the metadata header; | ||
# add these here. | ||
metadata = tuple(metadata) + ( | ||
gapic_v1.routing_header.to_grpc_metadata(( | ||
{%- for field_header in method.field_headers %} | ||
{%- if not method.client_streaming %} | ||
('{{ field_header }}', request.{{ field_header }}), | ||
{%- endif %} | ||
{%- endfor %} | ||
)), | ||
) | ||
{%- endif %} | ||
|
||
# Send the request. | ||
{% if not method.void %}response = {% endif %} | ||
{%- if not method.server_streaming %}await {% endif %}rpc( | ||
{%- if not method.client_streaming %} | ||
request, | ||
{%- else %} | ||
requests, | ||
{%- endif %} | ||
retry=retry, | ||
timeout=timeout, | ||
metadata=metadata, | ||
) | ||
{%- if method.lro %} | ||
|
||
# Wrap the response in an operation future. | ||
response = operation_async.from_gapic( | ||
response, | ||
self._client._transport.operations_client, | ||
{{ method.lro.response_type.ident }}, | ||
metadata_type={{ method.lro.metadata_type.ident }}, | ||
) | ||
{%- elif method.paged_result_field %} | ||
|
||
# This method is paged; wrap the response in a pager, which provides | ||
# an `__aiter__` convenience method. | ||
response = {{ method.client_output_async.ident }}( | ||
method=rpc, | ||
request=request, | ||
response=response, | ||
) | ||
{%- endif %} | ||
{%- if not method.void %} | ||
|
||
# Done; return the response. | ||
return response | ||
{%- endif %} | ||
{{ '\n' }} | ||
{% endfor %} | ||
|
||
|
||
try: | ||
_client_info = gapic_v1.client_info.ClientInfo( | ||
gapic_version=pkg_resources.get_distribution( | ||
'{{ api.naming.warehouse_package_name }}', | ||
).version, | ||
) | ||
except pkg_resources.DistributionNotFound: | ||
_client_info = gapic_v1.client_info.ClientInfo() | ||
|
||
|
||
__all__ = ( | ||
'{{ service.async_client_name }}', | ||
) | ||
{% endblock %} |
Oops, something went wrong.