Skip to content

Commit

Permalink
fix: pass metadata to pagers (#470)
Browse files Browse the repository at this point in the history
Closes #469
  • Loading branch information
busunkim96 authored Jul 7, 2020
1 parent f49bc3f commit c43c6d9
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 12 deletions.
5 changes: 5 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,11 @@ def has_lro(self) -> bool:
"""Return whether the service has a long-running method."""
return any([m.lro for m in self.methods.values()])

@property
def has_pagers(self) -> bool:
"""Return whether the service has paged methods."""
return any(m.paged_result_field for m in self.methods.values())

@property
def host(self) -> str:
"""Return the hostname for this service, if specified.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class {{ service.async_client_name }}:
method=rpc,
request=request,
response=response,
metadata=metadata,
)
{%- endif %}
{%- if not method.void %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
method=rpc,
request=request,
response=response,
metadata=metadata,
)
{%- endif %}
{%- if not method.void %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{# This lives within the loop in order to ensure that this template
is empty if there are no paged methods.
-#}
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple

{% filter sort_lines -%}
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
Expand Down Expand Up @@ -35,10 +35,11 @@ class {{ method.name }}Pager:
the most recent response is retained, and thus used for attribute lookup.
"""
def __init__(self,
method: Callable[[{{ method.input.ident }}],
{{ method.output.ident }}],
method: Callable[..., {{ method.output.ident }}],
request: {{ method.input.ident }},
response: {{ method.output.ident }}):
response: {{ method.output.ident }},
*,
metadata: Sequence[Tuple[str, str]] = ()):
"""Instantiate the pager.

Args:
Expand All @@ -48,10 +49,13 @@ class {{ method.name }}Pager:
The initial request object.
response (:class:`{{ method.output.ident.sphinx }}`):
The initial response object.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
"""
self._method = method
self._request = {{ method.input.ident }}(request)
self._response = response
self._metadata = metadata

def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
Expand All @@ -61,7 +65,7 @@ class {{ method.name }}Pager:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request)
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:
Expand Down Expand Up @@ -90,10 +94,11 @@ class {{ method.name }}AsyncPager:
the most recent response is retained, and thus used for attribute lookup.
"""
def __init__(self,
method: Callable[[{{ method.input.ident }}],
Awaitable[{{ method.output.ident }}]],
method: Callable[..., Awaitable[{{ method.output.ident }}]],
request: {{ method.input.ident }},
response: {{ method.output.ident }}):
response: {{ method.output.ident }},
*,
metadata: Sequence[Tuple[str, str]] = ()):
"""Instantiate the pager.

Args:
Expand All @@ -103,10 +108,13 @@ class {{ method.name }}AsyncPager:
The initial request object.
response (:class:`{{ method.output.ident.sphinx }}`):
The initial response object.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
"""
self._method = method
self._request = {{ method.input.ident }}(request)
self._response = response
self._metadata = metadata

def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
Expand All @@ -116,7 +124,7 @@ class {{ method.name }}AsyncPager:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request)
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response

def __aiter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'AsyncIterable') }}:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ from google.api_core import future
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% endif -%}
{% if service.has_pagers -%}
from google.api_core import gapic_v1
{% endif -%}
{% for method in service.methods.values() -%}
{% for ref_type in method.ref_types
if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation')
Expand Down Expand Up @@ -695,9 +698,24 @@ def test_{{ method.name|snake_case }}_pager():
),
RuntimeError,
)
results = [i for i in client.{{ method.name|snake_case }}(
request={},
)]

metadata = ()
{% if method.field_headers -%}
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata((
{%- for field_header in method.field_headers %}
{%- if not method.client_streaming %}
('{{ field_header }}', ''),
{%- endif %}
{%- endfor %}
)),
)
{% endif -%}
pager = client.{{ method.name|snake_case }}(request={})

assert pager._metadata == metadata

results = [i for i in pager]
assert len(results) == 6
assert all(isinstance(i, {{ method.paged_result_field.message.ident }})
for i in results)
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/schema/wrappers/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,38 @@ def test_service_any_streaming():

assert service.any_client_streaming == client
assert service.any_server_streaming == server


def test_has_pagers():
paged = make_field(name='foos', message=make_message('Foo'), repeated=True)
input_msg = make_message(
name='ListFoosRequest',
fields=(
make_field(name='parent', type=9), # str
make_field(name='page_size', type=5), # int
make_field(name='page_token', type=9), # str
),
)
output_msg = make_message(
name='ListFoosResponse',
fields=(
paged,
make_field(name='next_page_token', type=9), # str
),
)
method = make_method(
'ListFoos',
input_message=input_msg,
output_message=output_msg,
)

service = make_service(name="Fooer", methods=(method,),)
assert service.has_pagers

other_service = make_service(
name="Unfooer",
methods=(
get_method("Unfoo", "foo.bar.UnfooReq", "foo.bar.UnFooResp"),
),
)
assert not other_service.has_pagers

0 comments on commit c43c6d9

Please sign in to comment.