diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 9dabdfa4da..aa474e0e4e 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -32,7 +32,7 @@ import json import re from itertools import chain -from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping, +from typing import (Any, cast, Dict, FrozenSet, Iterator, Iterable, List, Mapping, ClassVar, Optional, Sequence, Set, Tuple, Union) from google.api import annotations_pb2 # type: ignore from google.api import client_pb2 @@ -757,17 +757,79 @@ class HttpRule: uri: str body: Optional[str] - @property - def path_fields(self) -> List[Tuple[str, str]]: + def path_fields(self, method: "~.Method") -> List[Tuple[Field, str, str]]: """return list of (name, template) tuples extracted from uri.""" - return [(match.group("name"), match.group("template")) + input = method.input + return [(input.get_field(*match.group("name").split(".")), match.group("name"), match.group("template")) for match in path_template._VARIABLE_RE.finditer(self.uri)] - @property - def sample_request(self) -> str: + def sample_request(self, method: "~.Method") -> str: """return json dict for sample request matching the uri template.""" - sample = utils.sample_from_path_fields(self.path_fields) - return json.dumps(sample) + + def sample_from_path_fields(paths: List[Tuple["wrappers.Field", str, str]]) -> Dict[Any, Any]: + """Construct a dict for a sample request object from a list of fields + and template patterns. + + Args: + paths: a list of tuples, each with a (segmented) name and a pattern. + Returns: + A new nested dict with the templates instantiated. + """ + + request: Dict[str, Any] = {} + + def _sample_names() -> Iterator[str]: + sample_num: int = 0 + while True: + sample_num += 1 + yield "sample{}".format(sample_num) + + def add_field(obj, path, value): + """Insert a field into a nested dict and return the (outer) dict. + Keys and sub-dicts are inserted if necessary to create the path. + e.g. if obj, as passed in, is {}, path is "a.b.c", and value is + "hello", obj will be updated to: + {'a': + {'b': + { + 'c': 'hello' + } + } + } + + Args: + obj: a (possibly) nested dict (parsed json) + path: a segmented field name, e.g. "a.b.c" + where each part is a dict key. + value: the value of the new key. + Returns: + obj, possibly modified + Raises: + AttributeError if the path references a key that is + not a dict.: e.g. path='a.b', obj = {'a':'abc'} + """ + + segments = path.split('.') + leaf = segments.pop() + subfield = obj + for segment in segments: + subfield = subfield.setdefault(segment, {}) + subfield[leaf] = value + return obj + + sample_names = _sample_names() + for field, path, template in paths: + sample_value = re.sub( + r"(\*\*|\*)", + lambda n: next(sample_names), + template or '*' + ) if field.type == PrimitiveType.build(str) else field.mock_value_original_type + add_field(request, path, sample_value) + + return request + + sample = sample_from_path_fields(self.path_fields(method)) + return sample @classmethod def try_parse_http_rule(cls, http_rule) -> Optional['HttpRule']: diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index a94ce42c53..0d0998cf43 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -170,7 +170,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if method.input.required_fields %} __{{ method.name | snake_case }}_required_fields_default_values = { {% for req_field in method.input.required_fields if req_field.is_primitive %} - "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %},{# default is str #} + "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #} {% endfor %} } diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index fdcec0eeaa..4cd8c37bf3 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1134,11 +1134,11 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method. ) # send a request that will satisfy transcoding - request_init = {{ method.http_options[0].sample_request}} + request_init = {{ method.http_options[0].sample_request(method) }} {% for field in method.body_fields.values() %} {% if not field.oneof or field.proto3_optional %} {# ignore oneof fields that might conflict with sample_request #} - request_init["{{ field.name }}"] = {{ field.mock_value }} + request_init["{{ field.name }}"] = {{ field.mock_value_original_type }} {% endif %} {% endfor %} request = request_type(request_init) @@ -1221,10 +1221,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide request_init = {} {% for req_field in method.input.required_fields if req_field.is_primitive %} - {% if req_field.field_pb.default_value is string %} + {% if req_field.field_pb.type == 9 %} request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}" {% else %} - request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }} + request_init["{{ req_field.name }}"] = {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }} {% endif %}{# default is str #} {% endfor %} request = request_type(request_init) @@ -1324,10 +1324,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide {% for req_field in method.input.required_fields if req_field.is_primitive %} ( "{{ req_field.name | camel_case }}", - {% if req_field.field_pb.default_value is string %} + {% if req_field.field_pb.type == 9 %} "{{ req_field.field_pb.default_value }}", {% else %} - {{ req_field.field_pb.default_value }}, + {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}, {% endif %}{# default is str #} ), {% endfor %} @@ -1346,11 +1346,11 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ ) # send a request that will satisfy transcoding - request_init = {{ method.http_options[0].sample_request}} + request_init = {{ method.http_options[0].sample_request(method) }} {% for field in method.body_fields.values() %} {% if not field.oneof or field.proto3_optional %} {# ignore oneof fields that might conflict with sample_request #} - request_init["{{ field.name }}"] = {{ field.mock_value }} + request_init["{{ field.name }}"] = {{ field.mock_value_original_type }} {% endif %} {% endfor %} request = request_type(request_init) @@ -1411,7 +1411,7 @@ def test_{{ method_name }}_rest_flattened(transport: str = 'rest'): req.return_value = response_value # get arguments that satisfy an http rule for this method - sample_request = {{ method.http_options[0].sample_request }} + sample_request = {{ method.http_options[0].sample_request(method) }} # get truthy value for each flattened field mock_args = dict( @@ -1531,7 +1531,7 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'): return_val.status_code = 200 req.side_effect = return_values - sample_request = {{ method.http_options[0].sample_request }} + sample_request = {{ method.http_options[0].sample_request(method) }} {% for field in method.body_fields.values() %} {% if not field.oneof or field.proto3_optional %} {# ignore oneof fields that might conflict with sample_request #} diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 5000d78b49..047cc4f300 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -29,7 +29,6 @@ from gapic.utils.reserved_names import RESERVED_NAMES from gapic.utils.rst import rst from gapic.utils.uri_conv import convert_uri_fieldnames -from gapic.utils.uri_sample import sample_from_path_fields __all__ = ( @@ -44,7 +43,6 @@ 'partition', 'RESERVED_NAMES', 'rst', - 'sample_from_path_fields', 'sort_lines', 'to_snake_case', 'to_camel_case', diff --git a/gapic/utils/uri_sample.py b/gapic/utils/uri_sample.py deleted file mode 100644 index 0eba82220f..0000000000 --- a/gapic/utils/uri_sample.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Generator, Dict, List, Tuple -import re - - -def _sample_names() -> Generator[str, None, None]: - sample_num: int = 0 - while True: - sample_num += 1 - yield "sample{}".format(sample_num) - - -def add_field(obj, path, value): - """Insert a field into a nested dict and return the (outer) dict. - Keys and sub-dicts are inserted if necessary to create the path. - e.g. if obj, as passed in, is {}, path is "a.b.c", and value is - "hello", obj will be updated to: - {'a': - {'b': - { - 'c': 'hello' - } - } - } - - Args: - obj: a (possibly) nested dict (parsed json) - path: a segmented field name, e.g. "a.b.c" - where each part is a dict key. - value: the value of the new key. - Returns: - obj, possibly modified - Raises: - AttributeError if the path references a key that is - not a dict.: e.g. path='a.b', obj = {'a':'abc'} - """ - segments = path.split('.') - leaf = segments.pop() - subfield = obj - for segment in segments: - subfield = subfield.setdefault(segment, {}) - subfield[leaf] = value - return obj - - -def sample_from_path_fields(paths: List[Tuple[str, str]]) -> Dict[Any, Any]: - """Construct a dict for a sample request object from a list of fields - and template patterns. - - Args: - paths: a list of tuples, each with a (segmented) name and a pattern. - Returns: - A new nested dict with the templates instantiated. - """ - - request: Dict[str, Any] = {} - sample_names = _sample_names() - - for path, template in paths: - sample_value = re.sub( - r"(\*\*|\*)", - lambda n: next(sample_names), template if template else '*' - ) - add_field(request, path, sample_value) - return request diff --git a/tests/fragments/test_required_non_string.proto b/tests/fragments/test_required_non_string.proto new file mode 100644 index 0000000000..fb055d6019 --- /dev/null +++ b/tests/fragments/test_required_non_string.proto @@ -0,0 +1,41 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.fragment; + +import "google/api/client.proto"; +import "google/api/field_behavior.proto"; +import "google/api/annotations.proto"; + +service RestService { + option (google.api.default_host) = "my.example.com"; + + rpc MyMethod(MethodRequest) returns (MethodResponse) { + option (google.api.http) = { + get: "/restservice/v1/mass_kg/{mass_kg}/length_cm/{length_cm}" + }; + } +} + + +message MethodRequest { + int32 mass_kg = 1 [(google.api.field_behavior) = REQUIRED]; + float length_cm = 2 [(google.api.field_behavior) = REQUIRED]; +} + +message MethodResponse { + string name = 1; +} \ No newline at end of file diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index 774f81b172..889dc629a3 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -470,9 +470,29 @@ def test_method_http_options_generate_sample(): http_rule = http_pb2.HttpRule( get='/v1/{resource.id=projects/*/regions/*/id/**}/stuff', ) - method = make_method('DoSomething', http_rule=http_rule) - sample = method.http_options[0].sample_request - assert json.loads(sample) == {'resource': { + + method = make_method( + 'DoSomething', + make_message( + name="Input", + fields=[ + make_field( + name="resource", + number=1, + type=11, + message=make_message( + "Resource", + fields=[ + make_field(name="id", type=9), + ], + ), + ), + ], + ), + http_rule=http_rule, + ) + sample = method.http_options[0].sample_request(method) + assert sample == {'resource': { 'id': 'projects/sample1/regions/sample2/id/sample3'}} @@ -480,9 +500,28 @@ def test_method_http_options_generate_sample_implicit_template(): http_rule = http_pb2.HttpRule( get='/v1/{resource.id}/stuff', ) - method = make_method('DoSomething', http_rule=http_rule) - sample = method.http_options[0].sample_request - assert json.loads(sample) == {'resource': { + method = make_method( + 'DoSomething', + make_message( + name="Input", + fields=[ + make_field( + name="resource", + number=1, + message=make_message( + "Resource", + fields=[ + make_field(name="id", type=9), + ], + ), + ), + ], + ), + http_rule=http_rule, + ) + + sample = method.http_options[0].sample_request(method) + assert sample == {'resource': { 'id': 'sample1'}}