diff --git a/oak/server/BUILD b/oak/server/BUILD index 4cc335c2688..9b2b62925d5 100644 --- a/oak/server/BUILD +++ b/oak/server/BUILD @@ -32,10 +32,17 @@ cc_library( ":channel", "//oak/common:handles", "//oak/common:logging", + "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ], ) +cc_library( + name = "handle_closer", + hdrs = ["handle_closer.h"], + deps = [":oak_node"], +) + cc_library( name = "wasm_node", srcs = [ @@ -47,7 +54,6 @@ cc_library( "wasm_node.h", ], deps = [ - ":channel", "//oak/common:handles", "//oak/common:logging", "//oak/proto:grpc_encap_cc_proto", @@ -109,9 +115,9 @@ cc_library( "oak_grpc_node.h", ], deps = [ - ":channel", ":oak_node", "//oak/common:app_config", + "//oak/common:handles", "//oak/common:logging", "//oak/common:policy", "//oak/proto:grpc_encap_cc_proto", @@ -191,7 +197,7 @@ cc_library( srcs = ["logging_node.cc"], hdrs = ["logging_node.h"], deps = [ - ":channel", + ":handle_closer", ":node_thread", "//oak/common:handles", "//oak/common:logging", @@ -205,7 +211,9 @@ cc_library( srcs = ["invocation.cc"], hdrs = ["invocation.h"], deps = [ - ":channel", + ":handle_closer", + ":oak_node", + "//oak/common:handles", "//oak/common:logging", "@com_google_absl//absl/memory", ], @@ -217,7 +225,6 @@ cc_library( hdrs = ["grpc_client_node.h"], deps = [ ":base_runtime", - ":channel", ":invocation", ":node_thread", "//oak/common:handles", diff --git a/oak/server/channel.cc b/oak/server/channel.cc index 5464ddbf13b..6caa003a1ff 100644 --- a/oak/server/channel.cc +++ b/oak/server/channel.cc @@ -79,24 +79,33 @@ void MessageChannel::Write(std::unique_ptr msg) { ReadResult MessageChannel::Read(uint32_t max_size, uint32_t max_channels) { absl::MutexLock lock(&mu_); if (msgs_.empty()) { - return ReadResult{0}; + ReadResult result(OakStatus::OK); + return result; } return ReadLocked(max_size, max_channels); } ReadResult MessageChannel::ReadLocked(uint32_t max_size, uint32_t max_channels) { - ReadResult result = {0}; Message* next_msg = msgs_.front().get(); size_t actual_size = next_msg->data.size(); size_t actual_count = next_msg->channels.size(); - if (actual_size > max_size || actual_count > max_channels) { - OAK_LOG(INFO) << "Next message of size " << actual_size << " with " << actual_count - << " channels, read limited to size " << max_size << " and " << max_channels - << " channels"; + if (actual_size > max_size) { + OAK_LOG(INFO) << "Next message of size " << actual_size << ", read limited to size " + << max_size; + ReadResult result(OakStatus::ERR_BUFFER_TOO_SMALL); + result.required_size = actual_size; + result.required_channels = actual_count; + return result; + } + if (actual_count > max_channels) { + OAK_LOG(INFO) << "Next message with " << actual_count << " handles, read limited to " + << max_channels << " handles"; + ReadResult result(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL); result.required_size = actual_size; result.required_channels = actual_count; return result; } + ReadResult result(OakStatus::OK); result.msg = std::move(msgs_.front()); msgs_.pop_front(); OAK_LOG(INFO) << "Read message of size " << result.msg->data.size() << " with " << actual_count diff --git a/oak/server/channel.h b/oak/server/channel.h index 0cd479130ce..df1df09f779 100644 --- a/oak/server/channel.h +++ b/oak/server/channel.h @@ -50,16 +50,16 @@ struct Message { oak::policy::Label label; }; -// Result of a read operation. If the operation would have produced a message -// bigger than the requested maximum size, then |required_size| will be non-zero -// and indicates the required size for the message. If the operation would have -// been accompanied by more than the requested maximum channel count, then -// |required_channels| will be non-zero and indicates the required channel count -// for the message. Otherwise, |required_size| and |required_channels| will be -// zero and |data| holds the message (transferring ownership). +// Result of a read operation. struct ReadResult { + explicit ReadResult(OakStatus s) : status(s), required_size(0), required_channels(0) {} + OakStatus status; + // The following fields are filled in if the status is ERR_BUFFER_TOO_SMALL + // or ERR_HANDLE_SPACE_TOO_SMALL, indicating the required size and number + // of handles needed to read the message. uint32_t required_size; uint32_t required_channels; + // The following field is filled in if the status is OK. std::unique_ptr msg; }; @@ -99,7 +99,7 @@ class MessageChannel { // Count indicates the number of pending messages. size_t Count() const LOCKS_EXCLUDED(mu_); - // Read returns the first message on the channel, subject to |max_size| checks. + // Read returns the first message on the channel, subject to size checks. ReadResult Read(uint32_t max_size, uint32_t max_channels) LOCKS_EXCLUDED(mu_); // BlockingRead behaves like Read but blocks until a message is available. diff --git a/oak/server/channel_test.cc b/oak/server/channel_test.cc index 1d6e64098b7..1c44886d67a 100644 --- a/oak/server/channel_test.cc +++ b/oak/server/channel_test.cc @@ -43,10 +43,12 @@ TEST(MessageChannel, BasicOperation) { ASSERT_EQ(true, read_half->CanRead()); ReadResult result1 = read_half->Read(1, 0); // too small + ASSERT_EQ(OakStatus::ERR_BUFFER_TOO_SMALL, result1.status); ASSERT_EQ(3, result1.required_size); ASSERT_EQ(nullptr, result1.msg); ReadResult result2 = read_half->Read(3, 0); // just right + ASSERT_EQ(OakStatus::OK, result2.status); EXPECT_NE(result2.msg, nullptr); ASSERT_EQ(3, result2.msg->data.size()); ASSERT_EQ(0x01, (result2.msg->data)[0]); @@ -54,6 +56,7 @@ TEST(MessageChannel, BasicOperation) { ASSERT_EQ(false, read_half->CanRead()); ReadResult result3 = read_half->Read(10000, 0); + ASSERT_EQ(OakStatus::OK, result3.status); ASSERT_EQ(nullptr, result3.msg); ASSERT_EQ(0, result3.required_size); @@ -65,15 +68,18 @@ TEST(MessageChannel, BasicOperation) { ASSERT_EQ(true, read_half->CanRead()); ReadResult result4 = read_half->Read(3000, 0); + ASSERT_EQ(OakStatus::OK, result4.status); EXPECT_NE(result4.msg, nullptr); ASSERT_EQ(3, result4.msg->data.size()); ASSERT_EQ(0x11, (result4.msg->data)[0]); ReadResult result5 = read_half->Read(0, 0); + ASSERT_EQ(OakStatus::ERR_BUFFER_TOO_SMALL, result5.status); ASSERT_EQ(3, result5.required_size); ASSERT_EQ(nullptr, result5.msg); ReadResult result6 = read_half->Read(10, 0); + ASSERT_EQ(OakStatus::OK, result6.status); EXPECT_NE(result6.msg, nullptr); ASSERT_EQ(3, result6.msg->data.size()); ASSERT_EQ(0x21, (result6.msg->data)[0]); @@ -94,11 +100,13 @@ TEST(MessageChannel, ChannelTransfer) { ASSERT_EQ(true, read_half->CanRead()); ReadResult result1 = read_half->Read(1000, 0); // no space for channels + ASSERT_EQ(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL, result1.status); ASSERT_EQ(3, result1.required_size); ASSERT_EQ(1, result1.required_channels); ASSERT_EQ(nullptr, result1.msg); ReadResult result2 = read_half->Read(1000, 1); // just right + ASSERT_EQ(OakStatus::OK, result2.status); EXPECT_NE(result2.msg, nullptr); ASSERT_EQ(3, result2.msg->data.size()); ASSERT_EQ(0x01, (result2.msg->data)[0]); @@ -107,6 +115,7 @@ TEST(MessageChannel, ChannelTransfer) { ASSERT_EQ(false, read_half->CanRead()); ReadResult result3 = read_half->Read(10000, 10); + ASSERT_EQ(OakStatus::OK, result3.status); ASSERT_EQ(nullptr, result3.msg); ASSERT_EQ(0, result3.required_size); ASSERT_EQ(0, result3.required_channels); @@ -124,17 +133,20 @@ TEST(MessageChannel, ChannelTransfer) { ASSERT_EQ(true, read_half->CanRead()); ReadResult result4 = read_half->Read(3000, 10); + ASSERT_EQ(OakStatus::OK, result4.status); EXPECT_NE(result4.msg, nullptr); ASSERT_EQ(3, result4.msg->data.size()); ASSERT_EQ(0x11, (result4.msg->data)[0]); ASSERT_EQ(2, result4.msg->channels.size()); ReadResult result5 = read_half->Read(100, 0); + ASSERT_EQ(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL, result5.status); ASSERT_EQ(3, result5.required_size); ASSERT_EQ(3, result5.required_channels); ASSERT_EQ(nullptr, result5.msg); ReadResult result6 = read_half->Read(10, 10); + ASSERT_EQ(OakStatus::OK, result6.status); EXPECT_NE(result6.msg, nullptr); ASSERT_EQ(3, result6.msg->data.size()); ASSERT_EQ(0x21, (result6.msg->data)[0]); diff --git a/oak/server/grpc_client_node.cc b/oak/server/grpc_client_node.cc index 23584205f6c..99cf2cafbcd 100644 --- a/oak/server/grpc_client_node.cc +++ b/oak/server/grpc_client_node.cc @@ -37,8 +37,8 @@ GrpcClientNode::GrpcClientNode(BaseRuntime* runtime, const std::string& name, OAK_LOG(INFO) << "Created gRPC client node for " << grpc_address; } -bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel) { - std::unique_ptr invocation(Invocation::ReceiveFromChannel(invocation_channel)); +bool GrpcClientNode::HandleInvocation(Handle invocation_handle) { + std::unique_ptr invocation(Invocation::ReceiveFromChannel(this, invocation_handle)); if (invocation == nullptr) { OAK_LOG(ERROR) << "Failed to create invocation"; return false; @@ -46,12 +46,12 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel // Expect to read a single request out of the request channel. // TODO(#97): support client-side streaming - ReadResult req_result = invocation->req_channel->Read(INT_MAX, INT_MAX); - if (req_result.required_size > 0) { - OAK_LOG(ERROR) << "Message size too large: " << req_result.required_size; + NodeReadResult req_result = ChannelRead(invocation->req_handle.get(), INT_MAX, INT_MAX); + if (req_result.status != OakStatus::OK) { + OAK_LOG(ERROR) << "Failed to read invocation message: " << req_result.status; return false; } - if (req_result.msg->channels.size() != 0) { + if (req_result.msg->handles.size() != 0) { OAK_LOG(ERROR) << "Unexpectedly received channel handles in request channel"; return false; } @@ -119,14 +119,14 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel any->set_value(rsp_data.data(), rsp_data.size()); grpc_rsp.set_allocated_rsp_msg(any); - std::unique_ptr rsp_msg = absl::make_unique(); + auto rsp_msg = absl::make_unique(); size_t serialized_size = grpc_rsp.ByteSizeLong(); rsp_msg->data.resize(serialized_size); grpc_rsp.SerializeToArray(rsp_msg->data.data(), rsp_msg->data.size()); // Write the encapsulated response Message to the response channel. OAK_LOG(INFO) << "Write gRPC response message to response channel"; - invocation->rsp_channel->Write(std::move(rsp_msg)); + ChannelWrite(invocation->rsp_handle.get(), std::move(rsp_msg)); } OAK_LOG(INFO) << "Finish invocation method " << method_name; @@ -143,14 +143,14 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel grpc_rsp.mutable_status()->set_code(status.error_code()); grpc_rsp.mutable_status()->set_message(status.error_message()); - std::unique_ptr rsp_msg = absl::make_unique(); + auto rsp_msg = absl::make_unique(); size_t serialized_size = grpc_rsp.ByteSizeLong(); rsp_msg->data.resize(serialized_size); grpc_rsp.SerializeToArray(rsp_msg->data.data(), rsp_msg->data.size()); OAK_LOG(INFO) << "Write final gRPC status of (" << status.error_code() << ", '" << status.error_message() << "') to response channel"; - invocation->rsp_channel->Write(std::move(rsp_msg)); + ChannelWrite(invocation->rsp_handle.get(), std::move(rsp_msg)); } // References to the per-invocation request/response channels will be dropped @@ -159,12 +159,6 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel } void GrpcClientNode::Run(Handle invocation_handle) { - // Borrow a pointer to the relevant channel half. - MessageChannelReadHalf* invocation_channel = BorrowReadChannel(invocation_handle); - if (invocation_channel == nullptr) { - OAK_LOG(ERROR) << "Required channel not available; handle: " << invocation_handle; - return; - } std::vector> channel_status; channel_status.push_back(absl::make_unique(invocation_handle)); while (true) { @@ -173,11 +167,10 @@ void GrpcClientNode::Run(Handle invocation_handle) { return; } - if (!HandleInvocation(invocation_channel)) { + if (!HandleInvocation(invocation_handle)) { OAK_LOG(ERROR) << "Invocation processing failed!"; } } - // Drop reference to the invocation channel on exit. } } // namespace oak diff --git a/oak/server/grpc_client_node.h b/oak/server/grpc_client_node.h index d2e06f622e0..e6d3ed52b71 100644 --- a/oak/server/grpc_client_node.h +++ b/oak/server/grpc_client_node.h @@ -24,7 +24,6 @@ #include "include/grpcpp/grpcpp.h" #include "oak/common/handles.h" #include "oak/server/base_runtime.h" -#include "oak/server/channel.h" #include "oak/server/node_thread.h" namespace oak { @@ -34,7 +33,7 @@ class GrpcClientNode final : public NodeThread { GrpcClientNode(BaseRuntime* runtime, const std::string& name, const std::string& grpc_address); private: - bool HandleInvocation(MessageChannelReadHalf* invocation_channel); + bool HandleInvocation(Handle invocation_handle); void Run(Handle handle) override; std::shared_ptr channel_; diff --git a/oak/server/handle_closer.h b/oak/server/handle_closer.h new file mode 100644 index 00000000000..74faa1af6f1 --- /dev/null +++ b/oak/server/handle_closer.h @@ -0,0 +1,40 @@ +/* + * Copyright 2020 The Project Oak Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OAK_SERVER_HANDLE_CLOSER_H_ +#define OAK_SERVER_HANDLE_CLOSER_H_ + +#include "oak/server/oak_node.h" + +namespace oak { + +// RAII class to hold a Node's handle and auto-close it on object destruction. +class HandleCloser { + public: + explicit HandleCloser(OakNode* node, Handle handle) : node_(node), handle_(handle) {} + + ~HandleCloser() { node_->ChannelClose(handle_); } + + Handle get() const { return handle_; } + + private: + OakNode* node_; + const Handle handle_; +}; + +} // namespace oak + +#endif // OAK_SERVER_HANDLE_CLOSER_H_ diff --git a/oak/server/invocation.cc b/oak/server/invocation.cc index 88a0fa8f1f8..3a4393228c5 100644 --- a/oak/server/invocation.cc +++ b/oak/server/invocation.cc @@ -16,47 +16,33 @@ #include "oak/server/invocation.h" -#include "absl/memory/memory.h" #include "oak/common/logging.h" namespace oak { -std::unique_ptr Invocation::ReceiveFromChannel( - MessageChannelReadHalf* invocation_channel) { - // Expect to receive a pair of channel references: - // - Reference to the read half of a channel that holds the request, serialized +std::unique_ptr Invocation::ReceiveFromChannel(OakNode* node, + Handle invocation_handle) { + // Expect to receive a pair of channel handles: + // - Handle for the read half of a channel that holds the request, serialized // as a GrpcRequest. - // - Reference to the write half of a channel to send responses down, each + // - Handle for the write half of a channel to send responses down, each // serialized as a GrpcResponse. - ReadResult invocation = invocation_channel->Read(INT_MAX, INT_MAX); - if (invocation.required_size > 0) { - OAK_LOG(ERROR) << "Message size too large: " << invocation.required_size; + NodeReadResult invocation = node->ChannelRead(invocation_handle, INT_MAX, INT_MAX); + if (invocation.status != OakStatus::OK) { + OAK_LOG(ERROR) << "Failed to read invocation message: " << invocation.status; return nullptr; } if (invocation.msg->data.size() != 0) { OAK_LOG(ERROR) << "Unexpectedly received data in invocation"; return nullptr; } - if (invocation.msg->channels.size() != 2) { - OAK_LOG(ERROR) << "Wrong number of channels " << invocation.msg->channels.size() + if (invocation.msg->handles.size() != 2) { + OAK_LOG(ERROR) << "Wrong number of handles " << invocation.msg->handles.size() << " in invocation"; return nullptr; } - - std::unique_ptr half0 = std::move(invocation.msg->channels[0]); - auto channel0 = absl::get_if>(half0.get()); - if (channel0 == nullptr) { - OAK_LOG(ERROR) << "First channel accompanying invocation is write-direction"; - return nullptr; - } - - std::unique_ptr half1 = std::move(invocation.msg->channels[1]); - auto channel1 = absl::get_if>(half1.get()); - if (channel1 == nullptr) { - OAK_LOG(ERROR) << "Second channel accompanying invocation is read-direction"; - return nullptr; - } - return absl::make_unique(std::move(*channel0), std::move(*channel1)); + return absl::make_unique(node, invocation.msg->handles[0], + invocation.msg->handles[1]); } } // namespace oak diff --git a/oak/server/invocation.h b/oak/server/invocation.h index 00684704d07..7f42cfb0e54 100644 --- a/oak/server/invocation.h +++ b/oak/server/invocation.h @@ -20,22 +20,23 @@ #include #include -#include "oak/server/channel.h" +#include "oak/common/handles.h" +#include "oak/server/handle_closer.h" +#include "oak/server/oak_node.h" namespace oak { -// An Invocation holds the channel references used in gRPC method invocation +// An Invocation holds the channel handles used in gRPC method invocation // processing. struct Invocation { // Build an Invocation from the data arriving on the given channel. - static std::unique_ptr ReceiveFromChannel(MessageChannelReadHalf* invocation_channel); + static std::unique_ptr ReceiveFromChannel(OakNode* node, Handle invocation_handle); - Invocation(std::unique_ptr req, - std::unique_ptr rsp) - : req_channel(std::move(req)), rsp_channel(std::move(rsp)) {} + Invocation(OakNode* node, Handle req, Handle rsp) + : req_handle(node, req), rsp_handle(node, rsp) {} - std::unique_ptr req_channel; - std::unique_ptr rsp_channel; + HandleCloser req_handle; + HandleCloser rsp_handle; }; } // namespace oak diff --git a/oak/server/logging_node.cc b/oak/server/logging_node.cc index fe8a3dadef1..eba6a19064f 100644 --- a/oak/server/logging_node.cc +++ b/oak/server/logging_node.cc @@ -19,35 +19,29 @@ #include "absl/memory/memory.h" #include "oak/common/logging.h" #include "oak/proto/log.pb.h" +#include "oak/server/handle_closer.h" namespace oak { void LoggingNode::Run(Handle handle) { - // Borrow pointer to the channel half. - MessageChannelReadHalf* channel = BorrowReadChannel(handle); - if (channel == nullptr) { - OAK_LOG(ERROR) << "{" << name_ << "} No channel available!"; - return; - } + HandleCloser handle_closer(this, handle); std::vector> status; status.push_back(absl::make_unique(handle)); bool done = false; while (!done) { if (!WaitOnChannels(&status)) { - OAK_LOG(WARNING) << "{" << name_ << "} Node termination requested, " << channel->Count() - << " log messages pending"; + OAK_LOG(WARNING) << "{" << name_ << "} Node termination requested"; done = true; } - ReadResult result; while (true) { - result = channel->Read(INT_MAX, INT_MAX); - if (result.required_size > 0) { - OAK_LOG(ERROR) << "{" << name_ << "} Message size too large: " << result.required_size; - return; - } - if (result.msg == nullptr) { + NodeReadResult result = ChannelRead(handle, INT_MAX, INT_MAX); + if (result.status == OakStatus::ERR_CHANNEL_EMPTY) { break; } + if (result.status != OakStatus::OK) { + OAK_LOG(ERROR) << "{" << name_ << "} Failed to read message: " << result.status; + return; + } oak::log::LogMessage log_msg; bool successful_parse = log_msg.ParseFromArray(result.msg->data.data(), result.msg->data.size()); @@ -61,14 +55,12 @@ void LoggingNode::Run(Handle handle) { } else { OAK_LOG(ERROR) << "{" << name_ << "} Could not parse LogMessage."; } - // Any channel references included with the message will be dropped. + // Drop any handles that erroneously came along with the message. + for (Handle handle : result.msg->handles) { + ChannelClose(handle); + } } } - if (CloseChannel(handle)) { - OAK_LOG(INFO) << "{" << name_ << "} Closed channel handle: " << handle; - } else { - OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << handle; - } } } // namespace oak diff --git a/oak/server/module_invocation.cc b/oak/server/module_invocation.cc index 3613442e772..d1b0c932bd1 100644 --- a/oak/server/module_invocation.cc +++ b/oak/server/module_invocation.cc @@ -27,14 +27,14 @@ namespace oak { namespace { -// Copy the data from a gRPC ByteBuffer into a Message. -std::unique_ptr Unwrap(const grpc::ByteBuffer& buffer) { +// Copy the data from a gRPC ByteBuffer into a NodeMessage. +std::unique_ptr Unwrap(const grpc::ByteBuffer& buffer) { std::vector<::grpc::Slice> slices; grpc::Status status = buffer.Dump(&slices); if (!status.ok()) { OAK_LOG(FATAL) << "Could not unwrap buffer"; } - auto msg = absl::make_unique(); + auto msg = absl::make_unique(); for (const auto& slice : slices) { msg->data.insert(msg->data.end(), slice.begin(), slice.end()); } @@ -76,7 +76,7 @@ void ModuleInvocation::ProcessRequest(bool ok) { delete this; return; } - std::unique_ptr request_msg = Unwrap(request_); + std::unique_ptr request_msg = Unwrap(request_); OAK_LOG(INFO) << "invocation#" << stream_id_ << " ProcessRequest: handling gRPC call: " << context_.method(); @@ -99,7 +99,7 @@ void ModuleInvocation::ProcessRequest(bool ok) { grpc_request.set_allocated_req_msg(any); grpc_request.set_last(true); - std::unique_ptr req_msg = absl::make_unique(); + auto req_msg = absl::make_unique(); size_t serialized_size = grpc_request.ByteSizeLong(); req_msg->data.resize(serialized_size); grpc_request.SerializeToArray(req_msg->data.data(), req_msg->data.size()); @@ -131,28 +131,31 @@ void ModuleInvocation::ProcessRequest(bool ok) { // Create a pair of channels for communication corresponding to this // particular method invocation, one for sending in requests, and one for // receiving responses. - MessageChannel::ChannelHalves req_halves = MessageChannel::Create(); - req_half_ = std::move(req_halves.write); - MessageChannel::ChannelHalves rsp_halves = MessageChannel::Create(); - rsp_half_ = std::move(rsp_halves.read); + std::pair req_handles = grpc_node_->ChannelCreate(); + req_handle_ = req_handles.first; + std::pair rsp_handles = grpc_node_->ChannelCreate(); + rsp_handle_ = rsp_handles.second; // Build a notification message that just holds references to these two // newly-created channels. - std::unique_ptr notify_msg = absl::make_unique(); - notify_msg->channels.push_back(absl::make_unique(std::move(req_halves.read))); - notify_msg->channels.push_back(absl::make_unique(std::move(rsp_halves.write))); + auto notify_msg = absl::make_unique(); + notify_msg->handles.push_back(req_handles.second); + notify_msg->handles.push_back(rsp_handles.first); // Write the request message to the just-created request channel. - req_half_->Write(std::move(req_msg)); + grpc_node_->ChannelWrite(req_handle_, std::move(req_msg)); OAK_LOG(INFO) << "invocation#" << stream_id_ - << " ProcessRequest: Wrote encapsulated request to new gRPC request channel"; + << " ProcessRequest: Wrote encapsulated request to new gRPC request channel " + << req_handle_; - // Write the notification message to the gRPC input channel, which the runtime - // connected to the Node. - MessageChannelWriteHalf* notify_half = grpc_node_->BorrowWriteChannel(); - notify_half->Write(std::move(notify_msg)); + // Write the notification message to the gRPC notification channel, then close + // our copy of the transferred handles. + grpc_node_->ChannelWrite(grpc_node_->handle_, std::move(notify_msg)); OAK_LOG(INFO) << "invocation#" << stream_id_ - << " ProcessRequest: Wrote notification request to gRPC input channel"; + << " ProcessRequest: Wrote notification request to gRPC notification handle " + << grpc_node_->handle_; + grpc_node_->ChannelClose(req_handles.second); + grpc_node_->ChannelClose(rsp_handles.first); // Move straight onto sending first response. SendResponse(true); @@ -173,15 +176,22 @@ void ModuleInvocation::SendResponse(bool ok) { } void ModuleInvocation::BlockingSendResponse() { - ReadResult rsp_result; // Block until we can read a single queued GrpcResponse message (in serialized form) from the // gRPC output channel. OAK_LOG(INFO) << "invocation#" << stream_id_ << " SendResponse: do blocking-read on grpc channel"; - rsp_result = rsp_half_->BlockingRead(INT_MAX, INT_MAX); - if (rsp_result.required_size > 0) { + + std::vector> status; + status.push_back(absl::make_unique(rsp_handle_)); + if (!grpc_node_->WaitOnChannels(&status)) { + OAK_LOG(ERROR) << "invocation#" << stream_id_ << " SendResponse: Failed to wait for message"; + FinishAndCleanUp(grpc::Status(grpc::StatusCode::INTERNAL, "Message wait failed")); + return; + } + NodeReadResult rsp_result = grpc_node_->ChannelRead(rsp_handle_, INT_MAX, INT_MAX); + if (rsp_result.status != OakStatus::OK) { OAK_LOG(ERROR) << "invocation#" << stream_id_ - << " SendResponse: Message size too large: " << rsp_result.required_size; - FinishAndCleanUp(grpc::Status(grpc::StatusCode::INTERNAL, "Message size too large")); + << " SendResponse: Failed to read message: " << rsp_result.status; + FinishAndCleanUp(grpc::Status(grpc::StatusCode::INTERNAL, "Failed to read response message")); return; } diff --git a/oak/server/module_invocation.h b/oak/server/module_invocation.h index 7f4fe3f29fd..38807acff2c 100644 --- a/oak/server/module_invocation.h +++ b/oak/server/module_invocation.h @@ -21,7 +21,7 @@ #include "include/grpcpp/generic/async_generic_service.h" #include "include/grpcpp/grpcpp.h" -#include "oak/server/channel.h" +#include "oak/common/handles.h" #include "oak/server/oak_grpc_node.h" namespace oak { @@ -41,7 +41,10 @@ class ModuleInvocation { stream_id_(grpc_node_->NextStreamID()) {} // This object deletes itself. - ~ModuleInvocation() = default; + ~ModuleInvocation() { + grpc_node_->ChannelClose(req_handle_); + grpc_node_->ChannelClose(rsp_handle_); + } // Starts the asynchronous gRPC flow, which calls ReadRequest when the next // Oak Module invocation request arrives. @@ -75,10 +78,10 @@ class ModuleInvocation { // Borrowed references to gRPC Node that this invocation is on behalf of. OakGrpcNode* grpc_node_; - // Channel references for the two channels that are used for communication + // Channel handles for the two channels that are used for communication // related to this method invocation. - std::unique_ptr req_half_; - std::unique_ptr rsp_half_; + Handle req_handle_; + Handle rsp_handle_; grpc::GenericServerContext context_; grpc::GenericServerAsyncReaderWriter stream_; diff --git a/oak/server/node_thread.cc b/oak/server/node_thread.cc index 74f65351afb..ce82cbac8d9 100644 --- a/oak/server/node_thread.cc +++ b/oak/server/node_thread.cc @@ -22,20 +22,16 @@ namespace oak { NodeThread::~NodeThread() { StopThread(); } -void NodeThread::Start() { +void NodeThread::Start(Handle handle) { if (thread_.joinable()) { OAK_LOG(ERROR) << "Attempt to Start() an already-running NodeThread"; return; } - if (runtime_->TerminationPending()) { + if (TerminationPending()) { OAK_LOG(ERROR) << "Attempt to Start() an already-terminated NodeThread"; return; } - // At Node start-up, there should be a single registered handle, which gets - // passed to the Run() method. - Handle handle = SingleHandle(); - OAK_LOG(INFO) << "Executing new {" << name_ << "} node thread with handle " << handle; thread_ = std::thread(&oak::NodeThread::Run, this, handle); OAK_LOG(INFO) << "Started {" << name_ << "} node thread"; diff --git a/oak/server/node_thread.h b/oak/server/node_thread.h index 108b6deffcd..fb7b2033b13 100644 --- a/oak/server/node_thread.h +++ b/oak/server/node_thread.h @@ -33,7 +33,7 @@ class NodeThread : public OakNode { virtual ~NodeThread(); // Start kicks off a separate thread that invokes the Run() method. - void Start() override; + void Start(Handle handle) override; // Stop terminates the thread associated with the pseudo-node. void Stop() override; diff --git a/oak/server/oak_grpc_node.cc b/oak/server/oak_grpc_node.cc index c9b5e9c05d9..cc875c22be3 100644 --- a/oak/server/oak_grpc_node.cc +++ b/oak/server/oak_grpc_node.cc @@ -53,7 +53,9 @@ std::unique_ptr OakGrpcNode::Create( return node; } -void OakGrpcNode::Start() { +void OakGrpcNode::Start(Handle handle) { + handle_ = handle; + OAK_LOG(INFO) << "{" << name_ << "} Using handle " << handle_ << " for sending invocations"; // Start a new thread to process the gRPC completion queue. queue_thread_ = std::thread(&OakGrpcNode::CompletionQueueLoop, this); } @@ -69,7 +71,7 @@ void OakGrpcNode::CompletionQueueLoop() { bool ok; void* tag; if (!completion_queue_->Next(&tag, &ok)) { - if (!runtime_->TerminationPending()) { + if (!TerminationPending()) { OAK_LOG(FATAL) << "{" << name_ << "} Failure reading from completion queue"; } OAK_LOG(INFO) << "{" << name_ diff --git a/oak/server/oak_grpc_node.h b/oak/server/oak_grpc_node.h index f5137b2a9b3..1e2afab22ce 100644 --- a/oak/server/oak_grpc_node.h +++ b/oak/server/oak_grpc_node.h @@ -22,7 +22,6 @@ #include "absl/synchronization/mutex.h" #include "include/grpcpp/grpcpp.h" #include "oak/common/app_config.h" -#include "oak/server/channel.h" #include "oak/server/oak_node.h" namespace oak { @@ -36,7 +35,7 @@ class OakGrpcNode final : public OakNode { std::shared_ptr grpc_credentials, const uint16_t port = 0); virtual ~OakGrpcNode(){}; - void Start() override; + void Start(Handle handle) override; void Stop() override; int GetPort() { return port_; }; @@ -50,14 +49,10 @@ class OakGrpcNode final : public OakNode { friend class ModuleInvocation; OakGrpcNode(BaseRuntime* runtime, const std::string& name) - : OakNode(runtime, name), next_stream_id_(1) {} + : OakNode(runtime, name), next_stream_id_(1), handle_(kInvalidHandle) {} OakGrpcNode(const OakGrpcNode&) = delete; OakGrpcNode& operator=(const OakGrpcNode&) = delete; - MessageChannelWriteHalf* BorrowWriteChannel() const { - return OakNode::BorrowWriteChannel(SingleHandle()); - } - // Consumes gRPC events from the completion queue in an infinite loop. void CompletionQueueLoop(); @@ -74,6 +69,7 @@ class OakGrpcNode final : public OakNode { absl::Mutex id_mu_; // protects next_stream_id_ int32_t next_stream_id_ GUARDED_BY(id_mu_); + Handle handle_; // const after Start() }; } // namespace oak diff --git a/oak/server/oak_node.cc b/oak/server/oak_node.cc index c104bf49145..a7f0d53c88f 100644 --- a/oak/server/oak_node.cc +++ b/oak/server/oak_node.cc @@ -16,11 +16,125 @@ #include "oak/server/oak_node.h" +#include "absl/memory/memory.h" #include "oak/common/logging.h" #include "oak/server/notification.h" namespace oak { +NodeReadResult OakNode::ChannelRead(Handle handle, uint32_t max_size, uint32_t max_channels) { + // Borrowing a reference to the channel is safe because the node is single + // threaded and so cannot invoke channel_close while channel_read is + // ongoing. + MessageChannelReadHalf* channel = BorrowReadChannel(handle); + if (channel == nullptr) { + OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << handle; + return NodeReadResult(OakStatus::ERR_BAD_HANDLE); + } + ReadResult result_internal = channel->Read(max_size, max_channels); + NodeReadResult result(result_internal.status); + result.required_size = result_internal.required_size; + result.required_channels = result_internal.required_channels; + if (result.status == OakStatus::OK) { + if (result_internal.msg != nullptr) { + // Move data and label across into Node-relative message. + result.msg = absl::make_unique(); + result.msg->data = std::move(result_internal.msg->data); + result.msg->label = std::move(result_internal.msg->label); + // Transmute channel references to handles in this Node's handle space. + for (size_t ii = 0; ii < result_internal.msg->channels.size(); ii++) { + Handle handle = AddChannel(std::move(result_internal.msg->channels[ii])); + OAK_LOG(INFO) << "{" << name_ << "} Transferred channel has new handle " << handle; + result.msg->handles.push_back(handle); + } + } else { + // Nothing available to read. + if (channel->Orphaned()) { + OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << handle << "]: no writers left"; + result.status = OakStatus::ERR_CHANNEL_CLOSED; + } else { + result.status = OakStatus::ERR_CHANNEL_EMPTY; + } + } + } + return result; +} + +OakStatus OakNode::ChannelWrite(Handle handle, std::unique_ptr msg) { + // Borrowing a reference to the channel is safe because the Node is single + // threaded and so cannot invoke channel_close while channel_write is + // ongoing. + MessageChannelWriteHalf* channel = BorrowWriteChannel(handle); + if (channel == nullptr) { + OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << handle; + return OakStatus::ERR_BAD_HANDLE; + } + + if (channel->Orphaned()) { + OAK_LOG(INFO) << "{" << name_ << "} channel_write[" << handle << "]: no readers left"; + return OakStatus::ERR_CHANNEL_CLOSED; + } + auto msg_internal = absl::make_unique(); + msg_internal->data = std::move(msg->data); + msg_internal->label = std::move(msg->label); + for (Handle h : msg->handles) { + ChannelHalf* half = BorrowChannel(h); + if (half == nullptr) { + OAK_LOG(WARNING) << "{" << name_ << "} Invalid transferred channel handle: " << h; + return OakStatus::ERR_BAD_HANDLE; + } + msg_internal->channels.push_back(CloneChannelHalf(half)); + } + + channel->Write(std::move(msg_internal)); + return OakStatus::OK; +} + +std::pair OakNode::ChannelCreate() { + MessageChannel::ChannelHalves halves = MessageChannel::Create(); + Handle write_handle = AddChannel(absl::make_unique(std::move(halves.write))); + Handle read_handle = AddChannel(absl::make_unique(std::move(halves.read))); + OAK_LOG(INFO) << "{" << name_ << "} Created new channel with handles write=" << write_handle + << ", read=" << read_handle; + return std::pair(write_handle, read_handle); +} + +OakStatus OakNode::ChannelClose(Handle handle) { + absl::MutexLock lock(&mu_); + auto it = channel_halves_.find(handle); + if (it == channel_halves_.end()) { + return OakStatus::ERR_BAD_HANDLE; + } + channel_halves_.erase(it); + return OakStatus::OK; +} + +OakStatus OakNode::NodeCreate(Handle handle, const std::string& config_name, + const std::string& entrypoint_name) { + // Check that the handle identifies the read half of a channel. + ChannelHalf* borrowed_half = BorrowChannel(handle); + if (borrowed_half == nullptr) { + OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << handle; + return OakStatus::ERR_BAD_HANDLE; + } + if (!absl::holds_alternative>(*borrowed_half)) { + OAK_LOG(WARNING) << "{" << name_ << "} Wrong direction channel handle: " << handle; + return OakStatus::ERR_BAD_HANDLE; + } + std::unique_ptr half = CloneChannelHalf(borrowed_half); + + OAK_LOG(INFO) << "Create a new node with config '" << config_name << "' and entrypoint '" + << entrypoint_name << "'"; + + std::string node_name; + if (!runtime_->CreateAndRunNode(config_name, entrypoint_name, std::move(half), &node_name)) { + return OakStatus::ERR_INVALID_ARGS; + } else { + OAK_LOG(INFO) << "Created new node named {" << node_name << "}"; + return OakStatus::OK; + } +} + Handle OakNode::NextHandle() { std::uniform_int_distribution distribution; while (true) { @@ -42,16 +156,6 @@ Handle OakNode::AddChannel(std::unique_ptr half) { return handle; } -bool OakNode::CloseChannel(Handle handle) { - absl::MutexLock lock(&mu_); - auto it = channel_halves_.find(handle); - if (it == channel_halves_.end()) { - return false; - } - channel_halves_.erase(it); - return true; -} - ChannelHalf* OakNode::BorrowChannel(Handle handle) const { absl::ReaderMutexLock lock(&mu_); auto it = channel_halves_.find(handle); @@ -147,12 +251,4 @@ bool OakNode::WaitOnChannels(std::vector>* status } } -Handle OakNode::SingleHandle() const { - absl::ReaderMutexLock lock(&mu_); - if (channel_halves_.size() != 1) { - return kInvalidHandle; - } - return channel_halves_.begin()->first; -} - } // namespace oak diff --git a/oak/server/oak_node.h b/oak/server/oak_node.h index 8be2ab826e7..c2e73371b12 100644 --- a/oak/server/oak_node.h +++ b/oak/server/oak_node.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -31,30 +32,56 @@ namespace oak { +// Representation of a message transferred over a channel, relative to +// a particular Node. This is equivalent to the Message object, but +// using channel handles (which are relative to a particular OakNode) rather +// than raw channel references. +struct NodeMessage { + std::vector data; + std::vector handles; + oak::policy::Label label; +}; + +// Result of a read operation relative to a Node. Equivalent to ReadResult but +// holds a NodeMessage rather than a Message. +struct NodeReadResult { + explicit NodeReadResult(OakStatus s) : status(s), required_size(0), required_channels(0) {} + OakStatus status; + // The following fields are filled in if the status is ERR_BUFFER_TOO_SMALL + // or ERR_HANDLE_SPACE_TOO_SMALL, indicating the required size and number + // of handles needed to read the message. + uint32_t required_size; + uint32_t required_channels; + // The following field is filled in if the status is OK. + std::unique_ptr msg; +}; + class OakNode { public: OakNode(BaseRuntime* runtime, const std::string& name) - : runtime_(runtime), name_(name), prng_engine_() {} + : name_(name), runtime_(runtime), prng_engine_() {} virtual ~OakNode() {} - virtual void Start() = 0; + virtual void Start(Handle handle) = 0; virtual void Stop() = 0; - // Take ownership of the given channel half, returning a channel handle that - // the node can use to refer to it in future. - Handle AddChannel(std::unique_ptr half) LOCKS_EXCLUDED(mu_); + // ChannelRead returns the first message on the channel identified by the + // handle, subject to size checks. + NodeReadResult ChannelRead(Handle handle, uint32_t max_size, uint32_t max_channels); - // Close the given channel half. Returns true if the channel was found and closed, - // false if the channel was not found. - bool CloseChannel(Handle handle) LOCKS_EXCLUDED(mu_); + // ChannelWrite passes ownership of a message to the channel identified by the + // handle. + OakStatus ChannelWrite(Handle handle, std::unique_ptr msg); - // Return a borrowed reference to the channel half identified by the given - // handle (or nullptr if the handle is not recognized). Caller is responsible - // for ensuring that the borrowed reference does not out-live the owned - // channel. - ChannelHalf* BorrowChannel(Handle handle) const LOCKS_EXCLUDED(mu_); - MessageChannelReadHalf* BorrowReadChannel(Handle handle) const LOCKS_EXCLUDED(mu_); - MessageChannelWriteHalf* BorrowWriteChannel(Handle handle) const LOCKS_EXCLUDED(mu_); + // Create a channel and return the (write, read) handles for it. + std::pair ChannelCreate(); + + // Close the given channel half. + OakStatus ChannelClose(Handle handle) LOCKS_EXCLUDED(mu_); + + // Create a new Node. + OakStatus NodeCreate(Handle handle, const std::string& config_name, + const std::string& entrypoint_name); // Wait on the given channel handles, modifying the contents of the passed-in // vector. Returns a boolean indicating whether the wait finished due to a @@ -65,23 +92,33 @@ class OakNode { bool WaitOnChannels(std::vector>* statuses) const; protected: - // If the Node has a single registered handle, return it; otherwise, return - // kInvalidHandle. This is a convenience method for initial execution of a - // Node, which should always start with exactly one handle (for a read half) - // registered in its channel_halves_ table; this handle is passed as the - // parameter to the Node's oak_main() entrypoint. - Handle SingleHandle() const LOCKS_EXCLUDED(mu_); - - // Runtime instance that owns this Node. - BaseRuntime* const runtime_; + bool TerminationPending() const { return runtime_->TerminationPending(); } const std::string name_; private: + // Allow the Runtime to use internal methods. + friend class OakRuntime; + + // Take ownership of the given channel half, returning a channel handle that + // the Node can use to refer to it in future. + Handle AddChannel(std::unique_ptr half) LOCKS_EXCLUDED(mu_); + Handle NextHandle() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Return a borrowed reference to the channel half identified by the given + // handle (or nullptr if the handle is not recognized). Caller is responsible + // for ensuring that the borrowed reference does not out-live the owned + // channel. + ChannelHalf* BorrowChannel(Handle handle) const LOCKS_EXCLUDED(mu_); + MessageChannelReadHalf* BorrowReadChannel(Handle handle) const LOCKS_EXCLUDED(mu_); + MessageChannelWriteHalf* BorrowWriteChannel(Handle handle) const LOCKS_EXCLUDED(mu_); + using ChannelHalfTable = std::unordered_map>; + // Runtime instance that owns this Node. + BaseRuntime* const runtime_; + mutable absl::Mutex mu_; // protects prng_engine_, channel_halves_ std::random_device prng_engine_ GUARDED_BY(mu_); diff --git a/oak/server/oak_runtime.cc b/oak/server/oak_runtime.cc index 02ba4cdbf51..b0aaaf09a27 100644 --- a/oak/server/oak_runtime.cc +++ b/oak/server/oak_runtime.cc @@ -79,9 +79,9 @@ grpc::Status OakRuntime::Initialize(const ApplicationConfiguration& config, // Create the initial Application Node. std::string node_name; - OakNode* app_node = + app_node_ = CreateNode(config.initial_node_config_name(), config.initial_entrypoint_name(), &node_name); - if (app_node == nullptr) { + if (app_node_ == nullptr) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Failed to create initial Oak Node"); } OAK_LOG(INFO) << "Created Wasm node named {" << node_name << "}"; @@ -89,11 +89,10 @@ grpc::Status OakRuntime::Initialize(const ApplicationConfiguration& config, // Create an initial channel from gRPC pseudo-Node to Application Node. // Both of the initial nodes have exactly one registered handle. MessageChannel::ChannelHalves halves = MessageChannel::Create(); - Handle grpc_handle = - grpc_node_->AddChannel(absl::make_unique(std::move(halves.write))); - Handle app_handle = app_node->AddChannel(absl::make_unique(std::move(halves.read))); - OAK_LOG(INFO) << "Created initial channel from Wasm node {" << grpc_name << "}." << grpc_handle - << " to {" << node_name << "}." << app_handle; + grpc_handle_ = grpc_node_->AddChannel(absl::make_unique(std::move(halves.write))); + app_handle_ = app_node_->AddChannel(absl::make_unique(std::move(halves.read))); + OAK_LOG(INFO) << "Created initial channel from Wasm node {" << grpc_name << "}." << grpc_handle_ + << " to {" << node_name << "}." << app_handle_; return grpc::Status::OK; } @@ -161,19 +160,16 @@ bool OakRuntime::CreateAndRunNode(const std::string& config_name, Handle handle = node->AddChannel(std::move(half)); OAK_LOG(INFO) << "Start node named {" << *node_name << "} with initial handle " << handle; - node->Start(); + node->Start(handle); return true; } grpc::Status OakRuntime::Start() { OAK_LOG(INFO) << "Starting runtime"; - absl::MutexLock lock(&mu_); - // Now all dependencies are running, start the Nodes running. - for (auto& named_node : nodes_) { - OAK_LOG(INFO) << "Starting node " << named_node.first; - named_node.second->Start(); - } + // Now all dependencies are running, start the initial pair of Nodes running. + grpc_node_->Start(grpc_handle_); + app_node_->Start(app_handle_); return grpc::Status::OK; } diff --git a/oak/server/oak_runtime.h b/oak/server/oak_runtime.h index f81d31355a9..63ccfb671ad 100644 --- a/oak/server/oak_runtime.h +++ b/oak/server/oak_runtime.h @@ -39,7 +39,12 @@ namespace oak { class OakRuntime : public BaseRuntime { public: - OakRuntime() : termination_pending_(false), grpc_node_(nullptr) {} + OakRuntime() + : grpc_node_(nullptr), + grpc_handle_(kInvalidHandle), + app_node_(nullptr), + app_handle_(kInvalidHandle), + termination_pending_(false) {} virtual ~OakRuntime() = default; // Initializes an OakRuntime with a user-provided ApplicationConfiguration. This @@ -77,6 +82,17 @@ class OakRuntime : public BaseRuntime { // Config names that refer to a gRPC client node. std::map> grpc_client_config_; + // Convenience (non-owning) reference to gRPC pseudo-node. + OakGrpcNode* grpc_node_; + // Handle for the write half of the gRPC server notification channel, relative + // to the gRPC server pseudo-Node + Handle grpc_handle_; + // Convenience (non-owning) reference to initial Application Wasm node; + OakNode* app_node_; + // Handle for the read half of the gRPC server notification channel, relative + // to the initial Application Wasm Node. + Handle app_handle_; + // Next index for node name generation. mutable absl::Mutex mu_; // protects nodes_, next_index_; std::map next_index_ GUARDED_BY(mu_); @@ -87,8 +103,6 @@ class OakRuntime : public BaseRuntime { // unique but is not visible to the running Application in any way. std::map> nodes_ GUARDED_BY(mu_); - // Convenience (non-owning) reference to gRPC pseudo-node; const after Initialize() called. - OakGrpcNode* grpc_node_; }; // class OakRuntime } // namespace oak diff --git a/oak/server/storage/BUILD b/oak/server/storage/BUILD index b6494cb2c47..49483566786 100644 --- a/oak/server/storage/BUILD +++ b/oak/server/storage/BUILD @@ -52,7 +52,6 @@ cc_library( "//oak/common:logging", "//oak/proto:grpc_encap_cc_proto", "//oak/proto:storage_channel_cc_proto", - "//oak/server:channel", "//oak/server:invocation", "//oak/server:node_thread", "//third_party/asylo:statusor", diff --git a/oak/server/storage/storage_node.cc b/oak/server/storage/storage_node.cc index 3d524736a48..7648b331c76 100644 --- a/oak/server/storage/storage_node.cc +++ b/oak/server/storage/storage_node.cc @@ -32,12 +32,6 @@ StorageNode::StorageNode(BaseRuntime* runtime, const std::string& name, : NodeThread(runtime, name), storage_processor_(storage_address) {} void StorageNode::Run(Handle invocation_handle) { - // Borrow a pointer to the relevant channel half. - MessageChannelReadHalf* invocation_channel = BorrowReadChannel(invocation_handle); - if (invocation_channel == nullptr) { - OAK_LOG(ERROR) << "Required channel not available; handle: " << invocation_handle; - return; - } std::vector> channel_status; channel_status.push_back(absl::make_unique(invocation_handle)); while (true) { @@ -46,19 +40,19 @@ void StorageNode::Run(Handle invocation_handle) { return; } - std::unique_ptr invocation(Invocation::ReceiveFromChannel(invocation_channel)); + std::unique_ptr invocation(Invocation::ReceiveFromChannel(this, invocation_handle)); if (invocation == nullptr) { OAK_LOG(ERROR) << "Failed to create invocation"; return; } // Expect to read a single request out of the request channel. - ReadResult req_result = invocation->req_channel->Read(INT_MAX, INT_MAX); - if (req_result.required_size > 0) { - OAK_LOG(ERROR) << "Message size too large: " << req_result.required_size; + NodeReadResult req_result = ChannelRead(invocation->req_handle.get(), INT_MAX, INT_MAX); + if (req_result.status != OakStatus::OK) { + OAK_LOG(ERROR) << "Failed to read message: " << req_result.status; return; } - if (req_result.msg->channels.size() != 0) { + if (req_result.msg->handles.size() != 0) { OAK_LOG(ERROR) << "Unexpectedly received channel handles in request channel"; return; } @@ -78,11 +72,11 @@ void StorageNode::Run(Handle invocation_handle) { } grpc_rsp->set_last(true); - std::unique_ptr rsp_msg = absl::make_unique(); + auto rsp_msg = absl::make_unique(); size_t serialized_size = grpc_rsp->ByteSizeLong(); rsp_msg->data.resize(serialized_size); grpc_rsp->SerializeToArray(rsp_msg->data.data(), rsp_msg->data.size()); - invocation->rsp_channel->Write(std::move(rsp_msg)); + ChannelWrite(invocation->rsp_handle.get(), std::move(rsp_msg)); // The response channel reference is dropped here. } diff --git a/oak/server/wasm_node.cc b/oak/server/wasm_node.cc index 10b64bf6cbb..498650a1712 100644 --- a/oak/server/wasm_node.cc +++ b/oak/server/wasm_node.cc @@ -29,7 +29,6 @@ #include "oak/common/logging.h" #include "oak/proto/oak_api.pb.h" #include "oak/server/base_runtime.h" -#include "oak/server/channel.h" #include "oak/server/wabt_output.h" #include "src/binary-reader.h" #include "src/error-formatter.h" @@ -162,7 +161,6 @@ std::unique_ptr WasmNode::Create(BaseRuntime* runtime, const std::stri std::unique_ptr node = absl::WrapUnique(new WasmNode(runtime, name, main_entrypoint)); node->InitEnvironment(&node->env_); - OAK_LOG(INFO) << "Runtime at: " << (void*)node->runtime_; OAK_LOG(INFO) << "Host func count: " << node->env_.GetFuncCount(); wabt::Errors errors; @@ -313,63 +311,29 @@ wabt::interp::HostFunc::Callback WasmNode::OakChannelRead(wabt::interp::Environm return wabt::interp::Result::Ok; } - // Borrowing a reference to the channel is safe because the node is single - // threaded and so cannot invoke channel_close while channel_read is - // ongoing. - MessageChannelReadHalf* channel = BorrowReadChannel(channel_handle); - if (channel == nullptr) { - OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << channel_handle; - results[0].set_i32(OakStatus::ERR_BAD_HANDLE); - return wabt::interp::Result::Ok; - } - - ReadResult result = channel->Read(size, handle_space_count); - if (result.required_size > 0) { - OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << channel_handle - << "]: buffer too small: " << size << " < " << result.required_size; - WriteI32(env, size_offset, result.required_size); - WriteI32(env, handle_count_offset, result.required_channels); - results[0].set_i32(OakStatus::ERR_BUFFER_TOO_SMALL); - return wabt::interp::Result::Ok; - } else if (result.required_channels > 0) { - OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << channel_handle - << "]: handle space too small: " << handle_space_count << " < " - << result.required_channels; - WriteI32(env, size_offset, result.required_size); - WriteI32(env, handle_count_offset, result.required_channels); - results[0].set_i32(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL); - return wabt::interp::Result::Ok; - } else if (result.msg == nullptr) { + NodeReadResult result = ChannelRead(channel_handle, size, handle_space_count); + OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << channel_handle + << "]: gives status: " << result.status << " with required size " + << result.required_size << ", count " << result.required_channels; + WriteI32(env, size_offset, result.required_size); + WriteI32(env, handle_count_offset, result.required_channels); + results[0].set_i32(result.status); + + if ((result.status == OakStatus::OK) && (result.msg != nullptr)) { + // Transfer message and handles to Node. OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << channel_handle - << "]: no message available"; - WriteI32(env, size_offset, 0); - WriteI32(env, handle_count_offset, 0); - - if (channel->Orphaned()) { - OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << channel_handle - << "]: no writers left"; - results[0].set_i32(OakStatus::ERR_CHANNEL_CLOSED); - } else { - results[0].set_i32(OakStatus::ERR_CHANNEL_EMPTY); + << "]: read message of size " << result.msg->data.size() << " with " + << result.msg->handles.size() << " attached handles"; + WriteI32(env, size_offset, result.msg->data.size()); + WriteMemory(env, offset, absl::Span(result.msg->data.data(), result.msg->data.size())); + + WriteI32(env, handle_count_offset, result.msg->handles.size()); + for (size_t ii = 0; ii < result.msg->handles.size(); ii++) { + Handle handle = result.msg->handles[ii]; + OAK_LOG(INFO) << "{" << name_ << "} Transferred new handle " << handle; + WriteU64(env, handle_space_offset + ii * sizeof(Handle), handle); } - return wabt::interp::Result::Ok; - } - - OAK_LOG(INFO) << "{" << name_ << "} channel_read[" << channel_handle - << "]: read message of size " << result.msg->data.size() << " with " - << result.msg->channels.size() << " attached channels"; - WriteI32(env, size_offset, result.msg->data.size()); - WriteMemory(env, offset, absl::Span(result.msg->data.data(), result.msg->data.size())); - WriteI32(env, handle_count_offset, result.msg->channels.size()); - - // Convert any accompanying channels into handles relative to the receiving node. - for (size_t ii = 0; ii < result.msg->channels.size(); ii++) { - Handle handle = AddChannel(std::move(result.msg->channels[ii])); - OAK_LOG(INFO) << "{" << name_ << "} Transferred channel has new handle " << handle; - WriteU64(env, handle_space_offset + ii * sizeof(Handle), handle); } - - results[0].set_i32(OakStatus::OK); return wabt::interp::Result::Ok; }; } @@ -393,46 +357,23 @@ wabt::interp::HostFunc::Callback WasmNode::OakChannelWrite(wabt::interp::Environ return wabt::interp::Result::Ok; } - // Borrowing a reference to the channel is safe because the Node is single - // threaded and so cannot invoke channel_close while channel_write is - // ongoing. - MessageChannelWriteHalf* channel = BorrowWriteChannel(channel_handle); - if (channel == nullptr) { - OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << channel_handle; - results[0].set_i32(OakStatus::ERR_BAD_HANDLE); - return wabt::interp::Result::Ok; - } - - if (channel->Orphaned()) { - OAK_LOG(INFO) << "{" << name_ << "} channel_write[" << channel_handle << "]: no readers left"; - results[0].set_i32(OakStatus::ERR_CHANNEL_CLOSED); - return wabt::interp::Result::Ok; - } - // Copy the data from the Wasm linear memory. absl::Span origin = ReadMemory(env, offset, size); - auto msg = absl::make_unique(); + auto msg = absl::make_unique(); msg->data.insert(msg->data.end(), origin.begin(), origin.end()); OAK_LOG(INFO) << "{" << name_ << "} channel_write[" << channel_handle << "]: write message of size " << size; - // Find any handles and clone the corresponding write channels. + // Find any handles and clone the corresponding channels. std::vector handles; handles.reserve(handle_count); for (uint32_t ii = 0; ii < handle_count; ii++) { Handle handle = ReadU64(env, handle_offset + (ii * sizeof(Handle))); OAK_LOG(INFO) << "{" << name_ << "} Transfer channel handle " << handle; - ChannelHalf* half = BorrowChannel(handle); - if (half == nullptr) { - OAK_LOG(WARNING) << "{" << name_ << "} Invalid transferred channel handle: " << handle; - results[0].set_i32(OakStatus::ERR_BAD_HANDLE); - return wabt::interp::Result::Ok; - } - msg->channels.push_back(CloneChannelHalf(half)); + msg->handles.push_back(handle); } - channel->Write(std::move(msg)); - - results[0].set_i32(OakStatus::OK); + OakStatus status = ChannelWrite(channel_handle, std::move(msg)); + results[0].set_i32(status); return wabt::interp::Result::Ok; }; } @@ -474,7 +415,7 @@ wabt::interp::HostFunc::Callback WasmNode::OakWaitOnChannels(wabt::interp::Envir } if (wait_success) { results[0].set_i32(OakStatus::OK); - } else if (runtime_->TerminationPending()) { + } else if (TerminationPending()) { results[0].set_i32(OakStatus::ERR_TERMINATED); } else { results[0].set_i32(OakStatus::ERR_BAD_HANDLE); @@ -497,14 +438,9 @@ wabt::interp::HostFunc::Callback WasmNode::OakChannelCreate(wabt::interp::Enviro return wabt::interp::Result::Ok; } - MessageChannel::ChannelHalves halves = MessageChannel::Create(); - Handle write_handle = AddChannel(absl::make_unique(std::move(halves.write))); - Handle read_handle = AddChannel(absl::make_unique(std::move(halves.read))); - OAK_LOG(INFO) << "{" << name_ << "} Created new channel with handles write=" << write_handle - << ", read=" << read_handle; - - WriteU64(env, write_half_offset, write_handle); - WriteU64(env, read_half_offset, read_handle); + std::pair handles = ChannelCreate(); + WriteU64(env, write_half_offset, handles.first); + WriteU64(env, read_half_offset, handles.second); results[0].set_i32(OakStatus::OK); return wabt::interp::Result::Ok; @@ -517,14 +453,10 @@ wabt::interp::HostFunc::Callback WasmNode::OakChannelClose(wabt::interp::Environ LogHostFunctionCall(name_, func, args); Handle channel_handle = args[0].get_i64(); - - if (CloseChannel(channel_handle)) { - OAK_LOG(INFO) << "{" << name_ << "} Closed channel handle: " << channel_handle; - results[0].set_i32(OakStatus::OK); - } else { - OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << channel_handle; - results[0].set_i32(OakStatus::ERR_BAD_HANDLE); - } + OakStatus status = ChannelClose(channel_handle); + OAK_LOG(INFO) << "{" << name_ << "} Close channel handle " << channel_handle << " status " + << status; + results[0].set_i32(status); return wabt::interp::Result::Ok; }; } @@ -549,35 +481,13 @@ wabt::interp::HostFunc::Callback WasmNode::OakNodeCreate(wabt::interp::Environme results[0].set_i32(OakStatus::ERR_INVALID_ARGS); return wabt::interp::Result::Ok; } - - // Check that the handle identifies the read half of a channel. - ChannelHalf* borrowed_half = BorrowChannel(channel_handle); - if (borrowed_half == nullptr) { - OAK_LOG(WARNING) << "{" << name_ << "} Invalid channel handle: " << channel_handle; - results[0].set_i32(OakStatus::ERR_BAD_HANDLE); - return wabt::interp::Result::Ok; - } - if (!absl::holds_alternative>(*borrowed_half)) { - OAK_LOG(WARNING) << "{" << name_ << "} Wrong direction channel handle: " << channel_handle; - results[0].set_i32(OakStatus::ERR_BAD_HANDLE); - return wabt::interp::Result::Ok; - } - std::unique_ptr half = CloneChannelHalf(borrowed_half); - auto config_base = env->GetMemory(0)->data.begin() + config_offset; std::string config_name(config_base, config_base + config_size); auto entrypoint_base = env->GetMemory(0)->data.begin() + entrypoint_offset; std::string entrypoint_name(entrypoint_base, entrypoint_base + entrypoint_size); - OAK_LOG(INFO) << "Create a new node with config '" << config_name << "' and entrypoint '" - << entrypoint_name << "'"; - std::string node_name; - if (!runtime_->CreateAndRunNode(config_name, entrypoint_name, std::move(half), &node_name)) { - results[0].set_i32(OakStatus::ERR_INVALID_ARGS); - } else { - OAK_LOG(INFO) << "Created new node named {" << node_name << "}"; - results[0].set_i32(OakStatus::OK); - } + OakStatus status = NodeCreate(channel_handle, config_name, entrypoint_name); + results[0].set_i32(status); return wabt::interp::Result::Ok; }; }