Skip to content

Commit

Permalink
sdk: commonize C++ module entrypoint code
Browse files Browse the repository at this point in the history
Have a common `grpc_oak_main()` that expects to invoke an externally
provided `process_invocation()` function.  Implemented as a header
file rather than a .cc file to work around some linker script
weirdness.
  • Loading branch information
daviddrysdale committed May 26, 2020
1 parent 820fa28 commit 64fc9ce
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 199 deletions.
1 change: 1 addition & 0 deletions examples/hello_world/module/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_binary(
srcs = ["hello_world.cc"],
deps = [
"//oak/module:oak_abi",
"//oak/module:oak_main",
# TODO(#422): Sort out inclusion of protobuf files
# "//oak/proto:oak_api_cc_proto",
],
Expand Down
106 changes: 17 additions & 89 deletions examples/hello_world/module/cpp/hello_world.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,93 +19,21 @@

#include "oak/module/oak_abi.h"

// TODO(#422): Sort out inclusion of protobuf files
// #include "oak/proto/oak_api.pb.h"

// Local copy of oak_api.pb.h contents for now.
namespace oak {

} // namespace oak

WASM_EXPORT void grpc_oak_main(oak_abi::Handle _handle) {
// Create a channel to the gRPC server pseudo-Node.
oak_abi::Handle write_handle;
oak_abi::Handle read_handle;
oak_abi::OakStatus result = channel_create(&write_handle, &read_handle, nullptr, 0);
if (result != oak_abi::OakStatus::OK) {
return;
}

// Create a gRPC server pseudo-Node
char config_name[] = "grpc-server";
result = node_create((uint8_t*)config_name, sizeof(config_name) - 1, nullptr, 0, nullptr, 0,
read_handle);
if (result != oak_abi::OakStatus::OK) {
return;
}
channel_close(read_handle);

// Create a separate channel for receiving invocations and pass it to the gRPC pseudo-Node.
oak_abi::Handle grpc_out_handle;
oak_abi::Handle grpc_in_handle;
result = channel_create(&grpc_out_handle, &grpc_in_handle, nullptr, 0);
if (result != oak_abi::OakStatus::OK) {
return;
}
result = channel_write(write_handle, nullptr, 0, (uint8_t*)&grpc_out_handle, 1);
if (result != oak_abi::OakStatus::OK) {
return;
}
channel_close(grpc_out_handle);
channel_close(write_handle);

// TODO(#744): Add C++ helpers for dealing with handle notification space.
uint8_t handle_space[9] = {
static_cast<uint8_t>(grpc_in_handle & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 8) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 16) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 24) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 32) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 40) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 48) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 56) & 0xff),
0x00, // read ready?
};

while (true) {
result = wait_on_channels(handle_space, 1);
if (result != oak_abi::OakStatus::OK) {
return;
}

// Reading from main channel should give no data and a (read, write) pair of handles.
uint32_t actual_size;
uint32_t handle_count;
oak_abi::Handle handles[2];
channel_read(grpc_in_handle, nullptr, 0, &actual_size, handles, 2, &handle_count);
if ((actual_size != 0) || (handle_count != 2)) {
return;
}
oak_abi::Handle req_handle = handles[0];
oak_abi::Handle rsp_handle = handles[1];

// Read an incoming request from the read handle, expecting data but no handles.
// (However, ignore its contents for now).
uint8_t buf[256];
channel_read(req_handle, buf, sizeof(buf), &actual_size, nullptr, 0, &handle_count);
channel_close(req_handle);

// Manually create an encapsulated GrpcResponse protobuf and send it back.
// 0a b00001.010 = tag 1 (GrpcResponse.rsp_msg), length-delimited field
// 0b length=11
// 0A b00001.010 = tag 1 (HelloResponse.reply), length-delimited field
// 07 length=7
// 74657374696e67 "testing"
// 18 b00011.000 = tag 3 (GrpcResponse.last), varint
// 01 true
uint8_t rsp_buf[] = "\x0a\x0b\x0A\x07\x74\x65\x73\x74\x69\x6e\x67\x18\x01";
// TODO(#422): replace with use of message type and serialization.
channel_write(rsp_handle, rsp_buf, sizeof(rsp_buf) - 1, nullptr, 0);
channel_close(rsp_handle);
}
// Include standard C++ placeholder oak_main() implementation.
#include "oak/module/oak_main.h"

extern "C" void process_invocation(const uint8_t* _req_buf, uint32_t _req_size,
oak_abi::Handle rsp_handle) {
// Ignore the contents of the incoming request.
// Manually create an encapsulated GrpcResponse protobuf and send it back.
// 0a b00001.010 = tag 1 (GrpcResponse.rsp_msg), length-delimited field
// 0b length=11
// 0A b00001.010 = tag 1 (HelloResponse.reply), length-delimited field
// 07 length=7
// 74657374696e67 "testing"
// 18 b00011.000 = tag 3 (GrpcResponse.last), varint
// 01 true
uint8_t rsp_buf[] = "\x0a\x0b\x0A\x07\x74\x65\x73\x74\x69\x6e\x67\x18\x01";
// TODO(#422): replace with use of message type and serialization.
channel_write(rsp_handle, rsp_buf, sizeof(rsp_buf) - 1, nullptr, 0);
}
2 changes: 1 addition & 1 deletion examples/tensorflow/config/config.textproto
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ node_configs {
log_config {}
}
initial_node_config_name: "app"
initial_entrypoint_name: "oak_main"
initial_entrypoint_name: "grpc_oak_main"
6 changes: 4 additions & 2 deletions examples/tensorflow/module/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ cc_binary(
"-s AUTO_ARCHIVE_INDEXES=0", # Necessary to remove console spam.
"-s DISABLE_EXCEPTION_CATCHING=1",
"-s ERROR_ON_UNDEFINED_SYMBOLS=0",
"-s EXPORTED_FUNCTIONS='[\"_oak_main\"]'",
"-s EXPORTED_FUNCTIONS='[\"_grpc_oak_main\"]'",
"-s FILESYSTEM=0",
"-s MALLOC=emmalloc",
"-s STANDALONE_WASM=1", # WASM file should run without JavaScript.
Expand All @@ -45,6 +45,7 @@ cc_binary(
],
deps = [
"//oak/module:oak_abi",
"//oak/module:oak_main",
"//oak/module:placeholders",
# TODO(#422): Sort out inclusion of protobuf files
# "//oak/proto:oak_api_cc_proto",
Expand All @@ -68,7 +69,7 @@ cc_binary(
"-s AUTO_ARCHIVE_INDEXES=0", # Necessary to remove console spam.
"-s DISABLE_EXCEPTION_CATCHING=1",
"-s ERROR_ON_UNDEFINED_SYMBOLS=0",
"-s EXPORTED_FUNCTIONS='[\"_oak_main\"]'",
"-s EXPORTED_FUNCTIONS='[\"_grpc_oak_main\"]'",
"-s FILESYSTEM=0",
"-s MALLOC=emmalloc",
"-s STANDALONE_WASM=1", # WASM file should run without JavaScript.
Expand All @@ -78,6 +79,7 @@ cc_binary(
],
deps = [
"//oak/module:oak_abi",
"//oak/module:oak_main",
"//oak/module:placeholders",
# TODO(#422): Sort out inclusion of protobuf files
# "//oak/proto:oak_api_cc_proto",
Expand Down
72 changes: 19 additions & 53 deletions examples/tensorflow/module/cpp/tensorflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"

// Include standard C++ placeholder oak_main() implementation.
#include "oak/module/oak_main.h"

const char kModelBuffer[] = {
0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x04,
0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00,
Expand Down Expand Up @@ -65,58 +68,21 @@ std::string init_tensorflow() {
return std::string("Success: Model was loaded correctly");
}

WASM_EXPORT void oak_main(oak_abi::Handle grpc_in_handle) {
// TODO(#744): Add C++ helpers for dealing with handle notification space.
uint8_t handle_space[9] = {
static_cast<uint8_t>(grpc_in_handle & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 8) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 16) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 24) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 32) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 40) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 48) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 56) & 0xff),
0x00, // read ready?
};

while (true) {
oak_abi::OakStatus result = wait_on_channels(handle_space, 1);
if (result != oak::OakStatus::OK) {
return;
}

// Reading from main channel should return no bytes and a (read, write) pair of handles.
uint32_t actual_size;
uint32_t handle_count;
oak_abi::Handle handles[2];
channel_read(grpc_in_handle, nullptr, 0, &actual_size, handles, 2, &handle_count);
if ((actual_size != 0) || (handle_count != 2)) {
return;
}
oak_abi::Handle req_handle = handles[0];
oak_abi::Handle rsp_handle = handles[1];
extern "C" void process_invocation(const uint8_t* _req_buf, uint32_t _req_size,
oak_abi::Handle rsp_handle) {
init_tensorflow();

// Read an incoming request from the read handle, expecting data but no handles.
// (However, ignore its contents for now).
uint8_t buf[256];
channel_read(req_handle, buf, sizeof(buf), &actual_size, nullptr, 0, &handle_count);
channel_close(req_handle);

init_tensorflow();

// Manually create an encapsulated GrpcResponse protobuf and send it back.
// 0a b00001.010 = tag 1 (GrpcResponse.rsp_msg), length-delimited field
// 0b length=11
// 12 b00010.010 = tag 2 (Any.value), length-delimited field
// 09 length=9
// 0A b00001.010 = tag 1 (HelloResponse.reply), length-delimited field
// 07 length=7
// 74657374696e67 "testing"
// 18 b00011.000 = tag 3 (GrpcResponse.last), varint
// 01 true
uint8_t rsp_buf[] = "\x0a\x0b\x12\x09\x0A\x07\x74\x65\x73\x74\x69\x6e\x67\x18\x01";
// TODO(#422): replace with use of message type and serialization.
channel_write(rsp_handle, rsp_buf, sizeof(rsp_buf) - 1, nullptr, 0);
channel_close(rsp_handle);
}
// Manually create an encapsulated GrpcResponse protobuf and send it back.
// 0a b00001.010 = tag 1 (GrpcResponse.rsp_msg), length-delimited field
// 0b length=11
// 12 b00010.010 = tag 2 (Any.value), length-delimited field
// 09 length=9
// 0A b00001.010 = tag 1 (HelloResponse.reply), length-delimited field
// 07 length=7
// 74657374696e67 "testing"
// 18 b00011.000 = tag 3 (GrpcResponse.last), varint
// 01 true
uint8_t rsp_buf[] = "\x0a\x0b\x12\x09\x0A\x07\x74\x65\x73\x74\x69\x6e\x67\x18\x01";
// TODO(#422): replace with use of message type and serialization.
channel_write(rsp_handle, rsp_buf, sizeof(rsp_buf) - 1, nullptr, 0);
}
74 changes: 20 additions & 54 deletions examples/tensorflow/module/cpp/tensorflow_micro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_interpreter.h"

// Include standard C++ placeholder oak_main() implementation.
#include "oak/module/oak_main.h"

// Constants were taken from the TFLite exapmles:
// https://github.com/tensorflow/tensorflow/blob/11bed638b14898cdde967f6b108e45732aa4798a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc#L25
// https://github.com/tensorflow/tensorflow/blob/11bed638b14898cdde967f6b108e45732aa4798a/tensorflow/lite/micro/examples/network_tester/network_model.h#L16-L64
Expand Down Expand Up @@ -94,58 +97,21 @@ std::string init_tensorflow() {
return std::string("Success: Model was loaded correctly");
}

WASM_EXPORT void oak_main(oak_abi::Handle grpc_in_handle) {
// TODO(#744): Add C++ helpers for dealing with handle notification space.
uint8_t handle_space[9] = {
static_cast<uint8_t>(grpc_in_handle & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 8) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 16) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 24) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 32) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 40) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 48) & 0xff),
static_cast<uint8_t>((grpc_in_handle >> 56) & 0xff),
0x00, // read ready?
};

while (true) {
oak_abi::OakStatus result = wait_on_channels(handle_space, 1);
if (result != oak::OakStatus::OK) {
return;
}

// Reading from main channel should return no bytes and a (read, write) pair of handles.
uint32_t actual_size;
uint32_t handle_count;
oak_abi::Handle handles[2];
channel_read(grpc_in_handle, nullptr, 0, &actual_size, handles, 2, &handle_count);
if ((actual_size != 0) || (handle_count != 2)) {
return;
}
oak_abi::Handle req_handle = handles[0];
oak_abi::Handle rsp_handle = handles[1];

// Read an incoming request from the read handle, expecting data but no handles.
// (However, ignore its contents for now).
uint8_t buf[256];
channel_read(req_handle, buf, sizeof(buf), &actual_size, nullptr, 0, &handle_count);
channel_close(req_handle);

init_tensorflow();

// Manually create an encapsulated GrpcResponse protobuf and send it back.
// 0a b00001.010 = tag 1 (GrpcResponse.rsp_msg), length-delimited field
// 0b length=11
// 12 b00010.010 = tag 2 (Any.value), length-delimited field
// 09 length=9
// 0A b00001.010 = tag 1 (HelloResponse.reply), length-delimited field
// 07 length=7
// 74657374696e67 "testing"
// 18 b00011.000 = tag 3 (GrpcResponse.last), varint
// 01 true
uint8_t rsp_buf[] = "\x0a\x0b\x12\x09\x0A\x07\x74\x65\x73\x74\x69\x6e\x67\x18\x01";
// TODO(#422): replace with use of message type and serialization.
channel_write(rsp_handle, rsp_buf, sizeof(rsp_buf) - 1, nullptr, 0);
channel_close(rsp_handle);
}
extern "C" void process_invocation(const uint8_t* _req_buf, uint32_t _req_size,
oak_abi::Handle rsp_handle) {
init_tensorflow();

// Manually create an encapsulated GrpcResponse protobuf and send it back.
// 0a b00001.010 = tag 1 (GrpcResponse.rsp_msg), length-delimited field
// 0b length=11
// 12 b00010.010 = tag 2 (Any.value), length-delimited field
// 09 length=9
// 0A b00001.010 = tag 1 (HelloResponse.reply), length-delimited field
// 07 length=7
// 74657374696e67 "testing"
// 18 b00011.000 = tag 3 (GrpcResponse.last), varint
// 01 true
uint8_t rsp_buf[] = "\x0a\x0b\x12\x09\x0A\x07\x74\x65\x73\x74\x69\x6e\x67\x18\x01";
// TODO(#422): replace with use of message type and serialization.
channel_write(rsp_handle, rsp_buf, sizeof(rsp_buf) - 1, nullptr, 0);
}
7 changes: 7 additions & 0 deletions oak/module/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,10 @@ cc_library(
hdrs = ["placeholders.h"],
visibility = ["//visibility:public"],
)

cc_library(
name = "oak_main",
srcs = ["oak_main.h"],
visibility = ["//visibility:public"],
deps = [":oak_abi"],
)
Loading

0 comments on commit 64fc9ce

Please sign in to comment.