Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update paging implementation to handle unconventional pagination #750

Merged
merged 8 commits into from
Feb 4, 2021
Merged
17 changes: 13 additions & 4 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,22 @@ def paged_result_field(self) -> Optional[Field]:
"""Return the response pagination field if the method is paginated."""
# If the request field lacks any of the expected pagination fields,
# then the method is not paginated.
for page_field in ((self.input, int, 'page_size'),
(self.input, str, 'page_token'),

# The request must have page_token and next_page_token as they keep track of pages
for source, source_type, name in ((self.input, str, 'page_token'),
(self.output, str, 'next_page_token')):
field = page_field[0].fields.get(page_field[2], None)
if not field or field.type != page_field[1]:
field = source.fields.get(name, None)
if not field or field.type != source_type:
return None

# The request must have max_results or page_size
page_fields = (self.input.fields.get('max_results', None),
self.input.fields.get('page_size', None))
page_field_size = next(
(field for field in page_fields if field), None)
if not page_field_size or page_field_size.type != int:
return None

# Return the first repeated field.
for field in self.output.fields.values():
if field.repeated:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{# This lives within the loop in order to ensure that this template
is empty if there are no paged methods.
-#}
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional

{% filter sort_lines -%}
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
Expand Down Expand Up @@ -68,14 +68,25 @@ class {{ method.name }}Pager:
self._response = self._method(self._request, metadata=self._metadata)
yield self._response

{% if method.paged_result_field.map %}
def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.type.fields.get('value').ident }}]]:
for page in self.pages:
yield from page.{{ method.paged_result_field.name}}.items()

def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]:
return self._response.items.get(key)
{% else %}
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:
for page in self.pages:
yield from page.{{ method.paged_result_field.name }}
{% endif %}

def __repr__(self) -> str:
return '{0}<{1!r}>'.format(self.__class__.__name__, self._response)


{# TODO(yon-mg): remove on rest async transport impl #}
{% if 'grpc' in opts.transport %}
class {{ method.name }}AsyncPager:
"""A pager for iterating through ``{{ method.name|snake_case }}`` requests.

Expand Down Expand Up @@ -138,5 +149,6 @@ class {{ method.name }}AsyncPager:
def __repr__(self) -> str:
return '{0}<{1!r}>'.format(self.__class__.__name__, self._response)

{% endif %}
{% endfor %}
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
# TODO(yon-mg): handle nested fields corerctly rather than using only top level fields
# not required for GCE
query_params = {
{% filter sort_lines -%}
{%- for field in method.query_params %}
{%- for field in method.query_params | sort%}
'{{ field|camel_case }}': request.{{ field }},
{%- endfor %}
{% endfilter -%}
}
# TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here
# discards default values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def test_{{ method.name|snake_case }}_raw_page_lro():
assert response.raw_page is response
{% endif %} {#- method.paged_result_field #}

{% endfor -%} {#- method in methods #}
{% endfor -%} {#- method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport -%}
def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}):
Expand Down Expand Up @@ -1162,7 +1162,126 @@ def test_{{ method.name|snake_case }}_rest_flattened_error():
)


{% endfor -%}
{% if method.paged_result_field %}
def test_{{ method.name|snake_case }}_pager():
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
)

# Mock the http request call within the method and fake a response.
with mock.patch.object(Session, 'request') as req:
# Set the response as a series of pages
{% if method.paged_result_field.map%}
response = (
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={
'a':{{ method.paged_result_field.type.fields.get('value').ident }}(),
'b':{{ method.paged_result_field.type.fields.get('value').ident }}(),
'c':{{ method.paged_result_field.type.fields.get('value').ident }}(),
},
next_page_token='abc',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={},
next_page_token='def',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={
'g':{{ method.paged_result_field.type.fields.get('value').ident }}(),
},
next_page_token='ghi',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}={
'h':{{ method.paged_result_field.type.fields.get('value').ident }}(),
'i':{{ method.paged_result_field.type.fields.get('value').ident }}(),
},
),
)
{% else %}
response = (
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[
{{ method.paged_result_field.type.ident }}(),
{{ method.paged_result_field.type.ident }}(),
{{ method.paged_result_field.type.ident }}(),
],
next_page_token='abc',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[],
next_page_token='def',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[
{{ method.paged_result_field.type.ident }}(),
],
next_page_token='ghi',
),
{{ method.output.ident }}(
{{ method.paged_result_field.name }}=[
{{ method.paged_result_field.type.ident }}(),
{{ method.paged_result_field.type.ident }}(),
],
),
)
{% endif %}
# Two responses for two calls
response = response + response

# Wrap the values into proper Response objs
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
req.side_effect = return_values

metadata = ()
{% if method.field_headers -%}
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata((
{%- for field_header in method.field_headers %}
{%- if not method.client_streaming %}
('{{ field_header }}', ''),
{%- endif %}
{%- endfor %}
)),
)
{% endif -%}
pager = client.{{ method.name|snake_case }}(request={})

assert pager._metadata == metadata

{% if method.paged_result_field.map %}
assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }})
assert pager.get('h') is None
{% endif %}

results = list(pager)
assert len(results) == 6
{% if method.paged_result_field.map %}
assert all(
isinstance(i, tuple)
for i in results)
for result in results:
assert isinstance(result, tuple)
assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }})

assert pager.get('a') is None
assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }})
{% else %}
assert all(isinstance(i, {{ method.paged_result_field.type.ident }})
for i in results)
{% endif %}

pages = list(client.{{ method.name|snake_case }}(request={}).pages)
for page_, token in zip(pages, ['abc','def','ghi', '']):
assert page_.raw_page.next_page_token == token


{% endif %} {# paged methods #}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the results of this template visible in a generated file within this repo? (just asking, the repo may not be structured that way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. It may be useful to include a generated client but that might clog the repo. The showcase client is generated as part of CircleCI tests though.

{% endfor -%} {#- method in methods for rest #}
def test_credentials_transport_error():
# It is an error to provide credentials and a transport instance.
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(
Expand Down
50 changes: 41 additions & 9 deletions tests/unit/schema/wrappers/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,38 @@ def test_method_client_output_empty():

def test_method_client_output_paged():
paged = make_field(name='foos', message=make_message('Foo'), repeated=True)
parent = make_field(name='parent', type=9) # str
page_size = make_field(name='page_size', type=5) # int
page_token = make_field(name='page_token', type=9) # str

input_msg = make_message(name='ListFoosRequest', fields=(
make_field(name='parent', type=9), # str
make_field(name='page_size', type=5), # int
make_field(name='page_token', type=9), # str
parent,
page_size,
page_token,
))
output_msg = make_message(name='ListFoosResponse', fields=(
paged,
make_field(name='next_page_token', type=9), # str
))
method = make_method('ListFoos',
input_message=input_msg,
output_message=output_msg,
)
method = make_method(
'ListFoos',
input_message=input_msg,
output_message=output_msg,
)
assert method.paged_result_field == paged
assert method.client_output.ident.name == 'ListFoosPager'

max_results = make_field(name='max_results', type=5) # int
input_msg = make_message(name='ListFoosRequest', fields=(
parent,
max_results,
page_token,
))
method = make_method(
'ListFoos',
input_message=input_msg,
output_message=output_msg,
)
assert method.paged_result_field == paged
assert method.client_output.ident.name == 'ListFoosPager'

Expand Down Expand Up @@ -123,6 +142,19 @@ def test_method_paged_result_field_no_page_field():
)
assert method.paged_result_field is None

method = make_method(
name='Foo',
input_message=make_message(
name='FooRequest',
fields=(make_field(name='page_token', type=9),) # str
),
output_message=make_message(
name='FooResponse',
fields=(make_field(name='next_page_token', type=9),) # str
)
)
assert method.paged_result_field is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we add tests in this file for max_results and for mapped responses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mapped responses are treated the same here. Checking for repeated fields should be sufficient since mapped fields are also repeated. Test for max_results is now also added.



def test_method_paged_result_ref_types():
input_msg = make_message(
Expand All @@ -139,7 +171,7 @@ def test_method_paged_result_ref_types():
name='ListMolluscsResponse',
fields=(
make_field(name='molluscs', message=mollusc_msg, repeated=True),
make_field(name='next_page_token', type=9)
make_field(name='next_page_token', type=9) # str
),
module='mollusc'
)
Expand Down Expand Up @@ -207,7 +239,7 @@ def test_flattened_ref_types():


def test_method_paged_result_primitive():
paged = make_field(name='squids', type=9, repeated=True)
paged = make_field(name='squids', type=9, repeated=True) # str
input_msg = make_message(
name='ListSquidsRequest',
fields=(
Expand Down