From 1141010c3de40202a65c93fe0c1801d6309f9816 Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Mon, 25 Jan 2021 18:14:00 +0000 Subject: [PATCH 1/7] fix: update paging implementation to handle unconventional pagination --- gapic/schema/wrappers.py | 9 +- .../%sub/services/%service/pagers.py.j2 | 14 ++- .../%name_%version/%sub/test_%service.py.j2 | 119 +++++++++++++++++- tests/unit/schema/wrappers/test_method.py | 14 ++- 4 files changed, 149 insertions(+), 7 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index eefe0cdc7e..94f29be8a0 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -866,12 +866,17 @@ def paged_result_field(self) -> Optional[Field]: """Return the response pagination field if the method is paginated.""" # If the request field lacks any of the expected pagination fields, # then the method is not paginated. - for page_field in ((self.input, int, 'page_size'), - (self.input, str, 'page_token'), + for page_field in ((self.input, str, 'page_token'), (self.output, str, 'next_page_token')): field = page_field[0].fields.get(page_field[2], None) if not field or field.type != page_field[1]: return None + page_fields = [self.input.fields.get('max_results', None), + self.input.fields.get('page_size', None)] + page_field = next( + (field for field in page_fields if field is not None), None) + if not page_field or page_field.type != int: + return None # Return the first repeated field. for field in self.output.fields.values(): diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 index ea08466ba0..5f04eed61e 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 @@ -6,7 +6,7 @@ {# This lives within the loop in order to ensure that this template is empty if there are no paged methods. -#} -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple, Optional {% filter sort_lines -%} {% for method in service.methods.values() | selectattr('paged_result_field') -%} @@ -68,14 +68,25 @@ class {{ method.name }}Pager: self._response = self._method(self._request, metadata=self._metadata) yield self._response + {% if method.paged_result_field.map %} + def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.ident | replace('Sequence[', '') | replace(']', '') }}]]: + for page in self.pages: + yield from page.{{ method.paged_result_field.name}}.items() + + def get(self, key: str) -> {{ method.paged_result_field.ident | replace('Sequence', 'Optional') }}: + return self._response.items.get(key) + {% else %} def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}: for page in self.pages: yield from page.{{ method.paged_result_field.name }} + {% endif %} def __repr__(self) -> str: return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) +{# TODO(yon-mg): remove on rest async transport impl #} +{% if 'grpc' in opts.transport %} class {{ method.name }}AsyncPager: """A pager for iterating through ``{{ method.name|snake_case }}`` requests. @@ -138,5 +149,6 @@ class {{ method.name }}AsyncPager: def __repr__(self) -> str: return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) +{% endif %} {% endfor %} {% endblock %} diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 6affa40bc8..b99aa34262 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1029,7 +1029,7 @@ def test_{{ method.name|snake_case }}_raw_page_lro(): assert response.raw_page is response {% endif %} {#- method.paged_result_field #} -{% endfor -%} {#- method in methods #} +{% endfor -%} {#- method in methods for grpc #} {% for method in service.methods.values() if 'rest' in opts.transport -%} def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}): @@ -1169,7 +1169,122 @@ def test_{{ method.name|snake_case }}_rest_flattened_error(): ) -{% endfor -%} +{% if method.paged_result_field %} +def test_{{ method.name|snake_case }}_pager(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # Set the response as a series of pages + {% if method.paged_result_field.map%} + response = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'a':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'b':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'c':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={}, + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'g':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'h':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'i':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + ), + ) + {% else %} + response = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + ], + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[], + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + ], + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + ], + ), + ) + {% endif %} + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(map(lambda x: {{ method.output.ident }}.to_json(x), response)) + side_effect = tuple(map(lambda x: Response(), response)) + for return_val, response_val in zip(side_effect, response): + return_val._content = response_val.encode('UTF-8') + req.side_effect = side_effect + + metadata = () + {% if method.field_headers -%} + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + {%- if not method.client_streaming %} + ('{{ field_header }}', ''), + {%- endif %} + {%- endfor %} + )), + ) + {% endif -%} + pager = client.{{ method.name|snake_case }}(request={}) + + assert pager._metadata == metadata + + {% if method.paged_result_field.map %} + assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }}) + assert pager.get('h') is None + {% endif %} + + results = [i for i in pager] + assert len(results) == 6 + {% if method.paged_result_field.map %} + assert all( + isinstance(i, tuple) and + tuple(map(lambda x: type(x), results[0])) == (str, {{ method.paged_result_field.type.fields.get('value').ident }}) + for i in results) + assert pager.get('a') is None + assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) + {% else %} + assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) + for i in results) + {% endif %} + + pages = list(client.{{ method.name|snake_case }}(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +{% endif %} {# paged methods #} +{% endfor -%} {#- method in methods for rest #} def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index bcaeb68800..256ef27951 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -123,6 +123,16 @@ def test_method_paged_result_field_no_page_field(): ) assert method.paged_result_field is None + method = make_method('Foo', + input_message=make_message(name='FooRequest', fields=( + make_field(name='page_token', type=9), # str + )), + output_message=make_message(name='FooResponse', fields=( + make_field(name='next_page_token', type=9), # str + )) + ) + assert method.paged_result_field == None + def test_method_paged_result_ref_types(): input_msg = make_message( @@ -139,7 +149,7 @@ def test_method_paged_result_ref_types(): name='ListMolluscsResponse', fields=( make_field(name='molluscs', message=mollusc_msg, repeated=True), - make_field(name='next_page_token', type=9) + make_field(name='next_page_token', type=9) # str ), module='mollusc' ) @@ -207,7 +217,7 @@ def test_flattened_ref_types(): def test_method_paged_result_primitive(): - paged = make_field(name='squids', type=9, repeated=True) + paged = make_field(name='squids', type=9, repeated=True) # str input_msg = make_message( name='ListSquidsRequest', fields=( From d226f360ea3c191d49754a4140826c311371a2c7 Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Mon, 25 Jan 2021 21:58:22 +0000 Subject: [PATCH 2/7] fix: typing errors, mypy cli update --- gapic/schema/wrappers.py | 15 ++++++++------- gapic/templates/noxfile.py.j2 | 1 + noxfile.py | 2 +- tests/unit/schema/wrappers/test_method.py | 2 +- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 94f29be8a0..6d3a4c7080 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -866,16 +866,17 @@ def paged_result_field(self) -> Optional[Field]: """Return the response pagination field if the method is paginated.""" # If the request field lacks any of the expected pagination fields, # then the method is not paginated. - for page_field in ((self.input, str, 'page_token'), + for page_field_token in ((self.input, str, 'page_token'), (self.output, str, 'next_page_token')): - field = page_field[0].fields.get(page_field[2], None) - if not field or field.type != page_field[1]: + field = page_field_token[0].fields.get(page_field_token[2], None) + if not field or field.type != page_field_token[1]: return None - page_fields = [self.input.fields.get('max_results', None), - self.input.fields.get('page_size', None)] - page_field = next( + + page_fields = (self.input.fields.get('max_results', None), + self.input.fields.get('page_size', None)) + page_field_size = next( (field for field in page_fields if field is not None), None) - if not page_field or page_field.type != int: + if not page_field_size or page_field_size.type != int: return None # Return the first repeated field. diff --git a/gapic/templates/noxfile.py.j2 b/gapic/templates/noxfile.py.j2 index ee97ea01cb..b6225d867d 100644 --- a/gapic/templates/noxfile.py.j2 +++ b/gapic/templates/noxfile.py.j2 @@ -32,6 +32,7 @@ def mypy(session): session.install('.') session.run( 'mypy', + '--explicit-package-bases', {%- if api.naming.module_namespace %} '{{ api.naming.module_namespace[0] }}', {%- else %} diff --git a/noxfile.py b/noxfile.py index a50376efe1..ca83576363 100644 --- a/noxfile.py +++ b/noxfile.py @@ -262,4 +262,4 @@ def mypy(session): session.install("mypy") session.install(".") - session.run("mypy", "gapic") + session.run("mypy", "-p", "gapic") diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 256ef27951..86f72a65b5 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -131,7 +131,7 @@ def test_method_paged_result_field_no_page_field(): make_field(name='next_page_token', type=9), # str )) ) - assert method.paged_result_field == None + assert method.paged_result_field is None def test_method_paged_result_ref_types(): From 01adad1f91b623e5844b8c0dc70e641788508197 Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Mon, 25 Jan 2021 22:08:22 +0000 Subject: [PATCH 3/7] fix: mypy cli flag --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index ca83576363..e1b881ac57 100644 --- a/noxfile.py +++ b/noxfile.py @@ -227,7 +227,7 @@ def showcase_mypy( session.chdir(lib) # Run the tests. - session.run("mypy", "google") + session.run("mypy", "--explicit-package-bases", "google") @nox.session(python="3.8") From 0f3e63d08f71cddd0ef81ebb894c119b17a098ca Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Mon, 25 Jan 2021 23:30:42 +0000 Subject: [PATCH 4/7] fix: delete __init__.py, remove -p mypy flag --- .../tests/unit/gapic/%name_%version/%sub/__init__.py | 0 noxfile.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/__init__.py diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/__init__.py b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/noxfile.py b/noxfile.py index e1b881ac57..7dbe33ebc3 100644 --- a/noxfile.py +++ b/noxfile.py @@ -262,4 +262,4 @@ def mypy(session): session.install("mypy") session.install(".") - session.run("mypy", "-p", "gapic") + session.run("mypy", "gapic") From 48253d346b0aab050f0b0853471fa27a585c528c Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Wed, 27 Jan 2021 22:43:35 +0000 Subject: [PATCH 5/7] fix: clearing up statements, tests, minor bug in filter usage --- gapic/schema/wrappers.py | 11 ++-- .../services/%service/transports/rest.py.j2 | 4 +- .../%name_%version/%sub/test_%service.py.j2 | 18 ++++--- tests/unit/schema/wrappers/test_method.py | 52 +++++++++++++------ 4 files changed, 56 insertions(+), 29 deletions(-) diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 6d3a4c7080..812630720b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -866,16 +866,19 @@ def paged_result_field(self) -> Optional[Field]: """Return the response pagination field if the method is paginated.""" # If the request field lacks any of the expected pagination fields, # then the method is not paginated. - for page_field_token in ((self.input, str, 'page_token'), + + # The request must have page_token and next_page_token as they keep track of pages + for source, source_type, name in ((self.input, str, 'page_token'), (self.output, str, 'next_page_token')): - field = page_field_token[0].fields.get(page_field_token[2], None) - if not field or field.type != page_field_token[1]: + field = source.fields.get(name, None) + if not field or field.type != source_type: return None + # The request must have max_results or page_size page_fields = (self.input.fields.get('max_results', None), self.input.fields.get('page_size', None)) page_field_size = next( - (field for field in page_fields if field is not None), None) + (field for field in page_fields if field), None) if not page_field_size or page_field_size.type != int: return None diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 54ec5ca92e..c21ad5b27e 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -182,11 +182,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - {% filter sort_lines -%} - {%- for field in method.query_params %} + {%- for field in method.query_params | sort%} '{{ field|camel_case }}': request.{{ field }}, {%- endfor %} - {% endfilter -%} } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index ea36c6fe35..59611cfd33 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1230,11 +1230,12 @@ def test_{{ method.name|snake_case }}_pager(): response = response + response # Wrap the values into proper Response objs - response = tuple(map(lambda x: {{ method.output.ident }}.to_json(x), response)) - side_effect = tuple(map(lambda x: Response(), response)) - for return_val, response_val in zip(side_effect, response): + response = tuple({{ method.output.ident }}.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): return_val._content = response_val.encode('UTF-8') - req.side_effect = side_effect + return_val.status_code = 200 + req.side_effect = return_values metadata = () {% if method.field_headers -%} @@ -1257,13 +1258,16 @@ def test_{{ method.name|snake_case }}_pager(): assert pager.get('h') is None {% endif %} - results = [i for i in pager] + results = list(pager) assert len(results) == 6 {% if method.paged_result_field.map %} assert all( - isinstance(i, tuple) and - tuple(map(lambda x: type(x), results[0])) == (str, {{ method.paged_result_field.type.fields.get('value').ident }}) + isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }}) + assert pager.get('a') is None assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) {% else %} diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 86f72a65b5..2162effbbb 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -66,19 +66,38 @@ def test_method_client_output_empty(): def test_method_client_output_paged(): paged = make_field(name='foos', message=make_message('Foo'), repeated=True) + parent = make_field(name='parent', type=9) # str + page_size = make_field(name='page_size', type=5) # int + page_token = make_field(name='page_token', type=9) # str + input_msg = make_message(name='ListFoosRequest', fields=( - make_field(name='parent', type=9), # str - make_field(name='page_size', type=5), # int - make_field(name='page_token', type=9), # str + parent, + page_size, + page_token, )) output_msg = make_message(name='ListFoosResponse', fields=( paged, make_field(name='next_page_token', type=9), # str )) - method = make_method('ListFoos', - input_message=input_msg, - output_message=output_msg, - ) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) + assert method.paged_result_field == paged + assert method.client_output.ident.name == 'ListFoosPager' + + max_results = make_field(name='max_results', type=5) # int + input_msg = make_message(name='ListFoosRequest', fields=( + parent, + max_results, + page_token, + )) + method = make_method( + 'ListFoos', + input_message=input_msg, + output_message=output_msg, + ) assert method.paged_result_field == paged assert method.client_output.ident.name == 'ListFoosPager' @@ -123,14 +142,17 @@ def test_method_paged_result_field_no_page_field(): ) assert method.paged_result_field is None - method = make_method('Foo', - input_message=make_message(name='FooRequest', fields=( - make_field(name='page_token', type=9), # str - )), - output_message=make_message(name='FooResponse', fields=( - make_field(name='next_page_token', type=9), # str - )) - ) + method = make_method( + name='Foo', + input_message=make_message( + name='FooRequest', + fields=(make_field(name='page_token', type=9),) # str + ), + output_message=make_message( + name='FooResponse', + fields=(make_field(name='next_page_token', type=9),) # str + ) + ) assert method.paged_result_field is None From 86e2c43737433093b339becc7a31633c46e81089 Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Fri, 29 Jan 2021 20:03:32 +0000 Subject: [PATCH 6/7] fix: wrong genereated type hints --- .../%name_%version/%sub/services/%service/pagers.py.j2 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 index 5f04eed61e..ca3cc8d40e 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2 @@ -69,11 +69,11 @@ class {{ method.name }}Pager: yield self._response {% if method.paged_result_field.map %} - def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.ident | replace('Sequence[', '') | replace(']', '') }}]]: + def __iter__(self) -> Iterable[Tuple[str, {{ method.paged_result_field.type.fields.get('value').ident }}]]: for page in self.pages: yield from page.{{ method.paged_result_field.name}}.items() - def get(self, key: str) -> {{ method.paged_result_field.ident | replace('Sequence', 'Optional') }}: + def get(self, key: str) -> Optional[{{ method.paged_result_field.type.fields.get('value').ident }}]: return self._response.items.get(key) {% else %} def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}: From 5f1ba345ea066baef4115ce49d7910955de1a979 Mon Sep 17 00:00:00 2001 From: Yonatan Getahun Date: Sat, 13 Feb 2021 02:10:26 +0000 Subject: [PATCH 7/7] feat: add grpc_transcoding, clean up rest generation unimplemented methods/tests --- gapic/generator/generator.py | 2 +- gapic/schema/wrappers.py | 36 ++ .../services/%service/transports/rest.py.j2 | 126 ++--- .../%name_%version/%sub/test_%service.py.j2 | 437 ++++++++++-------- tests/unit/schema/wrappers/test_method.py | 82 +++- 5 files changed, 439 insertions(+), 244 deletions(-) diff --git a/gapic/generator/generator.py b/gapic/generator/generator.py index d6eb3aca9d..24cb2381ac 100644 --- a/gapic/generator/generator.py +++ b/gapic/generator/generator.py @@ -59,7 +59,7 @@ def __init__(self, opts: Options) -> None: self._env.filters["wrap"] = utils.wrap self._env.filters["coerce_response_name"] = coerce_response_name - # Add tests to determine type of expressions stored in strings + # Add tests to determine FieldDescriptorProto type self._env.tests["str_field_pb"] = utils.is_str_field_pb self._env.tests["msg_field_pb"] = utils.is_msg_field_pb diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 812630720b..6202d43a98 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -30,6 +30,7 @@ import collections import dataclasses import re +import copy from itertools import chain from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, ClassVar, Optional, Sequence, Set, Tuple, Union) @@ -741,6 +742,41 @@ def field_headers(self) -> Sequence[str]: return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) + @property + def http_options(self) -> List[Dict[str, str]]: + """Return a list of the http options for this method. + + e.g. [{'method': 'post' + 'uri': '/some/path' + 'body': '*'},] + + """ + http = self.options.Extensions[annotations_pb2.http] + http_options = copy.deepcopy(http.additional_bindings) + http_options.append(http) + answers : List[Dict[str, str]] = [] + + for http_rule in http_options: + try: + method, uri = next((method, uri) for method,uri in [ + ('get',http_rule.get), + ('put',http_rule.put), + ('post',http_rule.post), + ('delete',http_rule.delete), + ('patch',http_rule.patch), + ('custom.path',http_rule.custom.path), + ] if uri + ) + except StopIteration: + continue + answer : Dict[str, str] = {} + answer['method'] = method + answer['uri'] = uri + if http_rule.body: + answer['body'] = http_rule.body + answers.append(answer) + return answers + @property def http_opt(self) -> Optional[Dict[str, str]]: """Return the http option for this method. diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index c21ad5b27e..471418e595 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -5,16 +5,17 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple {% if service.has_lro %} -from google.api_core import operations_v1 +from google.api_core import operations_v1 # type: ignore {%- endif %} -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore +from google.api_core import gapic_v1, path_template # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore +import json -from google.auth.transport.requests import AuthorizedSession +from google.auth.transport.requests import AuthorizedSession # type: ignore {# TODO(yon-mg): re-add python_import/ python_modules from removed diff/current grpc template code #} {% filter sort_lines -%} @@ -121,9 +122,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): return self._operations_client {%- endif %} {%- for method in service.methods.values() %} - {%- if method.http_opt %} + {%- if method.http_options and not method.lro and not (method.server_streaming or method.client_streaming) %} - def {{ method.name|snake_case }}(self, + def _{{ method.name|snake_case }}(self, request: {{ method.input.ident }}, *, metadata: Sequence[Tuple[str, str]] = (), ) -> {{ method.output.ident }}: @@ -146,57 +147,49 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): {%- endif %} """ - {# TODO(yon-mg): refactor when implementing grpc transcoding - - parse request pb & assign body, path params - - shove leftovers into query params - - make sure dotted nested fields preserved - - format url and send the request - #} - {%- if 'body' in method.http_opt %} + http_options = [ + {%- for rule in method.http_options %}{ + {%- for field, argument in rule.items() | sort %} + '{{ field }}':'{{ argument }}', + {%- endfor %} + }, + {%- endfor %}] + + request_kwargs = { + field.name:value + for field, value + in {{ method.input.ident }}.pb(request).ListFields() + } + transcoded_request = path_template.transcode(http_options, **request_kwargs) + {%- if 'body' in method.http_options[0] %} + # Jsonify the request body - {%- if method.http_opt['body'] != '*' %} - body = {{ method.input.fields[method.http_opt['body']].type.ident }}.to_json( - request.{{ method.http_opt['body'] }}, - including_default_value_fields=False - ) - {%- else %} - body = {{ method.input.ident }}.to_json( - request - ) - {%- endif %} + body = {% if method.http_options[0]['body'] == '*' -%} + {{ method.input.ident }}.to_json( + {{ method.input.ident }}(transcoded_request['body']), + {%- else -%} + {{ method.input.fields[method.http_opt['body']].type.ident }}.to_json( + {{ method.input.fields[method.http_opt['body']].type.ident }}(transcoded_request['body']), {%- endif %} - - {# TODO(yon-mg): Write helper method for handling grpc transcoding url #} - # TODO(yon-mg): need to handle grpc transcoding and parse url correctly - # current impl assumes basic case of grpc transcoding - url = 'https://{host}{{ method.http_opt['url'] }}'.format( - host=self._host, - {%- for field in method.path_params %} - {{ field }}=request.{{ method.input.get_field(field).name }}, - {%- endfor %} + including_default_value_fields=False, + use_integers_for_enums=False ) + {%- endif %} - {# TODO(yon-mg): move all query param logic out of wrappers into here to handle - nested fields correctly (can't just use set of top level fields - #} - # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields - # not required for GCE - query_params = { - {%- for field in method.query_params | sort%} - '{{ field|camel_case }}': request.{{ field }}, - {%- endfor %} - } - # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here - # discards default values - # TODO(yon-mg): add test for proper url encoded strings - query_params = ['{k}={v}'.format(k=k, v=v) for k, v in query_params.items() if v] - url += '?{}'.format('&'.join(query_params)).replace(' ', '+') + # Jsonify the query params + query_params = json.loads({{ method.input.ident }}.to_json( + {{ method.input.ident }}(transcoded_request['query_params']), + including_default_value_fields=False, + use_integers_for_enums=False + )) # Send the request - {% if not method.void %}response = {% endif %}self._session.{{ method.http_opt['verb'] }}( - url + response = self._session.request( + transcoded_request['method'], + self._host.join(('https://', transcoded_request['uri'])), + params=query_params {%- if 'body' in method.http_opt %}, - json=body, + data=body, {%- endif %} ) @@ -206,9 +199,38 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): # Return the response return {{ method.output.ident }}.from_json(response.content) + {%- else %} + + # Returh the response + return {{ method.output.ident }}() + {%- endif %} + {%- else %} + + def _{{ method.name|snake_case }}(self, + request: {{ method.input.ident }}, *, + metadata: Sequence[Tuple[str, str]] = (), + ) -> {{ method.output.ident }}: + r"""Placeholder: Unable to implement over REST + """ + {%- if not method.http_options %} + raise RuntimeError('Cannot define a method without a valid `google.api.http` annotation.') + {%- elif method.lro %} + raise NotImplementedError('LRO over REST is not yet defined for python client.') + {%- elif method.server_streaming or method.client_streaming %} + raise NotImplementedError('Streaming over REST is not yet defined for python client') + {%- else %} + raise NotImplementedError() {%- endif %} {%- endif %} {%- endfor %} + {%- for method in service.methods.values() %} + + @property + def {{ method.name|snake_case }}(self) -> Callable[ + [{{ method.input.ident }}], + {{ method.output.ident }}]: + return self._{{ method.name|snake_case }} + {%- endfor %} __all__ = ( diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 59611cfd33..722e5aed75 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -30,6 +30,7 @@ from google.api_core import client_options from google.api_core import exceptions from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import path_template {% if service.has_lro -%} from google.api_core import future from google.api_core import operations_v1 @@ -1023,14 +1024,17 @@ def test_{{ method.name|snake_case }}_raw_page_lro(): {% endfor -%} {#- method in methods for grpc #} {% for method in service.methods.values() if 'rest' in opts.transport -%} -def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}): +{% if method.http_options and not method.lro and not (method.server_streaming or method.client_streaming) %} +def test_{{ method.name|snake_case }}_rest(request_type={{ method.input.ident }}): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), - transport=transport, + transport='rest', ) # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. + # but since we have to encode/decode protos, we need to mock out the + # request itself as well as the transcoding logic, the two pieces that + # have to deal with non-protobuf messages. request = request_type() {% if method.client_streaming %} requests = [request] @@ -1038,32 +1042,33 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type # Mock the http request call within the method and fake a response. with mock.patch.object(Session, 'request') as req: - # Designate an appropriate value for the returned response. - {% if method.void -%} - return_value = None - {% elif method.lro -%} - return_value = operations_pb2.Operation(name='operations/spam') - {% elif method.server_streaming -%} - return_value = iter([{{ method.output.ident }}()]) - {% else -%} - return_value = {{ method.output.ident }}( - {%- for field in method.output.fields.values() %} - {{ field.name }}={{ field.mock_value }}, - {%- endfor %} - ) - {% endif -%} + with mock.patch.object(path_template, 'transcode') as transcode: + # Designate an appropriate value for the returned response. + {% if method.void -%} + return_value = '{}' + {% else -%} + return_value = {{ method.output.ident }}( + {%- for field in method.output.fields.values() | rejectattr('message') %}{% if not field.oneof or field.proto3_optional %} + {{ field.name }}={{ field.mock_value }}, + {% endif %} + {%- endfor -%} + ) + {% endif -%} - # Wrap the value into a proper Response obj - json_return_value = {{ method.output.ident }}.to_json(return_value) - response_value = Response() - response_value.status_code = 200 - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - {% if method.client_streaming %} - response = client.{{ method.name|snake_case }}(iter(requests)) - {% else %} - response = client.{{ method.name|snake_case }}(request) - {% endif %} + # Wrap the value into a proper Response obj + json_return_value = + {%- if method.void %} return_value + {%- else %} {{ method.output.ident }}.to_json(return_value) + {%- endif %} + response_value = Response() + response_value.status_code = 200 + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + # Mock grpc transcoding to return empty stuff + transcode.return_value = {'method':'', 'uri':'', 'body':{}, 'query_params':{}} + + response = client.{{ method.name|snake_case }}(request) {% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %} {# Cheeser assertion to force code coverage for bad paginated methods #} @@ -1075,7 +1080,7 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type assert response is None {% else %} assert isinstance(response, {{ method.client_output.ident }}) - {% for field in method.output.fields.values() -%} + {% for field in method.output.fields.values() | rejectattr('message') -%}{% if not field.oneof or field.proto3_optional %} {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6) {% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #} @@ -1083,6 +1088,7 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type {% else -%} assert response.{{ field.name }} == {{ field.mock_value }} {% endif -%} + {% endif -%} {# end oneof/optional #} {% endfor %} {% endif %} @@ -1090,65 +1096,75 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type def test_{{ method.name|snake_case }}_rest_from_dict(): test_{{ method.name|snake_case }}_rest(request_type=dict) - +{% if method.flattened_fields %} def test_{{ method.name|snake_case }}_rest_flattened(): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), + transport='rest' ) # Mock the http request call within the method and fake a response. with mock.patch.object(Session, 'request') as req: - # Designate an appropriate value for the returned response. - {% if method.void -%} - return_value = None - {% elif method.lro -%} - return_value = operations_pb2.Operation(name='operations/spam') - {% elif method.server_streaming -%} - return_value = iter([{{ method.output.ident }}()]) - {% else -%} - return_value = {{ method.output.ident }}() - {% endif %} - - # Wrap the value into a proper Response obj - json_return_value = {{ method.output.ident }}.to_json(return_value) - response_value = Response() - response_value.status_code = 200 - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - {%- for field in method.flattened_fields.values() if field.field_pb is msg_field_pb %} - {{ field.name }} = {{ field.mock_value }} - {% endfor %} - client.{{ method.name|snake_case }}( - {%- for field in method.flattened_fields.values() %} - {% if field.field_pb is msg_field_pb %}{{ field.name }}={{ field.name }},{% else %}{{ field.name }}={{ field.mock_value }},{% endif %} - {%- endfor %} - ) + with mock.patch.object(path_template, 'validate') as validate: + # Designate an appropriate value for the returned response. + {% if method.void -%} + return_value = '{}' + {% else -%} + return_value = {{ method.output.ident }}() + {% endif %} + + # Wrap the value into a proper Response obj + json_return_value = + {%- if method.void %} return_value + {%- else %} {{ method.output.ident }}.to_json(return_value) + {%- endif %} + response_value = Response() + response_value.status_code = 200 + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + # Mock grpc transcoding to include flattened fields + validate.return_value = True + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + {%- for field in method.flattened_fields.values() if field.field_pb is msg_field_pb %} + {{ field.name }} = {{ field.mock_value }} + {% endfor %} + client.{{ method.name|snake_case }}( + {%- for field in method.flattened_fields.values() %} + {% if field.field_pb is msg_field_pb %}{{ field.name }}={{ field.name }},{% else %}{{ field.name }}={{ field.mock_value }},{% endif %} + {%- endfor %} + ) - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, http_call, http_params = req.mock_calls[0] - body = http_params.get('json') - {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} - {% if field.ident|string() == 'timestamp.Timestamp' -%} - assert TimestampRule().to_proto(http_call[0].{{ key }}) == {{ field.mock_value }} - {% elif field.ident|string() == 'duration.Duration' -%} - assert DurationRule().to_proto(http_call[0].{{ key }}) == {{ field.mock_value }} - {% else -%} - assert {% if field.field_pb is msg_field_pb %}{{ field.ident }}.to_json({{ field.name }}, including_default_value_fields=False) - {%- elif field.field_pb is str_field_pb %}{{ field.mock_value }} - {%- else %}str({{ field.mock_value }}) - {%- endif %} in http_call[1] + str(body) - {% endif %} - {% endif %}{% endfor %} + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, http_call, http_params = req.mock_calls[0] + body = http_params.get('data') or http_params.get('json') + query_params = http_params.get('params', {}) + {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} + {% if field.ident|string() == 'timestamp.Timestamp' -%} + assert TimestampRule().to_proto(http_call[0].{{ key }}) == {{ field.mock_value }} + {% elif field.ident|string() == 'duration.Duration' -%} + assert DurationRule().to_proto(http_call[0].{{ key }}) == {{ field.mock_value }} + {% else -%} + assert {% if field.field_pb is msg_field_pb %}{{ field.ident }}.to_json( + {{ field.name }}, + including_default_value_fields=False, + use_integers_for_enums=False + ) + {%- elif field.field_pb is str_field_pb %}{{ field.mock_value }} + {%- else %}str({{ field.mock_value }}).lower() + {%- endif %} in http_call[1] + str(body) + str(query_params.values()) + {% endif %} + {% endif %}{% endfor %} def test_{{ method.name|snake_case }}_rest_flattened_error(): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), + transport='rest' ) # Attempting to call a method with both a request object and flattened @@ -1160,127 +1176,153 @@ def test_{{ method.name|snake_case }}_rest_flattened_error(): {{ field.name }}={{ field.mock_value }}, {%- endfor %} ) +{% endif %} -{% if method.paged_result_field %} -def test_{{ method.name|snake_case }}_pager(): +{% if method.paged_result_field %} +def test_{{ method.name|snake_case }}_rest_pager(): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), + transport='rest' ) # Mock the http request call within the method and fake a response. with mock.patch.object(Session, 'request') as req: - # Set the response as a series of pages - {% if method.paged_result_field.map%} - response = ( - {{ method.output.ident }}( - {{ method.paged_result_field.name }}={ - 'a':{{ method.paged_result_field.type.fields.get('value').ident }}(), - 'b':{{ method.paged_result_field.type.fields.get('value').ident }}(), - 'c':{{ method.paged_result_field.type.fields.get('value').ident }}(), - }, - next_page_token='abc', - ), - {{ method.output.ident }}( - {{ method.paged_result_field.name }}={}, - next_page_token='def', - ), - {{ method.output.ident }}( - {{ method.paged_result_field.name }}={ - 'g':{{ method.paged_result_field.type.fields.get('value').ident }}(), - }, - next_page_token='ghi', - ), - {{ method.output.ident }}( - {{ method.paged_result_field.name }}={ - 'h':{{ method.paged_result_field.type.fields.get('value').ident }}(), - 'i':{{ method.paged_result_field.type.fields.get('value').ident }}(), - }, - ), - ) - {% else %} - response = ( - {{ method.output.ident }}( - {{ method.paged_result_field.name }}=[ - {{ method.paged_result_field.type.ident }}(), - {{ method.paged_result_field.type.ident }}(), - {{ method.paged_result_field.type.ident }}(), - ], - next_page_token='abc', - ), - {{ method.output.ident }}( - {{ method.paged_result_field.name }}=[], - next_page_token='def', - ), - {{ method.output.ident }}( - {{ method.paged_result_field.name }}=[ - {{ method.paged_result_field.type.ident }}(), - ], - next_page_token='ghi', - ), - {{ method.output.ident }}( - {{ method.paged_result_field.name }}=[ - {{ method.paged_result_field.type.ident }}(), - {{ method.paged_result_field.type.ident }}(), - ], - ), - ) - {% endif %} - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple({{ method.output.ident }}.to_json(x) for x in response) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode('UTF-8') - return_val.status_code = 200 - req.side_effect = return_values - - metadata = () - {% if method.field_headers -%} - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - {%- for field_header in method.field_headers %} - {%- if not method.client_streaming %} - ('{{ field_header }}', ''), - {%- endif %} - {%- endfor %} - )), - ) - {% endif -%} - pager = client.{{ method.name|snake_case }}(request={}) - - assert pager._metadata == metadata - - {% if method.paged_result_field.map %} - assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }}) - assert pager.get('h') is None - {% endif %} - - results = list(pager) - assert len(results) == 6 - {% if method.paged_result_field.map %} - assert all( - isinstance(i, tuple) - for i in results) - for result in results: - assert isinstance(result, tuple) - assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }}) - - assert pager.get('a') is None - assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) - {% else %} - assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) - for i in results) - {% endif %} - - pages = list(client.{{ method.name|snake_case }}(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): - assert page_.raw_page.next_page_token == token - - -{% endif %} {# paged methods #} + with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + {% if method.paged_result_field.map%} + response = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'a':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'b':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'c':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={}, + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'g':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}={ + 'h':{{ method.paged_result_field.type.fields.get('value').ident }}(), + 'i':{{ method.paged_result_field.type.fields.get('value').ident }}(), + }, + ), + ) + {% else %} + response = ( + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + ], + next_page_token='abc', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[], + next_page_token='def', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + ], + next_page_token='ghi', + ), + {{ method.output.ident }}( + {{ method.paged_result_field.name }}=[ + {{ method.paged_result_field.type.ident }}(), + {{ method.paged_result_field.type.ident }}(), + ], + ), + ) + {% endif %} + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple({{ method.output.ident }}.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode('UTF-8') + return_val.status_code = 200 + req.side_effect = return_values + + # Mock grpc transcoding to return empty stuff + transcode.return_value = {'method':'', 'uri':'', 'body':{}, 'query_params':{}} + + metadata = () + {% if method.field_headers -%} + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + {%- for field_header in method.field_headers %} + {%- if not method.client_streaming %} + ('{{ field_header }}', ''), + {%- endif %} + {%- endfor %} + )), + ) + {% endif -%} + pager = client.{{ method.name|snake_case }}(request={}) + + assert pager._metadata == metadata + + {% if method.paged_result_field.map %} + assert isinstance(pager.get('a'), {{ method.paged_result_field.type.fields.get('value').ident }}) + assert pager.get('h') is None + {% endif %} + + results = list(pager) + assert len(results) == 6 + {% if method.paged_result_field.map %} + assert all( + isinstance(i, tuple) + for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }}) + + assert pager.get('a') is None + assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) + {% else %} + assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) + for i in results) + {% endif %} + + pages = list(client.{{ method.name|snake_case }}(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +{% endif %} {# paged methods #} +{% else %} +def test_{{ method.name|snake_case }}_rest_error(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + transport='rest' + ) + {%- if not method.http_options %} + # Since a `google.api.http` annotation is required for using a rest transport + # method, this should error. + with pytest.raises(RuntimeError) as runtime_error: + client.{{ method.name|snake_case }}({}) + assert ('Cannot define a method without a valid `google.api.http` annotation.' + in str(runtime_error.value)) + {%- else %} + # TODO(yon-mg): Remove when this method has a working implementation + # or testing straegy + with pytest.raises(NotImplementedError): + client.{{ method.name|snake_case }}({}) + {%- endif %} +{% endif %} {% endfor -%} {#- method in methods for rest #} def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. @@ -1516,6 +1558,25 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl client_cert_source_for_mtls=client_cert_source_callback ) mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +{% if service.has_lro -%} +def test_{{ service.name|snake_case }}_rest_lro_client(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + transport='rest', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client +{%- endif %} {% endif %} def test_{{ service.name|snake_case }}_host_no_port(): diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 2162effbbb..9917805ab2 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -292,9 +292,6 @@ def test_method_http_opt(): 'url': '/v1/{parent=projects/*}/topics', 'body': '*' } -# TODO(yon-mg) to test: grpc transcoding, -# correct handling of path/query params -# correct handling of body & additional binding def test_method_http_opt_no_body(): @@ -323,6 +320,85 @@ def test_method_path_params_no_http_rule(): assert method.path_params == [] +def test_method_http_options(): + verbs = [ + 'get', + 'put', + 'post', + 'delete', + 'patch' + ] + for v in verbs: + http_rule = http_pb2.HttpRule(**{v:'/v1/{parent=projects/*}/topics'}) + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_options == [{ + 'method': v, + 'uri':'/v1/{parent=projects/*}/topics' + }] + + +def test_method_http_options_empty_http_rule(): + http_rule = http_pb2.HttpRule() + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_options == [] + + http_rule = http_pb2.HttpRule(get='') + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_options == [] + + +def test_method_http_options_no_http_rule(): + method = make_method('DoSomething') + assert method.path_params == [] + + +def test_method_http_options_body(): + http_rule = http_pb2.HttpRule( + post='/v1/{parent=projects/*}/topics', + body='*' + ) + method = make_method('DoSomething', http_rule=http_rule) + assert method.http_options == [{ + 'method': 'post', + 'uri': '/v1/{parent=projects/*}/topics', + 'body': '*' + }] + + +def test_method_http_options_additional_bindings(): + http_rule = http_pb2.HttpRule( + post='/v1/{parent=projects/*}/topics', + body='*', + additional_bindings=[ + http_pb2.HttpRule( + post='/v1/{parent=projects/*/regions/*}/topics', + body='*', + ), + http_pb2.HttpRule( + post='/v1/projects/p1/topics', + body='body_field', + ), + ] + ) + method = make_method('DoSomething', http_rule=http_rule) + assert len(method.http_options) == 3 + assert { + 'method':'post', + 'uri':'/v1/{parent=projects/*}/topics', + 'body':'*' + } in method.http_options + assert { + 'method':'post', + 'uri':'/v1/{parent=projects/*/regions/*}/topics', + 'body':'*' + } in method.http_options + assert { + 'method':'post', + 'uri':'/v1/projects/p1/topics', + 'body':'body_field' + } in method.http_options + + def test_method_query_params(): # tests only the basic case of grpc transcoding http_rule = http_pb2.HttpRule(