Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix rest transport logic #1039

Merged
merged 2 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,9 +925,21 @@ def query_params(self) -> Set[str]:

return set(self.input.fields) - params

@property
def body_fields(self) -> Mapping[str, Field]:
bindings = self.http_options
if bindings and bindings[0].body and bindings[0].body != "*":
return self._fields_mapping([bindings[0].body])
return {}

# TODO(yon-mg): refactor as there may be more than one method signature
@utils.cached_property
def flattened_fields(self) -> Mapping[str, Field]:
signatures = self.options.Extensions[client_pb2.method_signature]
return self._fields_mapping(signatures)

# TODO(yon-mg): refactor as there may be more than one method signature
def _fields_mapping(self, signatures) -> Mapping[str, Field]:
"""Return the signature defined for this method."""
cross_pkg_request = self.input.ident.package != self.ident.package

Expand All @@ -946,7 +958,6 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]:

yield name, field

signatures = self.options.Extensions[client_pb2.method_signature]
answer: Dict[str, Field] = collections.OrderedDict(
name_and_field
for sig in signatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
{% if "grpc" in opts.transport %}
always_use_jwt_access=True,
{% endif %}
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from google.auth.transport.requests import AuthorizedSession
from google.auth.transport.requests import AuthorizedSession # type: ignore
import json # type: ignore
import grpc # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import rest_helpers # type: ignore
from google.api_core import path_template # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import rest_helpers # type: ignore
from google.api_core import path_template # type: ignore
from google.api_core import gapic_v1 # type: ignore
{% if service.has_lro %}
from google.api_core import operations_v1
{% endif %}
from requests import __version__ as requests_version
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import warnings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,14 @@ def test_{{ method.name|snake_case }}_rest(transport: str = 'rest', request_type
)

# send a request that will satisfy transcoding
request = request_type({{ method.http_options[0].sample_request}})
request_init = {{ method.http_options[0].sample_request}}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
{% endif %}
{% endfor %}
request = request_type(request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -2417,19 +2424,18 @@ async def test_test_iam_permissions_from_dict_async():
)
call.assert_called()

@pytest.mark.asyncio
async def test_transport_close_async():
client = {{ service.async_client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport="grpc_asyncio",
)
with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close:
async with client:
close.assert_not_called()
close.assert_called_once()
{% endif %}

@pytest.mark.asyncio
async def test_transport_close_async():
client = {{ service.async_client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport="grpc_asyncio",
)
with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close:
async with client:
close.assert_not_called()
close.assert_called_once()

def test_transport_close():
transports = {
{% if 'rest' in opts.transport %}
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/schema/wrappers/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,35 @@ def test_method_path_params_no_http_rule():
assert method.path_params == []


def test_body_fields():
http_rule = http_pb2.HttpRule(
post='/v1/{arms_shape=arms/*}/squids',
body='mantle'
)

mantle_stuff = make_field(name='mantle_stuff', type=9)
message = make_message('Mantle', fields=(mantle_stuff,))
mantle = make_field('mantle', type=11, type_name='Mantle', message=message)
arms_shape = make_field('arms_shape', type=9)
input_message = make_message('Squid', fields=(mantle, arms_shape))
method = make_method(
'PutSquid', input_message=input_message, http_rule=http_rule)
assert set(method.body_fields) == {'mantle'}
mock_value = method.body_fields['mantle'].mock_value
assert mock_value == "baz.Mantle(mantle_stuff='mantle_stuff_value')"


def test_body_fields_no_body():
http_rule = http_pb2.HttpRule(
post='/v1/{arms_shape=arms/*}/squids',
)

method = make_method(
'PutSquid', http_rule=http_rule)

assert not method.body_fields


def test_method_http_options():
verbs = [
'get',
Expand Down Expand Up @@ -363,7 +392,7 @@ def test_method_http_options_no_http_rule():
assert method.path_params == []


def test_method_http_options_body():
def test_method_http_options_body_star():
http_rule = http_pb2.HttpRule(
post='/v1/{parent=projects/*}/topics',
body='*'
Expand All @@ -376,6 +405,19 @@ def test_method_http_options_body():
}]


def test_method_http_options_body_field():
http_rule = http_pb2.HttpRule(
post='/v1/{parent=projects/*}/topics',
body='body_field'
)
method = make_method('DoSomething', http_rule=http_rule)
assert [dataclasses.asdict(http) for http in method.http_options] == [{
'method': 'post',
'uri': '/v1/{parent=projects/*}/topics',
'body': 'body_field'
}]


def test_method_http_options_additional_bindings():
http_rule = http_pb2.HttpRule(
post='/v1/{parent=projects/*}/topics',
Expand Down