Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: non-string required fields provide correct values #1108

Merged
merged 4 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 70 additions & 8 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import json
import re
from itertools import chain
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
from typing import (Any, cast, Dict, FrozenSet, Iterator, Iterable, List, Mapping,
ClassVar, Optional, Sequence, Set, Tuple, Union)
from google.api import annotations_pb2 # type: ignore
from google.api import client_pb2
Expand Down Expand Up @@ -757,17 +757,79 @@ class HttpRule:
uri: str
body: Optional[str]

@property
def path_fields(self) -> List[Tuple[str, str]]:
def path_fields(self, method: "~.Method") -> List[Tuple[Field, str, str]]:
"""return list of (name, template) tuples extracted from uri."""
return [(match.group("name"), match.group("template"))
input = method.input
return [(input.get_field(*match.group("name").split(".")), match.group("name"), match.group("template"))
for match in path_template._VARIABLE_RE.finditer(self.uri)]

@property
def sample_request(self) -> str:
def sample_request(self, method: "~.Method") -> str:
software-dov marked this conversation as resolved.
Show resolved Hide resolved
"""return json dict for sample request matching the uri template."""
sample = utils.sample_from_path_fields(self.path_fields)
return json.dumps(sample)

def sample_from_path_fields(paths: List[Tuple["wrappers.Field", str, str]]) -> Dict[Any, Any]:
"""Construct a dict for a sample request object from a list of fields
and template patterns.

Args:
paths: a list of tuples, each with a (segmented) name and a pattern.
Returns:
A new nested dict with the templates instantiated.
"""

request: Dict[str, Any] = {}

def _sample_names() -> Iterator[str]:
sample_num: int = 0
while True:
sample_num += 1
yield "sample{}".format(sample_num)

def add_field(obj, path, value):
"""Insert a field into a nested dict and return the (outer) dict.
Keys and sub-dicts are inserted if necessary to create the path.
e.g. if obj, as passed in, is {}, path is "a.b.c", and value is
"hello", obj will be updated to:
{'a':
{'b':
{
'c': 'hello'
}
}
}

Args:
obj: a (possibly) nested dict (parsed json)
path: a segmented field name, e.g. "a.b.c"
where each part is a dict key.
value: the value of the new key.
Returns:
obj, possibly modified
Raises:
AttributeError if the path references a key that is
not a dict.: e.g. path='a.b', obj = {'a':'abc'}
"""

segments = path.split('.')
leaf = segments.pop()
subfield = obj
for segment in segments:
subfield = subfield.setdefault(segment, {})
subfield[leaf] = value
return obj

sample_names = _sample_names()
for field, path, template in paths:
sample_value = re.sub(
r"(\*\*|\*)",
lambda n: next(sample_names),
template or '*'
) if field.type == PrimitiveType.build(str) else field.mock_value_original_type
add_field(request, path, sample_value)

return request

sample = sample_from_path_fields(self.path_fields(method))
return sample

@classmethod
def try_parse_http_rule(cls, http_rule) -> Optional['HttpRule']:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if method.input.required_fields %}
__{{ method.name | snake_case }}_required_fields_default_values = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %},{# default is str #}
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
software-dov marked this conversation as resolved.
Show resolved Hide resolved
{% endfor %}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1134,11 +1134,11 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
)

# send a request that will satisfy transcoding
request_init = {{ method.http_options[0].sample_request}}
request_init = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
request_init["{{ field.name }}"] = {{ field.mock_value_original_type }}
{% endif %}
{% endfor %}
request = request_type(request_init)
Expand Down Expand Up @@ -1221,10 +1221,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide

request_init = {}
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% if req_field.field_pb.default_value is string %}
{% if req_field.field_pb.type == 9 %}
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
{% else %}
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
request_init["{{ req_field.name }}"] = {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
Expand Down Expand Up @@ -1324,10 +1324,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
{% for req_field in method.input.required_fields if req_field.is_primitive %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.default_value is string %}
{% if req_field.field_pb.type == 9 %}
"{{ req_field.field_pb.default_value }}",
{% else %}
{{ req_field.field_pb.default_value }},
{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }},
{% endif %}{# default is str #}
),
{% endfor %}
Expand All @@ -1346,11 +1346,11 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ
)

# send a request that will satisfy transcoding
request_init = {{ method.http_options[0].sample_request}}
request_init = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
request_init["{{ field.name }}"] = {{ field.mock_value }}
request_init["{{ field.name }}"] = {{ field.mock_value_original_type }}
{% endif %}
{% endfor %}
request = request_type(request_init)
Expand Down Expand Up @@ -1411,7 +1411,7 @@ def test_{{ method_name }}_rest_flattened(transport: str = 'rest'):
req.return_value = response_value

# get arguments that satisfy an http rule for this method
sample_request = {{ method.http_options[0].sample_request }}
sample_request = {{ method.http_options[0].sample_request(method) }}

# get truthy value for each flattened field
mock_args = dict(
Expand Down Expand Up @@ -1531,7 +1531,7 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
return_val.status_code = 200
req.side_effect = return_values

sample_request = {{ method.http_options[0].sample_request }}
sample_request = {{ method.http_options[0].sample_request(method) }}
{% for field in method.body_fields.values() %}
{% if not field.oneof or field.proto3_optional %}
{# ignore oneof fields that might conflict with sample_request #}
Expand Down
2 changes: 0 additions & 2 deletions gapic/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from gapic.utils.reserved_names import RESERVED_NAMES
from gapic.utils.rst import rst
from gapic.utils.uri_conv import convert_uri_fieldnames
from gapic.utils.uri_sample import sample_from_path_fields


__all__ = (
Expand All @@ -44,7 +43,6 @@
'partition',
'RESERVED_NAMES',
'rst',
'sample_from_path_fields',
'sort_lines',
'to_snake_case',
'to_camel_case',
Expand Down
78 changes: 0 additions & 78 deletions gapic/utils/uri_sample.py

This file was deleted.

41 changes: 41 additions & 0 deletions tests/fragments/test_required_non_string.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package google.fragment;

import "google/api/client.proto";
import "google/api/field_behavior.proto";
import "google/api/annotations.proto";

service RestService {
option (google.api.default_host) = "my.example.com";

rpc MyMethod(MethodRequest) returns (MethodResponse) {
option (google.api.http) = {
get: "/restservice/v1/mass_kg/{mass_kg}/length_cm/{length_cm}"
};
}
}


message MethodRequest {
int32 mass_kg = 1 [(google.api.field_behavior) = REQUIRED];
float length_cm = 2 [(google.api.field_behavior) = REQUIRED];
}

message MethodResponse {
string name = 1;
}
51 changes: 45 additions & 6 deletions tests/unit/schema/wrappers/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,19 +470,58 @@ def test_method_http_options_generate_sample():
http_rule = http_pb2.HttpRule(
get='/v1/{resource.id=projects/*/regions/*/id/**}/stuff',
)
method = make_method('DoSomething', http_rule=http_rule)
sample = method.http_options[0].sample_request
assert json.loads(sample) == {'resource': {

method = make_method(
'DoSomething',
make_message(
name="Input",
fields=[
make_field(
name="resource",
number=1,
type=11,
message=make_message(
"Resource",
fields=[
make_field(name="id", type=9),
],
),
),
],
),
http_rule=http_rule,
)
sample = method.http_options[0].sample_request(method)
assert sample == {'resource': {
'id': 'projects/sample1/regions/sample2/id/sample3'}}


def test_method_http_options_generate_sample_implicit_template():
http_rule = http_pb2.HttpRule(
get='/v1/{resource.id}/stuff',
)
method = make_method('DoSomething', http_rule=http_rule)
sample = method.http_options[0].sample_request
assert json.loads(sample) == {'resource': {
method = make_method(
'DoSomething',
make_message(
name="Input",
fields=[
make_field(
name="resource",
number=1,
message=make_message(
"Resource",
fields=[
make_field(name="id", type=9),
],
),
),
],
),
http_rule=http_rule,
)

sample = method.http_options[0].sample_request(method)
assert sample == {'resource': {
'id': 'sample1'}}


Expand Down