Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows Support for cpp_rpc #4857

Merged
merged 11 commits into from
Apr 15, 2020
18 changes: 17 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
cmake_minimum_required(VERSION 3.2)
if(WIN32)
cmake_minimum_required(VERSION 3.9)
else()
cmake_minimum_required(VERSION 3.2)
endif()

project(tvm C CXX)

# Utility functions
Expand Down Expand Up @@ -63,6 +68,8 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)

tvm_option(USE_CXX_RPC "Build CXX RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)

Expand Down Expand Up @@ -272,8 +279,15 @@ endif()

add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
if(WIN32)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()

add_library(tvm_runtime SHARED ${RUNTIME_SRCS})

if(USE_CXX_RPC STREQUAL "ON")
add_subdirectory("apps/cpp_rpc")
endif()

if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...")
Expand Down Expand Up @@ -391,6 +405,8 @@ endif(INSTALL_DEV)

# More target definitions
if(MSVC)
set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
endif()
27 changes: 27 additions & 0 deletions apps/cpp_rpc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
set(TVM_RPC_SOURCES
main.cc
rpc_env.cc
rpc_server.cc
)

if(WIN32)
list(APPEND TVM_RPC_SOURCES win32_process.cc)
endif()

# Set output to same directory as the other TVM libs
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
add_executable(tvm_rpc ${TVM_RPC_SOURCES})
set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)

if(WIN32)
target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX)
endif()

target_include_directories(
tvm_rpc
PUBLIC "../../include"
PUBLIC DLPACK_PATH
PUBLIC DMLC_PATH
)

target_link_libraries(tvm_rpc tvm_runtime)
95 changes: 68 additions & 27 deletions apps/cpp_rpc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
#include <stdlib.h>
#include <signal.h>
#include <stdio.h>
#include <cstdlib>
#include <csignal>
#include <cstdio>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h>
#endif
#include <dmlc/logging.h>
#include <iostream>
#include <cstring>
Expand All @@ -35,11 +37,15 @@
#include "../../src/support/socket.h"
#include "rpc_server.h"

#if defined(_WIN32)
#include "win32_process.h"
#endif

using namespace std;
using namespace tvm::runtime;
using namespace tvm::support;

static const string kUSAGE = \
static const string kUsage = \
"Command line usage\n" \
" server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \
Expand Down Expand Up @@ -73,13 +79,16 @@ struct RpcServerArgs {
string key;
string custom_addr;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
#endif
};

/*!
* \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure
*/
void PrintArgs(struct RpcServerArgs args) {
void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end;
Expand All @@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
}

#if defined(__linux__) || defined(__ANDROID__)
/*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal
Expand All @@ -109,7 +119,7 @@ void HandleCtrlC() {
sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr);
}

#endif
/*!
* \brief GetCmdOption Parse and find the command option.
* \param argc arg counter
Expand All @@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
}
// We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find("=") + 1);
cmd = arg.substr(arg.find('=') + 1);
return cmd;
}
}
Expand All @@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) {
* \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter
* \param argv arg values
* \param args, the output structure which holds the parsed values
* \param args the output structure which holds the parsed values
*/
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
string silent = GetCmdOption(argc, argv, "--silent", true);
const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}

string host = GetCmdOption(argc, argv, "--host=");
const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) {
if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.host = host;
}

string port = GetCmdOption(argc, argv, "--port=");
const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.port = stoi(port);
}

string port_end = GetCmdOption(argc, argv, "--port_end=");
const string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.port_end = stoi(port_end);
Expand All @@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
if (!tracker.empty()) {
if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.tracker = tracker;
}

string key = GetCmdOption(argc, argv, "--key=");
const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) {
args.key = key;
}

string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
const string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.custom_addr = custom_addr;
}
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
if(!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif

}

/*!
Expand All @@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
* \return result of operation.
*/
int RpcServer(int argc, char * argv[]) {
struct RpcServerArgs args;
RpcServerArgs args;

/* parse the command line args */
ParseCmdArgs(argc, argv, args);
PrintArgs(args);

// Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
#if defined(__linux__) || defined(__ANDROID__)
// Ctrl+C handler
HandleCtrlC();
tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
#endif

#if defined(WIN32)
if(!args.mmap_path.empty()) {
int ret = 0;

try {
ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
ret = -1;
}

return ret;
}
#endif

RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
return 0;
}

Expand All @@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) {
*/
int main(int argc, char * argv[]) {
if (argc <= 1) {
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
return 0;
}

// Runs WSAStartup on Win32, no-op on POSIX
Socket::Startup();
#if defined(_WIN32)
SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1");
#endif

if (0 == strcmp(argv[1], "server")) {
RpcServer(argc, argv);
} else {
LOG(INFO) << kUSAGE;
return RpcServer(argc, argv);
}

LOG(INFO) << kUsage;

return 0;
}
Loading