Skip to content

Commit

Permalink
Use bytes to represent gRPC message body (#833)
Browse files Browse the repository at this point in the history
This will also be useful for #819
  • Loading branch information
tiziano88 authored Apr 10, 2020
1 parent 1d0e436 commit 64d3f90
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 99 deletions.
6 changes: 2 additions & 4 deletions examples/abitest/module_0/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1417,10 +1417,8 @@ impl FrontendNode {
oak::grpc::Code::Ok as i32,
grpc_rsp.status.unwrap_or_default().code
);
let rsp = GrpcTestResponse::decode(
grpc_rsp.rsp_msg.unwrap_or_default().value.as_slice(),
)
.expect("failed to parse GrpcTestResponse");
let rsp = GrpcTestResponse::decode(grpc_rsp.rsp_msg.as_slice())
.expect("failed to parse GrpcTestResponse");
expect_eq!(rsp.clone(), ok_rsp.clone());
count += 1;
}
Expand Down
1 change: 0 additions & 1 deletion oak/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ proto_library(
srcs = ["grpc_encap.proto"],
deps = [
"//third_party/google/rpc:status_proto",
"@com_google_protobuf//:any_proto",
],
)

Expand Down
9 changes: 6 additions & 3 deletions oak/proto/grpc_encap.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ syntax = "proto3";

package oak.encap;

import "google/protobuf/any.proto";
import "third_party/google/rpc/status.proto";

// Protocol buffer encoding to hold additional metadata that accompanies a gRPC
Expand All @@ -29,12 +28,16 @@ import "third_party/google/rpc/status.proto";

message GrpcRequest {
string method_name = 1;
google.protobuf.Any req_msg = 2;
// The body of the request. Usually a serialized protobuf message.
// The message type is deduced from the `method_name` field.
bytes req_msg = 2;
bool last = 3;
}

message GrpcResponse {
google.protobuf.Any rsp_msg = 1;
// The body of the response. Usually a serialized protobuf message.
// The message type is deduced from the `method_name` field of the request.
bytes rsp_msg = 1;
google.rpc.Status status = 2;

// The last field indicates that this is definitely the final response for a
Expand Down
6 changes: 2 additions & 4 deletions oak/server/grpc_client_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ bool GrpcClientNode::HandleInvocation(Handle invocation_handle) {
oak::encap::GrpcRequest grpc_req;
grpc_req.ParseFromString(std::string(req_result.msg->data.data(), req_result.msg->data.size()));
std::string method_name = grpc_req.method_name();
const grpc::string& req_data = grpc_req.req_msg().value();
const grpc::string& req_data = grpc_req.req_msg();

// Use a completion queue together with a generic client reader/writer to
// perform the method invocation. All steps are done in serial, so just use
Expand Down Expand Up @@ -117,9 +117,7 @@ bool GrpcClientNode::HandleInvocation(Handle invocation_handle) {
// Build an encapsulation of the gRPC response and put it in an Oak Message.
oak::encap::GrpcResponse grpc_rsp;
grpc_rsp.set_last(false);
google::protobuf::Any* any = new google::protobuf::Any();
any->set_value(rsp_data.data(), rsp_data.size());
grpc_rsp.set_allocated_rsp_msg(any);
grpc_rsp.set_rsp_msg(rsp_data);

auto rsp_msg = absl::make_unique<NodeMessage>();
size_t serialized_size = grpc_rsp.ByteSizeLong();
Expand Down
27 changes: 16 additions & 11 deletions oak/server/module_invocation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ void ModuleInvocation::ProcessRequest(bool ok) {
// Build an encapsulation of the gRPC request invocation and put it in a Message.
oak::encap::GrpcRequest grpc_request;
grpc_request.set_method_name(context_.method());
google::protobuf::Any* any = new google::protobuf::Any();
any->set_value(request_msg->data.data(), request_msg->data.size());
grpc_request.set_allocated_req_msg(any);
grpc_request.set_req_msg(request_msg->data.data(), request_msg->data.size());
grpc_request.set_last(true);

auto req_msg = absl::make_unique<NodeMessage>();
Expand Down Expand Up @@ -210,7 +208,7 @@ void ModuleInvocation::BlockingSendResponse() {
}
// Any channel references included with the message will be dropped.

const grpc::string& inner_msg = grpc_response.rsp_msg().value();
const grpc::string& inner_msg = grpc_response.rsp_msg();
grpc::Slice slice(inner_msg.data(), inner_msg.size());
grpc::ByteBuffer bb(&slice, /*nslices=*/1);

Expand All @@ -223,14 +221,21 @@ void ModuleInvocation::BlockingSendResponse() {
auto callback = new std::function<void(bool)>(
std::bind(&ModuleInvocation::SendResponse, this, std::placeholders::_1));
stream_.Write(bb, options, callback);
} else if (!grpc_response.has_rsp_msg()) {
// Final iteration but no response, just Finish.
google::rpc::Status status = grpc_response.status();
OAK_LOG(INFO) << "invocation#" << stream_id_
<< " SendResponse: Final inner response empty, status " << status.code();
FinishAndCleanUp(grpc::Status(static_cast<grpc::StatusCode>(status.code()), status.message()));
} else if (grpc_response.has_status()) {
if (inner_msg.empty()) {
// Final iteration with status and no response bytes, just Finish.
google::rpc::Status status = grpc_response.status();
OAK_LOG(INFO) << "invocation#" << stream_id_
<< " SendResponse: Final inner response empty, status " << status.code();
FinishAndCleanUp(
grpc::Status(static_cast<grpc::StatusCode>(status.code()), status.message()));
} else {
// Both status and response bytes, error out.
OAK_LOG(ERROR) << "invocation#" << stream_id_
<< " SendResponse: both status and response provided";
}
} else {
// Final response, so WriteAndFinish.
// Final iteration with response bytes, so WriteAndFinish.
OAK_LOG(INFO) << "invocation#" << stream_id_ << " SendResponse: Final inner response of size "
<< inner_msg.size() << ", request stream->WriteAndFinish => CleanUp";
options.set_last_message();
Expand Down
22 changes: 8 additions & 14 deletions oak/server/rust/oak_abi/src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,17 @@ use log::warn;

/// Encapsulate a protocol buffer message in a GrpcRequest wrapper using the
/// given method name.
pub fn encap_request<T: prost::Message>(
req: &T,
req_type_url: Option<&str>,
method_name: &str,
) -> Option<GrpcRequest> {
pub fn encap_request<T: prost::Message>(req: &T, method_name: &str) -> Option<GrpcRequest> {
// Put the request in a GrpcRequest wrapper and serialize it.
let mut grpc_req = GrpcRequest::default();
grpc_req.method_name = method_name.to_string();
let mut any = prost_types::Any::default();
if let Err(e) = req.encode(&mut any.value) {
let mut bytes = Vec::new();
if let Err(e) = req.encode(&mut bytes) {
warn!("failed to serialize gRPC request: {}", e);
return None;
};
if let Some(type_url) = req_type_url {
any.type_url = type_url.to_string();
}
grpc_req.req_msg = Some(any);
grpc_req.last = true;
let grpc_req = GrpcRequest {
method_name: method_name.to_string(),
req_msg: bytes,
last: true,
};
Some(grpc_req)
}
2 changes: 1 addition & 1 deletion oak/server/rust/oak_runtime/src/node/grpc_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl GrpcServerNode {
})?;

// Create a gRPC request.
encap_request(&grpc_request_body, None, http_request_path).ok_or_else(|| {
encap_request(&grpc_request_body, http_request_path).ok_or_else(|| {
error!("Failed to create a GrpcRequest");
GrpcServerError::BadProtobufMessage
})
Expand Down
46 changes: 14 additions & 32 deletions oak/server/storage/storage_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,8 @@ oak::StatusOr<std::unique_ptr<oak::encap::GrpcResponse>> StorageNode::ProcessMet

if (method_name == "/oak.storage.StorageService/Read") {
oak::storage::StorageChannelReadRequest read_req;
// Assume the type of the embedded request is correct.
grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" +
read_req.GetDescriptor()->full_name());
if (!grpc_req->req_msg().UnpackTo(&read_req)) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request");
if (!read_req.ParseFromString(grpc_req->req_msg())) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request");
}
oak::storage::StorageChannelReadResponse read_rsp;
std::string value;
Expand All @@ -105,15 +102,12 @@ oak::StatusOr<std::unique_ptr<oak::encap::GrpcResponse>> StorageNode::ProcessMet
read_req.transaction_id()));
read_rsp.mutable_item()->ParseFromString(value);
// TODO(#449): Check security policy for item.
grpc_rsp->mutable_rsp_msg()->PackFrom(read_rsp);
read_rsp.SerializeToString(grpc_rsp->mutable_rsp_msg());

} else if (method_name == "/oak.storage.StorageService/Write") {
oak::storage::StorageChannelWriteRequest write_req;
// Assume the type of the embedded request is correct.
grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" +
write_req.GetDescriptor()->full_name());
if (!grpc_req->req_msg().UnpackTo(&write_req)) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request");
if (!write_req.ParseFromString(grpc_req->req_msg())) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request");
}
// TODO(#449): Check integrity policy for item.
std::string item;
Expand All @@ -123,48 +117,36 @@ oak::StatusOr<std::unique_ptr<oak::encap::GrpcResponse>> StorageNode::ProcessMet

} else if (method_name == "/oak.storage.StorageService/Delete") {
oak::storage::StorageChannelDeleteRequest delete_req;
// Assume the type of the embedded request is correct.
grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" +
delete_req.GetDescriptor()->full_name());
if (!grpc_req->req_msg().UnpackTo(&delete_req)) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request");
if (!delete_req.ParseFromString(grpc_req->req_msg())) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request");
}
// TODO(#449): Check integrity policy for item.
OAK_RETURN_IF_ERROR(storage_processor_.Delete(
delete_req.storage_name(), delete_req.item().name(), delete_req.transaction_id()));

} else if (method_name == "/oak.storage.StorageService/Begin") {
oak::storage::StorageChannelBeginRequest begin_req;
// Assume the type of the embedded request is correct.
grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" +
begin_req.GetDescriptor()->full_name());
if (!grpc_req->req_msg().UnpackTo(&begin_req)) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request");
if (!begin_req.ParseFromString(grpc_req->req_msg())) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request");
}
oak::storage::StorageChannelBeginResponse begin_rsp;
std::string transaction_id;
OAK_ASSIGN_OR_RETURN(transaction_id, storage_processor_.Begin(begin_req.storage_name()));
begin_rsp.set_transaction_id(transaction_id);
grpc_rsp->mutable_rsp_msg()->PackFrom(begin_rsp);
begin_rsp.SerializeToString(grpc_rsp->mutable_rsp_msg());

} else if (method_name == "/oak.storage.StorageService/Commit") {
oak::storage::StorageChannelCommitRequest commit_req;
// Assume the type of the embedded request is correct.
grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" +
commit_req.GetDescriptor()->full_name());
if (!grpc_req->req_msg().UnpackTo(&commit_req)) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request");
if (!commit_req.ParseFromString(grpc_req->req_msg())) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request");
}
OAK_RETURN_IF_ERROR(
storage_processor_.Commit(commit_req.storage_name(), commit_req.transaction_id()));

} else if (method_name == "/oak.storage.StorageService/Rollback") {
oak::storage::StorageChannelRollbackRequest rollback_req;
// Assume the type of the embedded request is correct.
grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" +
rollback_req.GetDescriptor()->full_name());
if (!grpc_req->req_msg().UnpackTo(&rollback_req)) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request");
if (!rollback_req.ParseFromString(grpc_req->req_msg())) {
return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request");
}
OAK_RETURN_IF_ERROR(
storage_processor_.Rollback(rollback_req.storage_name(), rollback_req.transaction_id()));
Expand Down
45 changes: 21 additions & 24 deletions sdk/rust/oak/src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ impl ChannelResponseWriter {
) -> std::result::Result<(), OakError> {
// Put the serialized response into a GrpcResponse message wrapper and
// serialize it into the channel.
let mut grpc_rsp = GrpcResponse::default();
let mut any = prost_types::Any::default();
rsp.encode(&mut any.value)?;
grpc_rsp.rsp_msg = Some(any);
grpc_rsp.last = match mode {
WriteMode::KeepOpen => false,
WriteMode::Close => true,
let mut bytes = Vec::new();
rsp.encode(&mut bytes)?;
let grpc_rsp = GrpcResponse {
rsp_msg: bytes,
status: None,
last: match mode {
WriteMode::KeepOpen => false,
WriteMode::Close => true,
},
};
self.sender.send(&grpc_rsp)?;
if mode == WriteMode::Close {
Expand All @@ -90,11 +92,13 @@ impl ChannelResponseWriter {
/// Write an empty gRPC response and optionally close out the method
/// invocation. Any errors from the channel are silently dropped.
pub fn write_empty(&mut self, mode: WriteMode) -> std::result::Result<(), OakError> {
let mut grpc_rsp = GrpcResponse::default();
grpc_rsp.rsp_msg = Some(prost_types::Any::default());
grpc_rsp.last = match mode {
WriteMode::KeepOpen => false,
WriteMode::Close => true,
let grpc_rsp = GrpcResponse {
rsp_msg: Vec::new(),
status: None,
last: match mode {
WriteMode::KeepOpen => false,
WriteMode::Close => true,
},
};
self.sender.send(&grpc_rsp)?;
if mode == WriteMode::Close {
Expand Down Expand Up @@ -160,10 +164,7 @@ impl<T: ServerNode> crate::Node<Invocation> for T {
}
self.invoke(
&req.method_name,
req.req_msg
.map(|any| any.value)
.unwrap_or_default()
.as_slice(),
req.req_msg.as_slice(),
ChannelResponseWriter::new(invocation.response_sender),
);
Ok(())
Expand All @@ -178,8 +179,7 @@ pub fn decap_response<T: prost::Message + Default>(grpc_rsp: GrpcResponse) -> Re
if status.code != rpc::Code::Ok as i32 {
return Err(status);
}
let bytes = grpc_rsp.rsp_msg.unwrap_or_default().value;
let rsp = T::decode(bytes.as_slice()).map_err(|proto_err| {
let rsp = T::decode(grpc_rsp.rsp_msg.as_slice()).map_err(|proto_err| {
build_status(
rpc::Code::InvalidArgument,
&format!("message parsing failed: {}", proto_err),
Expand All @@ -194,7 +194,6 @@ pub fn decap_response<T: prost::Message + Default>(grpc_rsp: GrpcResponse) -> Re
pub fn invoke_grpc_method_stream<R>(
method_name: &str,
req: &R,
req_type_url: Option<&str>,
invocation_channel: &crate::io::Sender<Invocation>,
) -> Result<crate::io::Receiver<GrpcResponse>>
where
Expand All @@ -206,8 +205,8 @@ where

// Put the request in a GrpcRequest wrapper and send it into the request
// message channel.
let req = oak_abi::grpc::encap_request(req, req_type_url, method_name)
.expect("failed to serialize GrpcRequest");
let req =
oak_abi::grpc::encap_request(req, method_name).expect("failed to serialize GrpcRequest");
req_sender.send(&req).expect("failed to write to channel");
req_sender.close().expect("failed to close channel");

Expand Down Expand Up @@ -236,15 +235,13 @@ where
pub fn invoke_grpc_method<R, Q>(
method_name: &str,
req: &R,
req_type_url: Option<&str>,
invocation_channel: &crate::io::Sender<Invocation>,
) -> Result<Q>
where
R: prost::Message,
Q: prost::Message + Default,
{
let rsp_receiver =
invoke_grpc_method_stream(method_name, req, req_type_url, invocation_channel)?;
let rsp_receiver = invoke_grpc_method_stream(method_name, req, invocation_channel)?;
// Read a single encapsulated response.
let result = rsp_receiver.receive();
rsp_receiver.close().expect("failed to close channel");
Expand Down
6 changes: 3 additions & 3 deletions sdk/rust/oak_tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ where
{
// Put the request in a GrpcRequest wrapper and serialize into a message.
let grpc_req =
oak_abi::grpc::encap_request(req, None, method_name).expect("failed to build GrpcRequest");
oak_abi::grpc::encap_request(req, method_name).expect("failed to build GrpcRequest");
let mut req_msg = oak_runtime::Message {
data: vec![],
channels: vec![],
Expand Down Expand Up @@ -170,8 +170,8 @@ where
if status.code != oak::grpc::Code::Ok as i32 {
return Err(status);
}
let rsp = Q::decode(rsp.rsp_msg.unwrap_or_default().value.as_slice())
.expect("Failed to parse response protobuf message");
let rsp =
Q::decode(rsp.rsp_msg.as_slice()).expect("Failed to parse response protobuf message");
return Ok(rsp);
}
}
4 changes: 2 additions & 2 deletions sdk/rust/oak_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl prost_build::ServiceGenerator for OakServiceGenerator {
false =>
quote! {
pub fn #method_name(&self, req: #input_type) -> #oak_package::grpc::Result<#output_type> {
#oak_package::grpc::invoke_grpc_method(#method_name_string, &req, None, &self.0.invocation_sender)
#oak_package::grpc::invoke_grpc_method(#method_name_string, &req, &self.0.invocation_sender)
}
},
true =>
Expand All @@ -115,7 +115,7 @@ impl prost_build::ServiceGenerator for OakServiceGenerator {
// the underlying handle.
quote! {
pub fn #method_name(&self, req: #input_type) -> #oak_package::grpc::Result<#oak_package::io::Receiver<#oak_package::grpc::GrpcResponse>> {
#oak_package::grpc::invoke_grpc_method_stream(#method_name_string, &req, None, &self.0.invocation_sender)
#oak_package::grpc::invoke_grpc_method_stream(#method_name_string, &req, &self.0.invocation_sender)
}
},
}
Expand Down

0 comments on commit 64d3f90

Please sign in to comment.