diff --git a/examples/abitest/module_0/rust/src/lib.rs b/examples/abitest/module_0/rust/src/lib.rs index ad2d71b5524..e06c36e916b 100644 --- a/examples/abitest/module_0/rust/src/lib.rs +++ b/examples/abitest/module_0/rust/src/lib.rs @@ -1417,10 +1417,8 @@ impl FrontendNode { oak::grpc::Code::Ok as i32, grpc_rsp.status.unwrap_or_default().code ); - let rsp = GrpcTestResponse::decode( - grpc_rsp.rsp_msg.unwrap_or_default().value.as_slice(), - ) - .expect("failed to parse GrpcTestResponse"); + let rsp = GrpcTestResponse::decode(grpc_rsp.rsp_msg.as_slice()) + .expect("failed to parse GrpcTestResponse"); expect_eq!(rsp.clone(), ok_rsp.clone()); count += 1; } diff --git a/oak/proto/BUILD b/oak/proto/BUILD index 98bc81bed08..74403a8f882 100644 --- a/oak/proto/BUILD +++ b/oak/proto/BUILD @@ -46,7 +46,6 @@ proto_library( srcs = ["grpc_encap.proto"], deps = [ "//third_party/google/rpc:status_proto", - "@com_google_protobuf//:any_proto", ], ) diff --git a/oak/proto/grpc_encap.proto b/oak/proto/grpc_encap.proto index ecaf89d8319..8302618c01a 100644 --- a/oak/proto/grpc_encap.proto +++ b/oak/proto/grpc_encap.proto @@ -18,7 +18,6 @@ syntax = "proto3"; package oak.encap; -import "google/protobuf/any.proto"; import "third_party/google/rpc/status.proto"; // Protocol buffer encoding to hold additional metadata that accompanies a gRPC @@ -29,12 +28,16 @@ import "third_party/google/rpc/status.proto"; message GrpcRequest { string method_name = 1; - google.protobuf.Any req_msg = 2; + // The body of the request. Usually a serialized protobuf message. + // The message type is deduced from the `method_name` field. + bytes req_msg = 2; bool last = 3; } message GrpcResponse { - google.protobuf.Any rsp_msg = 1; + // The body of the response. Usually a serialized protobuf message. + // The message type is deduced from the `method_name` field of the request. + bytes rsp_msg = 1; google.rpc.Status status = 2; // The last field indicates that this is definitely the final response for a diff --git a/oak/server/grpc_client_node.cc b/oak/server/grpc_client_node.cc index b39c1c76ee9..cf1bc5a2947 100644 --- a/oak/server/grpc_client_node.cc +++ b/oak/server/grpc_client_node.cc @@ -60,7 +60,7 @@ bool GrpcClientNode::HandleInvocation(Handle invocation_handle) { oak::encap::GrpcRequest grpc_req; grpc_req.ParseFromString(std::string(req_result.msg->data.data(), req_result.msg->data.size())); std::string method_name = grpc_req.method_name(); - const grpc::string& req_data = grpc_req.req_msg().value(); + const grpc::string& req_data = grpc_req.req_msg(); // Use a completion queue together with a generic client reader/writer to // perform the method invocation. All steps are done in serial, so just use @@ -117,9 +117,7 @@ bool GrpcClientNode::HandleInvocation(Handle invocation_handle) { // Build an encapsulation of the gRPC response and put it in an Oak Message. oak::encap::GrpcResponse grpc_rsp; grpc_rsp.set_last(false); - google::protobuf::Any* any = new google::protobuf::Any(); - any->set_value(rsp_data.data(), rsp_data.size()); - grpc_rsp.set_allocated_rsp_msg(any); + grpc_rsp.set_rsp_msg(rsp_data); auto rsp_msg = absl::make_unique(); size_t serialized_size = grpc_rsp.ByteSizeLong(); diff --git a/oak/server/module_invocation.cc b/oak/server/module_invocation.cc index e5f91a31b77..b838679a52c 100644 --- a/oak/server/module_invocation.cc +++ b/oak/server/module_invocation.cc @@ -96,9 +96,7 @@ void ModuleInvocation::ProcessRequest(bool ok) { // Build an encapsulation of the gRPC request invocation and put it in a Message. oak::encap::GrpcRequest grpc_request; grpc_request.set_method_name(context_.method()); - google::protobuf::Any* any = new google::protobuf::Any(); - any->set_value(request_msg->data.data(), request_msg->data.size()); - grpc_request.set_allocated_req_msg(any); + grpc_request.set_req_msg(request_msg->data.data(), request_msg->data.size()); grpc_request.set_last(true); auto req_msg = absl::make_unique(); @@ -210,7 +208,7 @@ void ModuleInvocation::BlockingSendResponse() { } // Any channel references included with the message will be dropped. - const grpc::string& inner_msg = grpc_response.rsp_msg().value(); + const grpc::string& inner_msg = grpc_response.rsp_msg(); grpc::Slice slice(inner_msg.data(), inner_msg.size()); grpc::ByteBuffer bb(&slice, /*nslices=*/1); @@ -223,14 +221,21 @@ void ModuleInvocation::BlockingSendResponse() { auto callback = new std::function( std::bind(&ModuleInvocation::SendResponse, this, std::placeholders::_1)); stream_.Write(bb, options, callback); - } else if (!grpc_response.has_rsp_msg()) { - // Final iteration but no response, just Finish. - google::rpc::Status status = grpc_response.status(); - OAK_LOG(INFO) << "invocation#" << stream_id_ - << " SendResponse: Final inner response empty, status " << status.code(); - FinishAndCleanUp(grpc::Status(static_cast(status.code()), status.message())); + } else if (grpc_response.has_status()) { + if (inner_msg.empty()) { + // Final iteration with status and no response bytes, just Finish. + google::rpc::Status status = grpc_response.status(); + OAK_LOG(INFO) << "invocation#" << stream_id_ + << " SendResponse: Final inner response empty, status " << status.code(); + FinishAndCleanUp( + grpc::Status(static_cast(status.code()), status.message())); + } else { + // Both status and response bytes, error out. + OAK_LOG(ERROR) << "invocation#" << stream_id_ + << " SendResponse: both status and response provided"; + } } else { - // Final response, so WriteAndFinish. + // Final iteration with response bytes, so WriteAndFinish. OAK_LOG(INFO) << "invocation#" << stream_id_ << " SendResponse: Final inner response of size " << inner_msg.size() << ", request stream->WriteAndFinish => CleanUp"; options.set_last_message(); diff --git a/oak/server/rust/oak_abi/src/grpc/mod.rs b/oak/server/rust/oak_abi/src/grpc/mod.rs index 848eecbff88..5e7d011cfda 100644 --- a/oak/server/rust/oak_abi/src/grpc/mod.rs +++ b/oak/server/rust/oak_abi/src/grpc/mod.rs @@ -19,23 +19,17 @@ use log::warn; /// Encapsulate a protocol buffer message in a GrpcRequest wrapper using the /// given method name. -pub fn encap_request( - req: &T, - req_type_url: Option<&str>, - method_name: &str, -) -> Option { +pub fn encap_request(req: &T, method_name: &str) -> Option { // Put the request in a GrpcRequest wrapper and serialize it. - let mut grpc_req = GrpcRequest::default(); - grpc_req.method_name = method_name.to_string(); - let mut any = prost_types::Any::default(); - if let Err(e) = req.encode(&mut any.value) { + let mut bytes = Vec::new(); + if let Err(e) = req.encode(&mut bytes) { warn!("failed to serialize gRPC request: {}", e); return None; }; - if let Some(type_url) = req_type_url { - any.type_url = type_url.to_string(); - } - grpc_req.req_msg = Some(any); - grpc_req.last = true; + let grpc_req = GrpcRequest { + method_name: method_name.to_string(), + req_msg: bytes, + last: true, + }; Some(grpc_req) } diff --git a/oak/server/rust/oak_runtime/src/node/grpc_server.rs b/oak/server/rust/oak_runtime/src/node/grpc_server.rs index cfbb61e5016..0395faa3f31 100644 --- a/oak/server/rust/oak_runtime/src/node/grpc_server.rs +++ b/oak/server/rust/oak_runtime/src/node/grpc_server.rs @@ -193,7 +193,7 @@ impl GrpcServerNode { })?; // Create a gRPC request. - encap_request(&grpc_request_body, None, http_request_path).ok_or_else(|| { + encap_request(&grpc_request_body, http_request_path).ok_or_else(|| { error!("Failed to create a GrpcRequest"); GrpcServerError::BadProtobufMessage }) diff --git a/oak/server/storage/storage_node.cc b/oak/server/storage/storage_node.cc index db623a77ac8..2ca73d19ae1 100644 --- a/oak/server/storage/storage_node.cc +++ b/oak/server/storage/storage_node.cc @@ -92,11 +92,8 @@ oak::StatusOr> StorageNode::ProcessMet if (method_name == "/oak.storage.StorageService/Read") { oak::storage::StorageChannelReadRequest read_req; - // Assume the type of the embedded request is correct. - grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" + - read_req.GetDescriptor()->full_name()); - if (!grpc_req->req_msg().UnpackTo(&read_req)) { - return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request"); + if (!read_req.ParseFromString(grpc_req->req_msg())) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request"); } oak::storage::StorageChannelReadResponse read_rsp; std::string value; @@ -105,15 +102,12 @@ oak::StatusOr> StorageNode::ProcessMet read_req.transaction_id())); read_rsp.mutable_item()->ParseFromString(value); // TODO(#449): Check security policy for item. - grpc_rsp->mutable_rsp_msg()->PackFrom(read_rsp); + read_rsp.SerializeToString(grpc_rsp->mutable_rsp_msg()); } else if (method_name == "/oak.storage.StorageService/Write") { oak::storage::StorageChannelWriteRequest write_req; - // Assume the type of the embedded request is correct. - grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" + - write_req.GetDescriptor()->full_name()); - if (!grpc_req->req_msg().UnpackTo(&write_req)) { - return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request"); + if (!write_req.ParseFromString(grpc_req->req_msg())) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request"); } // TODO(#449): Check integrity policy for item. std::string item; @@ -123,11 +117,8 @@ oak::StatusOr> StorageNode::ProcessMet } else if (method_name == "/oak.storage.StorageService/Delete") { oak::storage::StorageChannelDeleteRequest delete_req; - // Assume the type of the embedded request is correct. - grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" + - delete_req.GetDescriptor()->full_name()); - if (!grpc_req->req_msg().UnpackTo(&delete_req)) { - return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request"); + if (!delete_req.ParseFromString(grpc_req->req_msg())) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request"); } // TODO(#449): Check integrity policy for item. OAK_RETURN_IF_ERROR(storage_processor_.Delete( @@ -135,36 +126,27 @@ oak::StatusOr> StorageNode::ProcessMet } else if (method_name == "/oak.storage.StorageService/Begin") { oak::storage::StorageChannelBeginRequest begin_req; - // Assume the type of the embedded request is correct. - grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" + - begin_req.GetDescriptor()->full_name()); - if (!grpc_req->req_msg().UnpackTo(&begin_req)) { - return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request"); + if (!begin_req.ParseFromString(grpc_req->req_msg())) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request"); } oak::storage::StorageChannelBeginResponse begin_rsp; std::string transaction_id; OAK_ASSIGN_OR_RETURN(transaction_id, storage_processor_.Begin(begin_req.storage_name())); begin_rsp.set_transaction_id(transaction_id); - grpc_rsp->mutable_rsp_msg()->PackFrom(begin_rsp); + begin_rsp.SerializeToString(grpc_rsp->mutable_rsp_msg()); } else if (method_name == "/oak.storage.StorageService/Commit") { oak::storage::StorageChannelCommitRequest commit_req; - // Assume the type of the embedded request is correct. - grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" + - commit_req.GetDescriptor()->full_name()); - if (!grpc_req->req_msg().UnpackTo(&commit_req)) { - return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request"); + if (!commit_req.ParseFromString(grpc_req->req_msg())) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request"); } OAK_RETURN_IF_ERROR( storage_processor_.Commit(commit_req.storage_name(), commit_req.transaction_id())); } else if (method_name == "/oak.storage.StorageService/Rollback") { oak::storage::StorageChannelRollbackRequest rollback_req; - // Assume the type of the embedded request is correct. - grpc_req->mutable_req_msg()->set_type_url("type.googleapis.com/" + - rollback_req.GetDescriptor()->full_name()); - if (!grpc_req->req_msg().UnpackTo(&rollback_req)) { - return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to unpack request"); + if (!rollback_req.ParseFromString(grpc_req->req_msg())) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to parse request"); } OAK_RETURN_IF_ERROR( storage_processor_.Rollback(rollback_req.storage_name(), rollback_req.transaction_id())); diff --git a/sdk/rust/oak/src/grpc/mod.rs b/sdk/rust/oak/src/grpc/mod.rs index 61c7b5e495a..88f50749434 100644 --- a/sdk/rust/oak/src/grpc/mod.rs +++ b/sdk/rust/oak/src/grpc/mod.rs @@ -72,13 +72,15 @@ impl ChannelResponseWriter { ) -> std::result::Result<(), OakError> { // Put the serialized response into a GrpcResponse message wrapper and // serialize it into the channel. - let mut grpc_rsp = GrpcResponse::default(); - let mut any = prost_types::Any::default(); - rsp.encode(&mut any.value)?; - grpc_rsp.rsp_msg = Some(any); - grpc_rsp.last = match mode { - WriteMode::KeepOpen => false, - WriteMode::Close => true, + let mut bytes = Vec::new(); + rsp.encode(&mut bytes)?; + let grpc_rsp = GrpcResponse { + rsp_msg: bytes, + status: None, + last: match mode { + WriteMode::KeepOpen => false, + WriteMode::Close => true, + }, }; self.sender.send(&grpc_rsp)?; if mode == WriteMode::Close { @@ -90,11 +92,13 @@ impl ChannelResponseWriter { /// Write an empty gRPC response and optionally close out the method /// invocation. Any errors from the channel are silently dropped. pub fn write_empty(&mut self, mode: WriteMode) -> std::result::Result<(), OakError> { - let mut grpc_rsp = GrpcResponse::default(); - grpc_rsp.rsp_msg = Some(prost_types::Any::default()); - grpc_rsp.last = match mode { - WriteMode::KeepOpen => false, - WriteMode::Close => true, + let grpc_rsp = GrpcResponse { + rsp_msg: Vec::new(), + status: None, + last: match mode { + WriteMode::KeepOpen => false, + WriteMode::Close => true, + }, }; self.sender.send(&grpc_rsp)?; if mode == WriteMode::Close { @@ -160,10 +164,7 @@ impl crate::Node for T { } self.invoke( &req.method_name, - req.req_msg - .map(|any| any.value) - .unwrap_or_default() - .as_slice(), + req.req_msg.as_slice(), ChannelResponseWriter::new(invocation.response_sender), ); Ok(()) @@ -178,8 +179,7 @@ pub fn decap_response(grpc_rsp: GrpcResponse) -> Re if status.code != rpc::Code::Ok as i32 { return Err(status); } - let bytes = grpc_rsp.rsp_msg.unwrap_or_default().value; - let rsp = T::decode(bytes.as_slice()).map_err(|proto_err| { + let rsp = T::decode(grpc_rsp.rsp_msg.as_slice()).map_err(|proto_err| { build_status( rpc::Code::InvalidArgument, &format!("message parsing failed: {}", proto_err), @@ -194,7 +194,6 @@ pub fn decap_response(grpc_rsp: GrpcResponse) -> Re pub fn invoke_grpc_method_stream( method_name: &str, req: &R, - req_type_url: Option<&str>, invocation_channel: &crate::io::Sender, ) -> Result> where @@ -206,8 +205,8 @@ where // Put the request in a GrpcRequest wrapper and send it into the request // message channel. - let req = oak_abi::grpc::encap_request(req, req_type_url, method_name) - .expect("failed to serialize GrpcRequest"); + let req = + oak_abi::grpc::encap_request(req, method_name).expect("failed to serialize GrpcRequest"); req_sender.send(&req).expect("failed to write to channel"); req_sender.close().expect("failed to close channel"); @@ -236,15 +235,13 @@ where pub fn invoke_grpc_method( method_name: &str, req: &R, - req_type_url: Option<&str>, invocation_channel: &crate::io::Sender, ) -> Result where R: prost::Message, Q: prost::Message + Default, { - let rsp_receiver = - invoke_grpc_method_stream(method_name, req, req_type_url, invocation_channel)?; + let rsp_receiver = invoke_grpc_method_stream(method_name, req, invocation_channel)?; // Read a single encapsulated response. let result = rsp_receiver.receive(); rsp_receiver.close().expect("failed to close channel"); diff --git a/sdk/rust/oak_tests/src/lib.rs b/sdk/rust/oak_tests/src/lib.rs index abd85a1d6a0..a828c6a73a8 100644 --- a/sdk/rust/oak_tests/src/lib.rs +++ b/sdk/rust/oak_tests/src/lib.rs @@ -104,7 +104,7 @@ where { // Put the request in a GrpcRequest wrapper and serialize into a message. let grpc_req = - oak_abi::grpc::encap_request(req, None, method_name).expect("failed to build GrpcRequest"); + oak_abi::grpc::encap_request(req, method_name).expect("failed to build GrpcRequest"); let mut req_msg = oak_runtime::Message { data: vec![], channels: vec![], @@ -170,8 +170,8 @@ where if status.code != oak::grpc::Code::Ok as i32 { return Err(status); } - let rsp = Q::decode(rsp.rsp_msg.unwrap_or_default().value.as_slice()) - .expect("Failed to parse response protobuf message"); + let rsp = + Q::decode(rsp.rsp_msg.as_slice()).expect("Failed to parse response protobuf message"); return Ok(rsp); } } diff --git a/sdk/rust/oak_utils/src/lib.rs b/sdk/rust/oak_utils/src/lib.rs index 6fb68d1b84c..9c3912969cb 100644 --- a/sdk/rust/oak_utils/src/lib.rs +++ b/sdk/rust/oak_utils/src/lib.rs @@ -105,7 +105,7 @@ impl prost_build::ServiceGenerator for OakServiceGenerator { false => quote! { pub fn #method_name(&self, req: #input_type) -> #oak_package::grpc::Result<#output_type> { - #oak_package::grpc::invoke_grpc_method(#method_name_string, &req, None, &self.0.invocation_sender) + #oak_package::grpc::invoke_grpc_method(#method_name_string, &req, &self.0.invocation_sender) } }, true => @@ -115,7 +115,7 @@ impl prost_build::ServiceGenerator for OakServiceGenerator { // the underlying handle. quote! { pub fn #method_name(&self, req: #input_type) -> #oak_package::grpc::Result<#oak_package::io::Receiver<#oak_package::grpc::GrpcResponse>> { - #oak_package::grpc::invoke_grpc_method_stream(#method_name_string, &req, None, &self.0.invocation_sender) + #oak_package::grpc::invoke_grpc_method_stream(#method_name_string, &req, &self.0.invocation_sender) } }, }