diff --git a/gapic/generator/generator.py b/gapic/generator/generator.py index 6a3446cbf8..d6eb3aca9d 100644 --- a/gapic/generator/generator.py +++ b/gapic/generator/generator.py @@ -59,6 +59,10 @@ def __init__(self, opts: Options) -> None: self._env.filters["wrap"] = utils.wrap self._env.filters["coerce_response_name"] = coerce_response_name + # Add tests to determine type of expressions stored in strings + self._env.tests["str_field_pb"] = utils.is_str_field_pb + self._env.tests["msg_field_pb"] = utils.is_msg_field_pb + self._sample_configs = opts.sample_configs def get_response( @@ -278,6 +282,10 @@ def _render_template( or ('transport' in template_name and not self._is_desired_transport(template_name, opts)) + or + # TODO(yon-mg) - remove when rest async implementation resolved + # temporarily stop async client gen while rest async is unkown + ('async' in template_name and 'grpc' not in opts.transport) ): continue diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 6020a26061..eefe0cdc7e 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -411,7 +411,9 @@ def get_field(self, *field_path: str, collisions = collisions or self.meta.address.collisions # Get the first field in the path. - cursor = self.fields[field_path[0]] + first_field = field_path[0] + cursor = self.fields[first_field + + ('_' if first_field in utils.RESERVED_NAMES else '')] # Base case: If this is the last field in the path, return it outright. if len(field_path) == 1: @@ -805,6 +807,7 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: continue name = f.strip() field = self.input.get_field(*name.split('.')) + name += '_' if field.field_pb.name in utils.RESERVED_NAMES else '' if cross_pkg_request and not field.is_primitive: # This is not a proto-plus wrapped message type, # and setting a non-primitive field directly is verboten. diff --git a/gapic/templates/%namespace/%name/__init__.py.j2 b/gapic/templates/%namespace/%name/__init__.py.j2 index d777dc86e3..7ffe67b3fe 100644 --- a/gapic/templates/%namespace/%name/__init__.py.j2 +++ b/gapic/templates/%namespace/%name/__init__.py.j2 @@ -12,8 +12,10 @@ from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.' if service.meta.address.subpackage == api.subpackage_view -%} from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%} {{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.client import {{ service.client_name }} +{%- if 'grpc' in opts.transport %} from {% if api.naming.module_namespace %}{{ api.naming.module_namespace|join('.') }}.{% endif -%} {{ api.naming.versioned_module_name }}.services.{{ service.name|snake_case }}.async_client import {{ service.async_client_name }} +{%- endif %} {% endfor -%} {# Import messages and enums from each proto. @@ -50,7 +52,9 @@ __all__ = ( {% for service in api.services.values()|sort(attribute='name') if service.meta.address.subpackage == api.subpackage_view -%} '{{ service.client_name }}', + {%- if 'grpc' in opts.transport %} '{{ service.async_client_name }}', + {%- endif %} {% endfor -%} {% for proto in api.protos.values()|sort(attribute='module_name') if proto.meta.address.subpackage == api.subpackage_view -%} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 index c99b2a5f91..e0112041c3 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/__init__.py.j2 @@ -2,10 +2,14 @@ {% block content %} from .client import {{ service.client_name }} +{%- if 'grpc' in opts.transport %} from .async_client import {{ service.async_client_name }} +{%- endif %} __all__ = ( '{{ service.client_name }}', + {%- if 'grpc' in opts.transport %} '{{ service.async_client_name }}', + {%- endif %} ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index ad7e4051b9..338bfcbaff 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -78,7 +78,13 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): Generally, you only need to set this if you're developing your own client library. """ - super().__init__(host=host, credentials=credentials) + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + ) self._session = AuthorizedSession(self._credentials) {%- if service.has_lro %} self._operations_client = None @@ -163,7 +169,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): url = 'https://{host}{{ method.http_opt['url'] }}'.format( host=self._host, {%- for field in method.path_params %} - {{ field }}=request.{{ field }}, + {{ field }}=request.{{ method.input.get_field(field).name }}, {%- endfor %} ) @@ -180,10 +186,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport): # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values # TODO(yon-mg): add test for proper url encoded strings - query_params = ((k, v) for k, v in query_params.items() if v) - for i, (param_name, param_value) in enumerate(query_params): - q = '?' if i == 0 else '&' - url += "{q}{name}={value}".format(q=q, name=param_name, value=param_value.replace(' ', '+')) + query_params = ['{k}={v}'.format(k=k, v=v) for k, v in query_params.items() if v] + url += '?{}'.format('&'.join(query_params)).replace(' ', '+') # Send the request {% if not method.void %}response = {% endif %}self._session.{{ method.http_opt['verb'] }}( diff --git a/gapic/templates/noxfile.py.j2 b/gapic/templates/noxfile.py.j2 index 65397ea71c..ee97ea01cb 100644 --- a/gapic/templates/noxfile.py.j2 +++ b/gapic/templates/noxfile.py.j2 @@ -7,7 +7,7 @@ import shutil import nox # type: ignore -@nox.session(python=['3.6', '3.7']) +@nox.session(python=['3.6', '3.7', '3.8', '3.9']) def unit(session): """Run the unit test suite.""" @@ -21,7 +21,7 @@ def unit(session): '--cov-config=.coveragerc', '--cov-report=term', '--cov-report=html', - os.path.join('tests', 'unit',) + os.path.join('tests', 'unit', ''.join(session.posargs)) ) diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 570e997a6f..f912c479a3 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -10,6 +10,11 @@ import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule +{%- if 'rest' in opts.transport %} +from requests import Response +from requests.sessions import Session +{%- endif %} + {# Import the service itself as well as every proto module that it imports. -#} {% filter sort_lines -%} from google import auth @@ -17,7 +22,9 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.oauth2 import service_account from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }} +{%- if 'grpc' in opts.transport %} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.async_client_name }} +{%- endif %} from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports from google.api_core import client_options from google.api_core import exceptions @@ -81,7 +88,12 @@ def test_{{ service.client_name|snake_case }}_from_service_account_info(): {% if service.host %}assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %} -@pytest.mark.parametrize("client_class", [{{ service.client_name }}, {{ service.async_client_name }}]) +@pytest.mark.parametrize("client_class", [ + {{ service.client_name }}, + {%- if 'grpc' in opts.transport %} + {{ service.async_client_name }}, + {%- endif %} +]) def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: @@ -97,18 +109,29 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file(client_c def test_{{ service.client_name|snake_case }}_get_transport_class(): transport = {{ service.client_name }}.get_transport_class() - assert transport == transports.{{ service.name }}GrpcTransport + available_transports = [ + {%- for transport_name in opts.transport %} + transports.{{ service.name }}{{ transport_name.capitalize() }}Transport, + {%- endfor %} + ] + assert transport in available_transports - transport = {{ service.client_name }}.get_transport_class("grpc") - assert transport == transports.{{ service.name }}GrpcTransport + transport = {{ service.client_name }}.get_transport_class("{{ opts.transport[0] }}") + assert transport == transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport @pytest.mark.parametrize("client_class,transport_class,transport_name", [ + {%- if 'grpc' in opts.transport %} ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"), - ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio") + ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"), + {%- elif 'rest' in opts.transport %} + ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"), + {%- endif %} ]) @mock.patch.object({{ service.client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.client_name }})) +{%- if 'grpc' in opts.transport %} @mock.patch.object({{ service.async_client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.async_client_name }})) +{%- endif %} def test_{{ service.client_name|snake_case }}_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. with mock.patch.object({{ service.client_name }}, 'get_transport_class') as gtc: @@ -197,13 +220,20 @@ def test_{{ service.client_name|snake_case }}_client_options(client_class, trans ) @pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + {% if 'grpc' in opts.transport %} ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc", "true"), ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "true"), ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc", "false"), - ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "false") + ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio", "false"), + {% elif 'rest' in opts.transport %} + ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest", "true"), + ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest", "false"), + {%- endif %} ]) @mock.patch.object({{ service.client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.client_name }})) +{%- if 'grpc' in opts.transport %} @mock.patch.object({{ service.async_client_name }}, "DEFAULT_ENDPOINT", modify_default_endpoint({{ service.async_client_name }})) +{%- endif %} @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default @@ -286,8 +316,12 @@ def test_{{ service.client_name|snake_case }}_mtls_env_auto(client_class, transp @pytest.mark.parametrize("client_class,transport_class,transport_name", [ + {%- if 'grpc' in opts.transport %} ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"), - ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio") + ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"), + {%- elif 'rest' in opts.transport %} + ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"), + {%- endif %} ]) def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. @@ -308,8 +342,12 @@ def test_{{ service.client_name|snake_case }}_client_options_scopes(client_class ) @pytest.mark.parametrize("client_class,transport_class,transport_name", [ + {%- if 'grpc' in opts.transport %} ({{ service.client_name }}, transports.{{ service.grpc_transport_name }}, "grpc"), - ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio") + ({{ service.async_client_name }}, transports.{{ service.grpc_asyncio_transport_name }}, "grpc_asyncio"), + {%- elif 'rest' in opts.transport %} + ({{ service.client_name }}, transports.{{ service.rest_transport_name }}, "rest"), + {%- endif %} ]) def test_{{ service.client_name|snake_case }}_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. @@ -328,6 +366,7 @@ def test_{{ service.client_name|snake_case }}_client_options_credentials_file(cl quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) +{%- if 'grpc' in opts.transport %} def test_{{ service.client_name|snake_case }}_client_options_from_dict(): @@ -345,9 +384,10 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict(): quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) +{%- endif %} -{% for method in service.methods.values() -%} +{% for method in service.methods.values() if 'grpc' in opts.transport -%} def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ method.input.ident }}): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), @@ -991,9 +1031,148 @@ def test_{{ method.name|snake_case }}_raw_page_lro(): {% endfor -%} {#- method in methods #} +{% for method in service.methods.values() if 'rest' in opts.transport -%} +def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type={{ method.input.ident }}): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + {% if method.client_streaming %} + requests = [request] + {% endif %} + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # Designate an appropriate value for the returned response. + {% if method.void -%} + return_value = None + {% elif method.lro -%} + return_value = operations_pb2.Operation(name='operations/spam') + {% elif method.server_streaming -%} + return_value = iter([{{ method.output.ident }}()]) + {% else -%} + return_value = {{ method.output.ident }}( + {%- for field in method.output.fields.values() %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + ) + {% endif -%} + + # Wrap the value into a proper Response obj + json_return_value = {{ method.output.ident }}.to_json(return_value) + response_value = Response() + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + {% if method.client_streaming %} + response = client.{{ method.name|snake_case }}(iter(requests)) + {% else %} + response = client.{{ method.name|snake_case }}(request) + {% endif %} + + {% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %} + {# Cheeser assertion to force code coverage for bad paginated methods #} + assert response.raw_page is response + {% endif %} + + # Establish that the response is the type that we expect. + {% if method.void -%} + assert response is None + {% else %} + assert isinstance(response, {{ method.client_output.ident }}) + {% for field in method.output.fields.values() -%} + {% if field.field_pb.type in [1, 2] -%} {# Use approx eq for floats -#} + assert math.isclose(response.{{ field.name }}, {{ field.mock_value }}, rel_tol=1e-6) + {% elif field.field_pb.type == 8 -%} {# Use 'is' for bools #} + assert response.{{ field.name }} is {{ field.mock_value }} + {% else -%} + assert response.{{ field.name }} == {{ field.mock_value }} + {% endif -%} + {% endfor %} + {% endif %} + + +def test_{{ method.name|snake_case }}_rest_from_dict(): + test_{{ method.name|snake_case }}_rest(request_type=dict) + + +def test_{{ method.name|snake_case }}_rest_flattened(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # Designate an appropriate value for the returned response. + {% if method.void -%} + return_value = None + {% elif method.lro -%} + return_value = operations_pb2.Operation(name='operations/spam') + {% elif method.server_streaming -%} + return_value = iter([{{ method.output.ident }}()]) + {% else -%} + return_value = {{ method.output.ident }}() + {% endif %} + + # Wrap the value into a proper Response obj + json_return_value = {{ method.output.ident }}.to_json(return_value) + response_value = Response() + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + {%- for field in method.flattened_fields.values() if field.field_pb is msg_field_pb %} + {{ field.name }} = {{ field.mock_value }} + {% endfor %} + client.{{ method.name|snake_case }}( + {%- for field in method.flattened_fields.values() %} + {% if field.field_pb is msg_field_pb %}{{ field.name }}={{ field.name }},{% else %}{{ field.name }}={{ field.mock_value }},{% endif %} + {%- endfor %} + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, http_call, http_params = req.mock_calls[0] + body = http_params.get('json') + {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} + {% if field.ident|string() == 'timestamp.Timestamp' -%} + assert TimestampRule().to_proto(http_call[0].{{ key }}) == {{ field.mock_value }} + {% elif field.ident|string() == 'duration.Duration' -%} + assert DurationRule().to_proto(http_call[0].{{ key }}) == {{ field.mock_value }} + {% else -%} + assert {% if field.field_pb is msg_field_pb %}{{ field.ident }}.to_json({{ field.name }}, including_default_value_fields=False) + {%- elif field.field_pb is str_field_pb %}{{ field.mock_value }} + {%- else %}str({{ field.mock_value }}) + {%- endif %} in http_call[1] + str(body) + {% endif %} + {% endif %}{% endfor %} + + +def test_{{ method.name|snake_case }}_rest_flattened_error(): + client = {{ service.client_name }}( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.{{ method.name|snake_case }}( + {{ method.input.ident }}(), + {%- for field in method.flattened_fields.values() %} + {{ field.name }}={{ field.mock_value }}, + {%- endfor %} + ) + + +{% endfor -%} def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. - transport = transports.{{ service.name }}GrpcTransport( + transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( credentials=credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): @@ -1003,7 +1182,7 @@ def test_credentials_transport_error(): ) # It is an error to provide a credentials file and a transport instance. - transport = transports.{{ service.name }}GrpcTransport( + transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( credentials=credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): @@ -1013,7 +1192,7 @@ def test_credentials_transport_error(): ) # It is an error to provide scopes and a transport instance. - transport = transports.{{ service.name }}GrpcTransport( + transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( credentials=credentials.AnonymousCredentials(), ) with pytest.raises(ValueError): @@ -1023,16 +1202,15 @@ def test_credentials_transport_error(): ) - def test_transport_instance(): # A client may be instantiated with a custom transport instance. - transport = transports.{{ service.name }}GrpcTransport( + transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( credentials=credentials.AnonymousCredentials(), ) client = {{ service.client_name }}(transport=transport) assert client.transport is transport - +{% if 'grpc' in opts.transport %} def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.{{ service.name }}GrpcTransport( @@ -1046,11 +1224,15 @@ def test_transport_get_channel(): ) channel = transport.grpc_channel assert channel - +{% endif %} @pytest.mark.parametrize("transport_class", [ + {%- if 'grpc' in opts.transport %} transports.{{ service.grpc_transport_name }}, - transports.{{ service.grpc_asyncio_transport_name }} + transports.{{ service.grpc_asyncio_transport_name }}, + {%- elif 'rest' in opts.transport %} + transports.{{ service.rest_transport_name }}, + {%- endif %} ]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. @@ -1059,7 +1241,7 @@ def test_transport_adc(transport_class): transport_class() adc.assert_called_once() - +{% if 'grpc' in opts.transport %} def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = {{ service.client_name }}( @@ -1069,7 +1251,7 @@ def test_transport_grpc_default(): client.transport, transports.{{ service.name }}GrpcTransport, ) - +{% endif %} def test_{{ service.name|snake_case }}_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error @@ -1151,7 +1333,7 @@ def test_{{ service.name|snake_case }}_auth_adc(): quota_project_id=None, ) - +{% if 'grpc' in opts.transport %} def test_{{ service.name|snake_case }}_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1164,6 +1346,7 @@ def test_{{ service.name|snake_case }}_transport_auth_adc(): {%- endfor %}), quota_project_id="octopus", ) +{% endif %} def test_{{ service.name|snake_case }}_host_no_port(): {% with host = (service.host|default('localhost', true)).split(':')[0] -%} @@ -1184,7 +1367,7 @@ def test_{{ service.name|snake_case }}_host_with_port(): assert client.transport._host == '{{ host }}:8000' {% endwith %} - +{% if 'grpc' in opts.transport %} def test_{{ service.name|snake_case }}_grpc_transport_channel(): channel = grpc.insecure_channel('http://localhost/') @@ -1334,6 +1517,7 @@ def test_{{ service.name|snake_case }}_grpc_lro_async_client(): assert transport.operations_client is transport.operations_client {% endif -%} +{% endif %} {# if grpc in opts #} {% with molluscs = cycler("squid", "clam", "whelk", "octopus", "oyster", "nudibranch", "cuttlefish", "mussel", "winkle", "nautilus", "scallop", "abalone") -%} {% for message in service.resource_messages|sort(attribute="resource_type") -%} @@ -1404,7 +1588,7 @@ def test_client_withDEFAULT_CLIENT_INFO(): prep.assert_called_once_with(client_info) -{% if opts.add_iam_methods %} +{% if opts.add_iam_methods and 'grpc' in opts.transport %} def test_set_iam_policy(transport: str = "grpc"): client = {{ service.client_name }}( credentials=credentials.AnonymousCredentials(), transport=transport, diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 9729591c3c..98d31c283f 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -15,6 +15,8 @@ from gapic.utils.cache import cached_property from gapic.utils.case import to_snake_case from gapic.utils.case import to_camel_case +from gapic.utils.checks import is_msg_field_pb +from gapic.utils.checks import is_str_field_pb from gapic.utils.code import empty from gapic.utils.code import nth from gapic.utils.code import partition @@ -32,6 +34,8 @@ 'cached_property', 'doc', 'empty', + 'is_msg_field_pb', + 'is_str_field_pb', 'nth', 'Options', 'partition', diff --git a/gapic/utils/case.py b/gapic/utils/case.py index f58aa4adc6..635d2945c5 100644 --- a/gapic/utils/case.py +++ b/gapic/utils/case.py @@ -21,7 +21,7 @@ def to_snake_case(s: str) -> str: This is provided to templates as the ``snake_case`` filter. Args: - s (str): The input string, provided in any sane case system. + s (str): The input string, provided in any sane case system without spaces. Returns: str: The string in snake case (and all lower-cased). @@ -53,7 +53,7 @@ def to_camel_case(s: str) -> str: This is provided to templates as the ``camel_case`` filter. Args: - s (str): The input string, provided in any sane case system + s (str): The input string, provided in any sane case system without spaces. Returns: str: The string in lower camel case. diff --git a/gapic/utils/checks.py b/gapic/utils/checks.py new file mode 100644 index 0000000000..a4f7ec7445 --- /dev/null +++ b/gapic/utils/checks.py @@ -0,0 +1,33 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto + + +def is_str_field_pb(field_pb: FieldDescriptorProto) -> bool: + """Determine if field_pb is of type string. + + Args: + field (Field): The input field as a FieldDescriptorProto + """ + return field_pb.type == FieldDescriptorProto.TYPE_STRING + + +def is_msg_field_pb(field_pb: FieldDescriptorProto) -> bool: + """Determine if field_pb is of type Message. + + Args: + field (Field): The input field as a FieldDescriptorProto. + """ + return field_pb.type == FieldDescriptorProto.TYPE_MESSAGE diff --git a/tests/unit/generator/test_generator.py b/tests/unit/generator/test_generator.py index 97793e4433..3f66033d42 100644 --- a/tests/unit/generator/test_generator.py +++ b/tests/unit/generator/test_generator.py @@ -116,7 +116,7 @@ def test_get_response_fails_invalid_file_paths(): assert "%proto" in ex_str and "%service" in ex_str -def test_get_response_ignores_unwanted_transports(): +def test_get_response_ignores_unwanted_transports_and_clients(): g = make_generator() with mock.patch.object(jinja2.FileSystemLoader, "list_templates") as lt: lt.return_value = [ @@ -125,31 +125,49 @@ def test_get_response_ignores_unwanted_transports(): "foo/%service/transports/grpc.py.j2", "foo/%service/transports/__init__.py.j2", "foo/%service/transports/base.py.j2", + "foo/%service/async_client.py.j2", + "foo/%service/client.py.j2", "mollusks/squid/sample.py.j2", ] with mock.patch.object(jinja2.Environment, "get_template") as gt: gt.return_value = jinja2.Template("Service: {{ service.name }}") + api_schema = make_api( + make_proto( + descriptor_pb2.FileDescriptorProto( + service=[ + descriptor_pb2.ServiceDescriptorProto( + name="SomeService"), + ] + ), + ) + ) + cgr = g.get_response( - api_schema=make_api( - make_proto( - descriptor_pb2.FileDescriptorProto( - service=[ - descriptor_pb2.ServiceDescriptorProto( - name="SomeService"), - ] - ), - ) - ), + api_schema=api_schema, opts=Options.build("transport=river+car") ) - - assert len(cgr.file) == 4 + assert len(cgr.file) == 5 assert {i.name for i in cgr.file} == { "foo/some_service/transports/river.py", "foo/some_service/transports/car.py", "foo/some_service/transports/__init__.py", "foo/some_service/transports/base.py", + # Only generate async client with grpc transport + "foo/some_service/client.py", + } + + cgr = g.get_response( + api_schema=api_schema, + opts=Options.build("transport=grpc") + ) + assert len(cgr.file) == 5 + assert {i.name for i in cgr.file} == { + "foo/some_service/transports/grpc.py", + "foo/some_service/transports/__init__.py", + "foo/some_service/transports/base.py", + "foo/some_service/client.py", + "foo/some_service/async_client.py", } diff --git a/tests/unit/utils/test_checks.py b/tests/unit/utils/test_checks.py new file mode 100644 index 0000000000..32d5b33b49 --- /dev/null +++ b/tests/unit/utils/test_checks.py @@ -0,0 +1,34 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gapic.utils import checks +from test_utils.test_utils import make_field, make_message + + +def test_is_str_field_pb(): + msg_field = make_field('msg_field', message=make_message('test_msg')) + str_field = make_field('str_field', type=9) + int_field = make_field('int_field', type=5) + assert not checks.is_str_field_pb(msg_field.field_pb) + assert checks.is_str_field_pb(str_field.field_pb) + assert not checks.is_str_field_pb(int_field.field_pb) + + +def test_is_msg_field_pb(): + msg_field = make_field('msg_field', message=make_message('test_msg')) + str_field = make_field('str_field', type=9) + int_field = make_field('int_field', type=5) + assert checks.is_msg_field_pb(msg_field.field_pb) + assert not checks.is_msg_field_pb(str_field.field_pb) + assert not checks.is_msg_field_pb(int_field.field_pb)