Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom codec #461

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ fn generate_methods<T: Service>(service: &T, proto_path: &str) -> TokenStream {
}

fn generate_unary<T: Method>(method: &T, proto_path: &str, path: String) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());
let (request, response) = method.request_response_name(proto_path);

Expand All @@ -132,14 +132,14 @@ fn generate_unary<T: Method>(method: &T, proto_path: &str, path: String) -> Toke
tonic::Status::new(tonic::Code::Unknown, format!("Service was not ready: {}", e.into()))
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.unary(request.into_request(), path, codec).await
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.unary(request.into_request(), path, codec).await
}
}
}

fn generate_server_streaming<T: Method>(method: &T, proto_path: &str, path: String) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());

let (request, response) = method.request_response_name(proto_path);
Expand All @@ -153,14 +153,14 @@ fn generate_server_streaming<T: Method>(method: &T, proto_path: &str, path: Stri
tonic::Status::new(tonic::Code::Unknown, format!("Service was not ready: {}", e.into()))
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.server_streaming(request.into_request(), path, codec).await
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.server_streaming(request.into_request(), path, codec).await
}
}
}

fn generate_client_streaming<T: Method>(method: &T, proto_path: &str, path: String) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());

let (request, response) = method.request_response_name(proto_path);
Expand All @@ -181,7 +181,7 @@ fn generate_client_streaming<T: Method>(method: &T, proto_path: &str, path: Stri
}

fn generate_streaming<T: Method>(method: &T, proto_path: &str, path: String) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());

let (request, response) = method.request_response_name(proto_path);
Expand All @@ -195,8 +195,8 @@ fn generate_streaming<T: Method>(method: &T, proto_path: &str, path: String) ->
tonic::Status::new(tonic::Code::Unknown, format!("Service was not ready: {}", e.into()))
})?;
let codec = #codec_name::default();
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.streaming(request.into_streaming_request(), path, codec).await
let path = http::uri::PathAndQuery::from_static(#path);
self.inner.streaming(request.into_streaming_request(), path, codec).await
}
}
}
9 changes: 4 additions & 5 deletions tonic-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,14 @@ pub mod server;
/// to allow any codegen module to generate service
/// abstractions.
pub trait Service {
/// Path to the codec.
const CODEC_PATH: &'static str;

/// Comment type.
type Comment: AsRef<str>;

/// Method type.
type Method: Method;

/// Path to the codec.
fn codec_path(&self) -> &str;
/// Name of service.
fn name(&self) -> &str;
/// Package name of service.
Expand All @@ -131,11 +130,11 @@ pub trait Service {
/// to generate abstraction implementations for
/// the provided methods.
pub trait Method {
/// Path to the codec.
const CODEC_PATH: &'static str;
/// Comment type.
type Comment: AsRef<str>;

/// Path to the codec.
fn codec_path(&self) -> &str;
/// Name of method.
fn name(&self) -> &str;
/// Identifier used to generate type name.
Expand Down
91 changes: 71 additions & 20 deletions tonic-build/src/prost.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{client, server};
use proc_macro2::TokenStream;
use prost_build::{Config, Method, Service};
use prost_build::{Config, Method as ProstMethod, Service as ProstService};
use quote::ToTokens;
use std::io;
use std::path::{Path, PathBuf};
Expand All @@ -19,6 +19,7 @@ pub fn configure() -> Builder {
proto_path: "super".to_string(),
#[cfg(feature = "rustfmt")]
format: true,
codec_path: DEFAULT_PROST_CODEC_PATH.to_string(),
}
}

Expand All @@ -39,14 +40,45 @@ pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
Ok(())
}

const PROST_CODEC_PATH: &'static str = "tonic::codec::ProstCodec";
const DEFAULT_PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec";

impl crate::Service for Service {
const CODEC_PATH: &'static str = PROST_CODEC_PATH;
struct Service {
name: String,
package: String,
proto_name: String,
comments: Vec<String>,
methods: Vec<Method>,
codec_path: String,
}

impl Service {
fn new(service: ProstService, codec_path: String) -> Service {
Service {
name: service.name,
package: service.package,
proto_name: service.proto_name,
comments: service.comments.leading,
methods: service
.methods
.into_iter()
.map(|v| Method {
inner: v,
codec_path: codec_path.clone(),
})
.collect(),
codec_path,
}
}
}

impl crate::Service for Service {
type Method = Method;
type Comment = String;

fn codec_path(&self) -> &str {
&self.codec_path
}

fn name(&self) -> &str {
&self.name
}
Expand All @@ -60,55 +92,63 @@ impl crate::Service for Service {
}

fn comment(&self) -> &[Self::Comment] {
&self.comments.leading[..]
&self.comments[..]
}

fn methods(&self) -> &[Self::Method] {
&self.methods[..]
}
}

struct Method {
inner: ProstMethod,
codec_path: String,
}

impl crate::Method for Method {
const CODEC_PATH: &'static str = PROST_CODEC_PATH;
type Comment = String;

fn codec_path(&self) -> &str {
&self.codec_path
}

fn name(&self) -> &str {
&self.name
&self.inner.name
}

fn identifier(&self) -> &str {
&self.proto_name
&self.inner.proto_name
}

fn client_streaming(&self) -> bool {
self.client_streaming
self.inner.client_streaming
}

fn server_streaming(&self) -> bool {
self.server_streaming
self.inner.server_streaming
}

fn comment(&self) -> &[Self::Comment] {
&self.comments.leading[..]
&self.inner.comments.leading[..]
}

fn request_response_name(&self, proto_path: &str) -> (TokenStream, TokenStream) {
let request = if self.input_proto_type.starts_with(".google.protobuf")
|| self.input_type.starts_with("::")
let request = if self.inner.input_proto_type.starts_with(".google.protobuf")
|| self.inner.input_type.starts_with("::")
{
self.input_type.parse::<TokenStream>().unwrap()
self.inner.input_type.parse::<TokenStream>().unwrap()
} else {
syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, self.input_type))
syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, self.inner.input_type))
.unwrap()
.to_token_stream()
};

let response = if self.output_proto_type.starts_with(".google.protobuf")
|| self.output_type.starts_with("::")
let response = if self.inner.output_proto_type.starts_with(".google.protobuf")
|| self.inner.output_type.starts_with("::")
{
self.output_type.parse::<TokenStream>().unwrap()
self.inner.output_type.parse::<TokenStream>().unwrap()
} else {
syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, self.output_type))
syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, self.inner.output_type))
.unwrap()
.to_token_stream()
};
Expand All @@ -134,7 +174,8 @@ impl ServiceGenerator {
}

impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
fn generate(&mut self, service: ProstService, _buf: &mut String) {
let service = Service::new(service, self.builder.codec_path.clone());
if self.builder.build_server {
let server = server::generate(&service, &self.builder.proto_path);
self.servers.extend(server);
Expand Down Expand Up @@ -184,6 +225,7 @@ pub struct Builder {
pub(crate) field_attributes: Vec<(String, String)>,
pub(crate) type_attributes: Vec<(String, String)>,
pub(crate) proto_path: String,
pub(crate) codec_path: String,

out_dir: Option<PathBuf>,
#[cfg(feature = "rustfmt")]
Expand Down Expand Up @@ -258,6 +300,15 @@ impl Builder {
self
}

/// Set the module path to the `tonic::codec::Codec` implementation to use
/// to serialize/deserialize the protobuf messages.
///
/// This defaults to `tonic::codec::ProstCodec`
pub fn codec_path(mut self, codec_path: impl AsRef<str>) -> Self {
self.codec_path = codec_path.as_ref().to_string();
self
}

/// Compile the .proto files and execute code generation.
pub fn compile<P>(self, protos: &[P], includes: &[P]) -> io::Result<()>
where
Expand Down
8 changes: 4 additions & 4 deletions tonic-build/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ fn generate_unary<T: Method>(
method_ident: Ident,
server_trait: Ident,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();

let service_ident = quote::format_ident!("{}Svc", method.identifier());

Expand Down Expand Up @@ -311,7 +311,7 @@ fn generate_server_streaming<T: Method>(
method_ident: Ident,
server_trait: Ident,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();

let service_ident = quote::format_ident!("{}Svc", method.identifier());

Expand Down Expand Up @@ -368,7 +368,7 @@ fn generate_client_streaming<T: Method>(
let service_ident = quote::format_ident!("{}Svc", method.identifier());

let (request, response) = method.request_response_name(proto_path);
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();

quote! {
#[allow(non_camel_case_types)]
Expand Down Expand Up @@ -416,7 +416,7 @@ fn generate_streaming<T: Method>(
method_ident: Ident,
server_trait: Ident,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(T::CODEC_PATH).unwrap();
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();

let service_ident = quote::format_ident!("{}Svc", method.identifier());

Expand Down