Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pseudo-Nodes use Handles not raw channel references #767

Merged
merged 15 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions oak/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ cc_library(
":channel",
"//oak/common:handles",
"//oak/common:logging",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)

cc_library(
name = "handle_closer",
hdrs = ["handle_closer.h"],
deps = [":oak_node"],
)

cc_library(
name = "wasm_node",
srcs = [
Expand All @@ -47,7 +54,6 @@ cc_library(
"wasm_node.h",
],
deps = [
":channel",
"//oak/common:handles",
"//oak/common:logging",
"//oak/proto:grpc_encap_cc_proto",
Expand Down Expand Up @@ -109,9 +115,9 @@ cc_library(
"oak_grpc_node.h",
],
deps = [
":channel",
":oak_node",
"//oak/common:app_config",
"//oak/common:handles",
"//oak/common:logging",
"//oak/common:policy",
"//oak/proto:grpc_encap_cc_proto",
Expand Down Expand Up @@ -191,7 +197,7 @@ cc_library(
srcs = ["logging_node.cc"],
hdrs = ["logging_node.h"],
deps = [
":channel",
":handle_closer",
":node_thread",
"//oak/common:handles",
"//oak/common:logging",
Expand All @@ -205,7 +211,9 @@ cc_library(
srcs = ["invocation.cc"],
hdrs = ["invocation.h"],
deps = [
":channel",
":handle_closer",
":oak_node",
"//oak/common:handles",
"//oak/common:logging",
"@com_google_absl//absl/memory",
],
Expand All @@ -217,7 +225,6 @@ cc_library(
hdrs = ["grpc_client_node.h"],
deps = [
":base_runtime",
":channel",
":invocation",
":node_thread",
"//oak/common:handles",
Expand Down
21 changes: 15 additions & 6 deletions oak/server/channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,33 @@ void MessageChannel::Write(std::unique_ptr<Message> msg) {
ReadResult MessageChannel::Read(uint32_t max_size, uint32_t max_channels) {
absl::MutexLock lock(&mu_);
if (msgs_.empty()) {
return ReadResult{0};
ReadResult result(OakStatus::OK);
return result;
}
return ReadLocked(max_size, max_channels);
}

ReadResult MessageChannel::ReadLocked(uint32_t max_size, uint32_t max_channels) {
ReadResult result = {0};
Message* next_msg = msgs_.front().get();
size_t actual_size = next_msg->data.size();
size_t actual_count = next_msg->channels.size();
if (actual_size > max_size || actual_count > max_channels) {
OAK_LOG(INFO) << "Next message of size " << actual_size << " with " << actual_count
<< " channels, read limited to size " << max_size << " and " << max_channels
<< " channels";
if (actual_size > max_size) {
OAK_LOG(INFO) << "Next message of size " << actual_size << ", read limited to size "
<< max_size;
ReadResult result(OakStatus::ERR_BUFFER_TOO_SMALL);
result.required_size = actual_size;
result.required_channels = actual_count;
return result;
}
if (actual_count > max_channels) {
OAK_LOG(INFO) << "Next message with " << actual_count << " handles, read limited to "
<< max_channels << " handles";
ReadResult result(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL);
result.required_size = actual_size;
result.required_channels = actual_count;
return result;
}
ReadResult result(OakStatus::OK);
result.msg = std::move(msgs_.front());
msgs_.pop_front();
OAK_LOG(INFO) << "Read message of size " << result.msg->data.size() << " with " << actual_count
Expand Down
16 changes: 8 additions & 8 deletions oak/server/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ struct Message {
oak::policy::Label label;
};

// Result of a read operation. If the operation would have produced a message
// bigger than the requested maximum size, then |required_size| will be non-zero
// and indicates the required size for the message. If the operation would have
// been accompanied by more than the requested maximum channel count, then
// |required_channels| will be non-zero and indicates the required channel count
// for the message. Otherwise, |required_size| and |required_channels| will be
// zero and |data| holds the message (transferring ownership).
// Result of a read operation.
struct ReadResult {
explicit ReadResult(OakStatus s) : status(s), required_size(0), required_channels(0) {}
OakStatus status;
// The following fields are filled in if the status is ERR_BUFFER_TOO_SMALL
// or ERR_HANDLE_SPACE_TOO_SMALL, indicating the required size and number
// of handles needed to read the message.
uint32_t required_size;
uint32_t required_channels;
// The following field is filled in if the status is OK.
std::unique_ptr<Message> msg;
};

Expand Down Expand Up @@ -99,7 +99,7 @@ class MessageChannel {
// Count indicates the number of pending messages.
size_t Count() const LOCKS_EXCLUDED(mu_);

// Read returns the first message on the channel, subject to |max_size| checks.
// Read returns the first message on the channel, subject to size checks.
ReadResult Read(uint32_t max_size, uint32_t max_channels) LOCKS_EXCLUDED(mu_);

// BlockingRead behaves like Read but blocks until a message is available.
Expand Down
12 changes: 12 additions & 0 deletions oak/server/channel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,20 @@ TEST(MessageChannel, BasicOperation) {
ASSERT_EQ(true, read_half->CanRead());

ReadResult result1 = read_half->Read(1, 0); // too small
ASSERT_EQ(OakStatus::ERR_BUFFER_TOO_SMALL, result1.status);
ASSERT_EQ(3, result1.required_size);
ASSERT_EQ(nullptr, result1.msg);

ReadResult result2 = read_half->Read(3, 0); // just right
ASSERT_EQ(OakStatus::OK, result2.status);
EXPECT_NE(result2.msg, nullptr);
ASSERT_EQ(3, result2.msg->data.size());
ASSERT_EQ(0x01, (result2.msg->data)[0]);

ASSERT_EQ(false, read_half->CanRead());

ReadResult result3 = read_half->Read(10000, 0);
ASSERT_EQ(OakStatus::OK, result3.status);
ASSERT_EQ(nullptr, result3.msg);
ASSERT_EQ(0, result3.required_size);

Expand All @@ -65,15 +68,18 @@ TEST(MessageChannel, BasicOperation) {
ASSERT_EQ(true, read_half->CanRead());

ReadResult result4 = read_half->Read(3000, 0);
ASSERT_EQ(OakStatus::OK, result4.status);
EXPECT_NE(result4.msg, nullptr);
ASSERT_EQ(3, result4.msg->data.size());
ASSERT_EQ(0x11, (result4.msg->data)[0]);

ReadResult result5 = read_half->Read(0, 0);
ASSERT_EQ(OakStatus::ERR_BUFFER_TOO_SMALL, result5.status);
ASSERT_EQ(3, result5.required_size);
ASSERT_EQ(nullptr, result5.msg);

ReadResult result6 = read_half->Read(10, 0);
ASSERT_EQ(OakStatus::OK, result6.status);
EXPECT_NE(result6.msg, nullptr);
ASSERT_EQ(3, result6.msg->data.size());
ASSERT_EQ(0x21, (result6.msg->data)[0]);
Expand All @@ -94,11 +100,13 @@ TEST(MessageChannel, ChannelTransfer) {
ASSERT_EQ(true, read_half->CanRead());

ReadResult result1 = read_half->Read(1000, 0); // no space for channels
ASSERT_EQ(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL, result1.status);
ASSERT_EQ(3, result1.required_size);
ASSERT_EQ(1, result1.required_channels);
ASSERT_EQ(nullptr, result1.msg);

ReadResult result2 = read_half->Read(1000, 1); // just right
ASSERT_EQ(OakStatus::OK, result2.status);
EXPECT_NE(result2.msg, nullptr);
ASSERT_EQ(3, result2.msg->data.size());
ASSERT_EQ(0x01, (result2.msg->data)[0]);
Expand All @@ -107,6 +115,7 @@ TEST(MessageChannel, ChannelTransfer) {
ASSERT_EQ(false, read_half->CanRead());

ReadResult result3 = read_half->Read(10000, 10);
ASSERT_EQ(OakStatus::OK, result3.status);
ASSERT_EQ(nullptr, result3.msg);
ASSERT_EQ(0, result3.required_size);
ASSERT_EQ(0, result3.required_channels);
Expand All @@ -124,17 +133,20 @@ TEST(MessageChannel, ChannelTransfer) {
ASSERT_EQ(true, read_half->CanRead());

ReadResult result4 = read_half->Read(3000, 10);
ASSERT_EQ(OakStatus::OK, result4.status);
EXPECT_NE(result4.msg, nullptr);
ASSERT_EQ(3, result4.msg->data.size());
ASSERT_EQ(0x11, (result4.msg->data)[0]);
ASSERT_EQ(2, result4.msg->channels.size());

ReadResult result5 = read_half->Read(100, 0);
ASSERT_EQ(OakStatus::ERR_HANDLE_SPACE_TOO_SMALL, result5.status);
ASSERT_EQ(3, result5.required_size);
ASSERT_EQ(3, result5.required_channels);
ASSERT_EQ(nullptr, result5.msg);

ReadResult result6 = read_half->Read(10, 10);
ASSERT_EQ(OakStatus::OK, result6.status);
EXPECT_NE(result6.msg, nullptr);
ASSERT_EQ(3, result6.msg->data.size());
ASSERT_EQ(0x21, (result6.msg->data)[0]);
Expand Down
29 changes: 11 additions & 18 deletions oak/server/grpc_client_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ GrpcClientNode::GrpcClientNode(BaseRuntime* runtime, const std::string& name,
OAK_LOG(INFO) << "Created gRPC client node for " << grpc_address;
}

bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel) {
std::unique_ptr<Invocation> invocation(Invocation::ReceiveFromChannel(invocation_channel));
bool GrpcClientNode::HandleInvocation(Handle invocation_handle) {
std::unique_ptr<Invocation> invocation(Invocation::ReceiveFromChannel(this, invocation_handle));
if (invocation == nullptr) {
OAK_LOG(ERROR) << "Failed to create invocation";
return false;
}

// Expect to read a single request out of the request channel.
// TODO(#97): support client-side streaming
ReadResult req_result = invocation->req_channel->Read(INT_MAX, INT_MAX);
if (req_result.required_size > 0) {
OAK_LOG(ERROR) << "Message size too large: " << req_result.required_size;
NodeReadResult req_result = ChannelRead(invocation->req_handle.get(), INT_MAX, INT_MAX);
if (req_result.status != OakStatus::OK) {
OAK_LOG(ERROR) << "Failed to read invocation message: " << req_result.status;
return false;
}
if (req_result.msg->channels.size() != 0) {
if (req_result.msg->handles.size() != 0) {
OAK_LOG(ERROR) << "Unexpectedly received channel handles in request channel";
return false;
}
Expand Down Expand Up @@ -119,14 +119,14 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel
any->set_value(rsp_data.data(), rsp_data.size());
grpc_rsp.set_allocated_rsp_msg(any);

std::unique_ptr<Message> rsp_msg = absl::make_unique<Message>();
auto rsp_msg = absl::make_unique<NodeMessage>();
size_t serialized_size = grpc_rsp.ByteSizeLong();
rsp_msg->data.resize(serialized_size);
grpc_rsp.SerializeToArray(rsp_msg->data.data(), rsp_msg->data.size());

// Write the encapsulated response Message to the response channel.
OAK_LOG(INFO) << "Write gRPC response message to response channel";
invocation->rsp_channel->Write(std::move(rsp_msg));
ChannelWrite(invocation->rsp_handle.get(), std::move(rsp_msg));
}

OAK_LOG(INFO) << "Finish invocation method " << method_name;
Expand All @@ -143,14 +143,14 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel
grpc_rsp.mutable_status()->set_code(status.error_code());
grpc_rsp.mutable_status()->set_message(status.error_message());

std::unique_ptr<Message> rsp_msg = absl::make_unique<Message>();
auto rsp_msg = absl::make_unique<NodeMessage>();
size_t serialized_size = grpc_rsp.ByteSizeLong();
rsp_msg->data.resize(serialized_size);
grpc_rsp.SerializeToArray(rsp_msg->data.data(), rsp_msg->data.size());

OAK_LOG(INFO) << "Write final gRPC status of (" << status.error_code() << ", '"
<< status.error_message() << "') to response channel";
invocation->rsp_channel->Write(std::move(rsp_msg));
ChannelWrite(invocation->rsp_handle.get(), std::move(rsp_msg));
}

// References to the per-invocation request/response channels will be dropped
Expand All @@ -159,12 +159,6 @@ bool GrpcClientNode::HandleInvocation(MessageChannelReadHalf* invocation_channel
}

void GrpcClientNode::Run(Handle invocation_handle) {
// Borrow a pointer to the relevant channel half.
MessageChannelReadHalf* invocation_channel = BorrowReadChannel(invocation_handle);
if (invocation_channel == nullptr) {
OAK_LOG(ERROR) << "Required channel not available; handle: " << invocation_handle;
return;
}
std::vector<std::unique_ptr<ChannelStatus>> channel_status;
channel_status.push_back(absl::make_unique<ChannelStatus>(invocation_handle));
while (true) {
Expand All @@ -173,11 +167,10 @@ void GrpcClientNode::Run(Handle invocation_handle) {
return;
}

if (!HandleInvocation(invocation_channel)) {
if (!HandleInvocation(invocation_handle)) {
OAK_LOG(ERROR) << "Invocation processing failed!";
}
}
// Drop reference to the invocation channel on exit.
}

} // namespace oak
3 changes: 1 addition & 2 deletions oak/server/grpc_client_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "include/grpcpp/grpcpp.h"
#include "oak/common/handles.h"
#include "oak/server/base_runtime.h"
#include "oak/server/channel.h"
#include "oak/server/node_thread.h"

namespace oak {
Expand All @@ -34,7 +33,7 @@ class GrpcClientNode final : public NodeThread {
GrpcClientNode(BaseRuntime* runtime, const std::string& name, const std::string& grpc_address);

private:
bool HandleInvocation(MessageChannelReadHalf* invocation_channel);
bool HandleInvocation(Handle invocation_handle);
void Run(Handle handle) override;

std::shared_ptr<grpc::ChannelInterface> channel_;
Expand Down
40 changes: 40 additions & 0 deletions oak/server/handle_closer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2020 The Project Oak Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef OAK_SERVER_HANDLE_CLOSER_H_
#define OAK_SERVER_HANDLE_CLOSER_H_

#include "oak/server/oak_node.h"

namespace oak {

// RAII class to hold a Node's handle and auto-close it on object destruction.
class HandleCloser {
public:
explicit HandleCloser(OakNode* node, Handle handle) : node_(node), handle_(handle) {}

~HandleCloser() { node_->ChannelClose(handle_); }

Handle get() const { return handle_; }

private:
OakNode* node_;
const Handle handle_;
};

} // namespace oak

#endif // OAK_SERVER_HANDLE_CLOSER_H_
Loading