From acc1df83f73cda03c10722fbb739c7920b6b9048 Mon Sep 17 00:00:00 2001 From: Tin Rabzelj Date: Sun, 19 Nov 2023 01:26:34 +0100 Subject: [PATCH] Fix code generation for Name trait (#944) --- prost-build/src/code_generator.rs | 77 ++++++++++++++++++++----------- src/name.rs | 14 ++++-- tests/src/type_names.proto | 6 ++- tests/src/type_names.rs | 14 ++++-- 4 files changed, 74 insertions(+), 37 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index ccf7e2415..a6a06d484 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -29,6 +29,7 @@ enum Syntax { pub struct CodeGenerator<'a> { config: &'a mut Config, package: String, + type_path: Vec, source_info: Option, syntax: Syntax, message_graph: &'a MessageGraph, @@ -69,6 +70,7 @@ impl<'a> CodeGenerator<'a> { let mut code_gen = CodeGenerator { config, package: file.package.unwrap_or_default(), + type_path: Vec::new(), source_info, syntax, message_graph, @@ -84,13 +86,6 @@ impl<'a> CodeGenerator<'a> { code_gen.package ); - if code_gen.config.enable_type_names { - code_gen.buf.push_str(&format!( - "const PACKAGE: &str = \"{}\";\n", - code_gen.package, - )); - } - code_gen.path.push(4); for (idx, message) in file.message_type.into_iter().enumerate() { code_gen.path.push(idx as i32); @@ -128,10 +123,16 @@ impl<'a> CodeGenerator<'a> { let message_name = message.name().to_string(); let fq_message_name = format!( - "{}{}.{}", - if self.package.is_empty() { "" } else { "." }, - self.package, - message.name() + "{}{}{}{}.{}", + if self.package.is_empty() && self.type_path.is_empty() { + "" + } else { + "." + }, + self.package.trim_matches('.'), + if self.type_path.is_empty() { "" } else { "." }, + self.type_path.join("."), + message_name, ); // Skip external types. @@ -282,19 +283,34 @@ impl<'a> CodeGenerator<'a> { )); self.depth += 1; - self.buf - .push_str("const PACKAGE: &'static str = PACKAGE;\n"); self.buf.push_str(&format!( "const NAME: &'static str = \"{}\";\n", - message_name + message_name, + )); + self.buf.push_str(&format!( + "const PACKAGE: &'static str = \"{}\";\n", + self.package, + )); + + let prost_path = self.config.prost_path.as_deref().unwrap_or("::prost"); + let string_path = format!("{}::alloc::string::String", prost_path); + let format_path = format!("{}::alloc::format", prost_path); + + self.buf.push_str(&format!( + r#"fn full_name() -> {string_path} {{ + {format_path}!("{}{}{}{}{{}}", Self::NAME) + }}"#, + self.package.trim_matches('.'), + if self.package.is_empty() { "" } else { "." }, + self.type_path.join("."), + if self.type_path.is_empty() { "" } else { "." }, )); if let Some(domain_name) = self.config.type_name_domains.get_first(fq_message_name) { self.buf.push_str(&format!( - r#"fn type_url() -> String {{ - format!("{}/{{}}", Self::full_name()) + r#"fn type_url() -> {string_path} {{ + {format_path}!("{domain_name}/{{}}", Self::full_name()) }}"#, - domain_name )); } @@ -684,11 +700,18 @@ impl<'a> CodeGenerator<'a> { let enum_values = &desc.value; let fq_proto_enum_name = format!( - "{}{}.{}", - if self.package.is_empty() { "" } else { "." }, - self.package, - proto_enum_name + "{}{}{}{}.{}", + if self.package.is_empty() && self.type_path.is_empty() { + "" + } else { + "." + }, + self.package.trim_matches('.'), + if self.type_path.is_empty() { "" } else { "." }, + self.type_path.join("."), + proto_enum_name, ); + if self .extern_paths .resolve_ident(&fq_proto_enum_name) @@ -906,8 +929,7 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(&to_snake(module)); self.buf.push_str(" {\n"); - self.package.push('.'); - self.package.push_str(module); + self.type_path.push(module.into()); self.depth += 1; } @@ -915,8 +937,7 @@ impl<'a> CodeGenerator<'a> { fn pop_mod(&mut self) { self.depth -= 1; - let idx = self.package.rfind('.').unwrap(); - self.package.truncate(idx); + self.type_path.pop(); self.push_indent(); self.buf.push_str("}\n"); @@ -954,7 +975,11 @@ impl<'a> CodeGenerator<'a> { return proto_ident; } - let mut local_path = self.package.split('.').peekable(); + let mut local_path = self + .package + .split('.') + .chain(self.type_path.iter().map(String::as_str)) + .peekable(); // If no package is specified the start of the package name will be '.' // and split will return an empty string ("") which breaks resolution diff --git a/src/name.rs b/src/name.rs index 7650c3273..2dbc5ee46 100644 --- a/src/name.rs +++ b/src/name.rs @@ -5,23 +5,27 @@ use alloc::{format, string::String}; /// Associate a type name with a [`Message`] type. pub trait Name: Message { - /// Type name for this [`Message`]. This is the camel case name, - /// e.g. `TypeName`. + /// Simple name for this [`Message`]. + /// This name is the same as it appears in the source .proto file, e.g. `FooBar`. const NAME: &'static str; /// Package name this message type is contained in. They are domain-like /// and delimited by `.`, e.g. `google.protobuf`. const PACKAGE: &'static str; - /// Full name of this message type containing both the package name and - /// type name, e.g. `google.protobuf.TypeName`. + /// Fully-qualified unique name for this [`Message`]. + /// It's prefixed with the package name and names of any parent messages, + /// e.g. `google.rpc.BadRequest.FieldViolation`. + /// By default, this is the package name followed by the message name. + /// Fully-qualified names must be unique within a domain of Type URLs. fn full_name() -> String { format!("{}.{}", Self::PACKAGE, Self::NAME) } - /// Type URL for this message, which by default is the full name with a + /// Type URL for this [`Message`], which by default is the full name with a /// leading slash, but may also include a leading domain name, e.g. /// `type.googleapis.com/google.profile.Person`. + /// This can be used when serializing with the [`Any`] type. fn type_url() -> String { format!("/{}", Self::full_name()) } diff --git a/tests/src/type_names.proto b/tests/src/type_names.proto index 68c130a9b..58c599d17 100644 --- a/tests/src/type_names.proto +++ b/tests/src/type_names.proto @@ -3,7 +3,9 @@ syntax = "proto3"; package type_names; message Foo { + message Bar { + } } -message Bar { -} +message Baz { +} \ No newline at end of file diff --git a/tests/src/type_names.rs b/tests/src/type_names.rs index 2843d115a..09fb41aa1 100644 --- a/tests/src/type_names.rs +++ b/tests/src/type_names.rs @@ -1,4 +1,3 @@ -use prost::alloc::{format, string::String}; use prost::Name; include!(concat!(env!("OUT_DIR"), "/type_names.rs")); @@ -7,9 +6,16 @@ include!(concat!(env!("OUT_DIR"), "/type_names.rs")); fn valid_type_names() { assert_eq!("Foo", Foo::NAME); assert_eq!("type_names", Foo::PACKAGE); + assert_eq!("type_names.Foo", Foo::full_name()); assert_eq!("tests/type_names.Foo", Foo::type_url()); - assert_eq!("Bar", Bar::NAME); - assert_eq!("type_names", Bar::PACKAGE); - assert_eq!("/type_names.Bar", Bar::type_url()); + assert_eq!("Bar", foo::Bar::NAME); + assert_eq!("type_names", foo::Bar::PACKAGE); + assert_eq!("type_names.Foo.Bar", foo::Bar::full_name()); + assert_eq!("tests/type_names.Foo.Bar", foo::Bar::type_url()); + + assert_eq!("Baz", Baz::NAME); + assert_eq!("type_names", Baz::PACKAGE); + assert_eq!("type_names.Baz", Baz::full_name()); + assert_eq!("/type_names.Baz", Baz::type_url()); }