diff --git a/gapic/schema/api.py b/gapic/schema/api.py index df3e1daa8e..3c79d8f7cd 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -615,7 +615,7 @@ def _get_fields(self, # `_load_message` method. answer: Dict[str, wrappers.Field] = collections.OrderedDict() for i, field_pb in enumerate(field_pbs): - is_oneof = oneofs and field_pb.oneof_index > 0 + is_oneof = oneofs and field_pb.HasField('oneof_index') oneof_name = nth( (oneofs or {}).keys(), field_pb.oneof_index diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 1061620378..bbeeec679b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -239,6 +239,15 @@ def __hash__(self): # Identity is sufficiently unambiguous. return hash(self.ident) + def oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + @utils.cached_property def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: answer = tuple( @@ -583,6 +592,15 @@ def client_output(self): def client_output_async(self): return self._client_output(enable_asyncio=True) + def flattened_oneof_fields(self, include_optional=False): + oneof_fields = collections.defaultdict(list) + for field in self.flattened_fields.values(): + # Only include proto3 optional oneofs if explicitly looked for. + if field.oneof and not field.proto3_optional or include_optional: + oneof_fields[field.oneof].append(field) + + return oneof_fields + def _client_output(self, enable_asyncio: bool): """Return the output from the client layer. @@ -685,6 +703,10 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: return answer + @utils.cached_property + def flattened_field_to_key(self): + return {field.name: key for key, field in self.flattened_fields.items()} + @utils.cached_property def legacy_flattened_fields(self) -> Mapping[str, Field]: """Return the legacy flattening interface: top level fields only, diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 4f30579121..322094d226 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -288,9 +288,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'): call.return_value = iter([{{ method.output.ident }}()]) {% else -%} call.return_value = {{ method.output.ident }}( - {%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %} + {%- for field in method.output.fields.values() | rejectattr('message')%}{% if not field.oneof or field.proto3_optional %} {{ field.name }}={{ field.mock_value }}, {% endif %}{%- endfor %} + {#- This is a hack to only pick one field #} + {%- for oneof_fields in method.output.oneof_fields().values() %} + {% with field = oneof_fields[0] %} + {{ field.name }}={{ field.mock_value }}, + {%- endwith %} + {%- endfor %} ) {% endif -%} {% if method.client_streaming %} @@ -567,9 +573,15 @@ def test_{{ method.name|snake_case }}_flattened(): # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - {% for key, field in method.flattened_fields.items() -%} + {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} assert args[0].{{ key }} == {{ field.mock_value }} - {% endfor %} + {% endif %}{% endfor %} + {%- for oneofs in method.flattened_oneof_fields().values() %} + {%- with field = oneofs[-1] %} + assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }} + {%- endwith %} + {%- endfor %} + def test_{{ method.name|snake_case }}_flattened_error(): @@ -640,9 +652,14 @@ async def test_{{ method.name|snake_case }}_flattened_async(): # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - {% for key, field in method.flattened_fields.items() -%} + {% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %} assert args[0].{{ key }} == {{ field.mock_value }} - {% endfor %} + {% endif %}{% endfor %} + {%- for oneofs in method.flattened_oneof_fields().values() %} + {%- with field = oneofs[-1] %} + assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }} + {%- endwith %} + {%- endfor %} @pytest.mark.asyncio diff --git a/tests/unit/schema/wrappers/test_message.py b/tests/unit/schema/wrappers/test_message.py index 7ae95d0299..7d8cca169a 100644 --- a/tests/unit/schema/wrappers/test_message.py +++ b/tests/unit/schema/wrappers/test_message.py @@ -235,3 +235,26 @@ def test_field_map(): entry_field = make_field('foos', message=entry_msg, repeated=True) assert entry_msg.map assert entry_field.map + + +def test_oneof_fields(): + mass_kg = make_field(name="mass_kg", oneof="mass", type=5) + mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) + length_m = make_field(name="length_m", oneof="length", type=5) + length_f = make_field(name="length_f", oneof="length", type=5) + color = make_field(name="color", type=5) + request = make_message( + name="CreateMolluscReuqest", + fields=( + mass_kg, + mass_lbs, + length_m, + length_f, + color, + ), + ) + actual_oneofs = request.oneof_fields() + expected_oneofs = { + "mass": [mass_kg, mass_lbs], + "length": [length_m, length_f], + } diff --git a/tests/unit/schema/wrappers/test_method.py b/tests/unit/schema/wrappers/test_method.py index c0102402c2..f10bb078cd 100644 --- a/tests/unit/schema/wrappers/test_method.py +++ b/tests/unit/schema/wrappers/test_method.py @@ -364,3 +364,59 @@ def test_method_legacy_flattened_fields(): ]) assert method.legacy_flattened_fields == expected + + +def test_flattened_oneof_fields(): + mass_kg = make_field(name="mass_kg", oneof="mass", type=5) + mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5) + + length_m = make_field(name="length_m", oneof="length", type=5) + length_f = make_field(name="length_f", oneof="length", type=5) + + color = make_field(name="color", type=5) + mantle = make_field( + name="mantle", + message=make_message( + name="Mantle", + fields=( + make_field(name="color", type=5), + mass_kg, + mass_lbs, + ), + ), + ) + request = make_message( + name="CreateMolluscReuqest", + fields=( + length_m, + length_f, + color, + mantle, + ), + ) + method = make_method( + name="CreateMollusc", + input_message=request, + signatures=[ + "length_m,", + "length_f,", + "mantle.mass_kg,", + "mantle.mass_lbs,", + "color", + ] + ) + + expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]} + actual = method.flattened_oneof_fields() + assert expected == actual + + # Check this method too becasue the setup is a lot of work. + expected = { + "color": "color", + "length_m": "length_m", + "length_f": "length_f", + "mass_kg": "mantle.mass_kg", + "mass_lbs": "mantle.mass_lbs", + } + actual = method.flattened_field_to_key + assert expected == actual