Skip to content

Commit

Permalink
Add platform timer to microTVM.
Browse files Browse the repository at this point in the history
  • Loading branch information
areusch committed Dec 11, 2020
1 parent 94b2e44 commit b0dc298
Show file tree
Hide file tree
Showing 15 changed files with 358 additions and 71 deletions.
6 changes: 6 additions & 0 deletions apps/bundle_deploy/bundle.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,9 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLContext ctx, void*
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx) {
return g_memory_manager->Free(g_memory_manager, ptr, ctx);
}

tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorFunctionCallNotImplemented; }

tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorFunctionCallNotImplemented;
}
6 changes: 6 additions & 0 deletions apps/bundle_deploy/bundle_static.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,9 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLContext ctx, void*
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx) {
return g_memory_manager->Free(g_memory_manager, ptr, ctx);
}

tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorFunctionCallNotImplemented; }

tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorFunctionCallNotImplemented;
}
7 changes: 7 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,13 @@ TVM_DLL int TVMObjectRetain(TVMObjectHandle obj);
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

/*!
* \brief Free a TVMByteArray returned from TVMFuncCall, and associated memory.
* \param arr The TVMByteArray instance.
* \return 0 on success, -1 on failure.
*/
TVM_DLL int TVMByteArrayFree(TVMByteArray* arr);

/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/runtime/crt/error_codes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ typedef enum {
kTvmErrorCategoryGenerated = 6,
kTvmErrorCategoryGraphRuntime = 7,
kTvmErrorCategoryFunctionCall = 8,
kTvmErrorCategoryTimeEvaluator = 9,
} tvm_crt_error_category_t;

typedef enum {
Expand Down Expand Up @@ -77,6 +78,7 @@ typedef enum {
kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1),
kTvmErrorPlatformShutdown = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 2),
kTvmErrorPlatformNoMemory = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 3),
kTvmErrorPlatformTimerBadState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 4),

// Common error codes returned from generated functions.
kTvmErrorGeneratedInvalidStorageId = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryGenerated, 0),
Expand All @@ -91,6 +93,9 @@ typedef enum {
kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),

// Time Evaluator - times functions for use with debug runtime.
kTvmErrorTimeEvaluatorBadHandle = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryTimeEvaluator, 0),

// System errors are always negative integers; this mask indicates presence of a system error.
// Cast tvm_crt_error_t to a signed integer to interpret the negative error code.
kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)),
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/runtime/crt/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLContext ctx, void*
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx);

/*! \brief Start a device timer.
*
* The device timer used must not be running.
*
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformTimerStart();

/*! \brief Stop the running device timer and get the elapsed time (in microseconds).
*
* The device timer used must be running.
*
* \param elapsed_time_seconds Pointer to write elapsed time into.
*
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
1 change: 1 addition & 0 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..error import register_error
from .._ffi import get_global_func
from ..contrib import graph_runtime
from ..contrib.debugger import debug_runtime
from ..rpc import RPCSession
from .transport import IoTimeoutError
from .transport import TransportLogger
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,15 @@ int TVMFuncFree(TVMFunctionHandle func) {
API_END();
}

int TVMByteArrayFree(TVMByteArray* arr) {
if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) {
return 0; // Thread-local storage does not need explicit deleting.
}

delete arr;
return 0;
}

int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
API_BEGIN();
Expand Down
150 changes: 142 additions & 8 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ static const TVMModule* registered_modules[TVM_CRT_MAX_REGISTERED_MODULES];
/*! \brief Passed as `module_index` to EncodeFunctionHandle. */
static const tvm_module_index_t kGlobalFuncModuleIndex = TVM_CRT_MAX_REGISTERED_MODULES;

/*! \brief Special module handle for retur values from RPCTimeEvaluator. */
static const tvm_module_index_t kTimeEvaluatorModuleIndex = 0x7fff;

static int DecodeModuleHandle(TVMModuleHandle handle, tvm_module_index_t* out_module_index) {
tvm_module_index_t module_index;

Expand Down Expand Up @@ -185,20 +188,36 @@ static int DecodeFunctionHandle(TVMFunctionHandle handle, tvm_module_index_t* mo
(tvm_module_index_t)(((uintptr_t)handle) >> (sizeof(tvm_function_index_t) * 8));
unvalidated_module_index &= ~0x8000;

if (unvalidated_module_index > kGlobalFuncModuleIndex) {
TVMAPIErrorf("invalid module handle: index=%08x", unvalidated_module_index);
return -1;
} else if (unvalidated_module_index < kGlobalFuncModuleIndex &&
registered_modules[unvalidated_module_index] == NULL) {
TVMAPIErrorf("unregistered module: index=%08x", unvalidated_module_index);
return -1;
if (unvalidated_module_index != kTimeEvaluatorModuleIndex) {
if (unvalidated_module_index > kGlobalFuncModuleIndex) {
TVMAPIErrorf("invalid module handle: index=%08x", unvalidated_module_index);
return -1;
} else if (unvalidated_module_index < kGlobalFuncModuleIndex &&
registered_modules[unvalidated_module_index] == NULL) {
TVMAPIErrorf("unregistered module: index=%08x", unvalidated_module_index);
return -1;
}
}

*function_index = ((uint32_t)((uintptr_t)handle)) & ~0x8000;
*module_index = unvalidated_module_index;
return 0;
}

int TVMByteArrayFree(TVMByteArray* arr) {
DLContext ctx = {kDLCPU, 0};
int to_return = TVMPlatformMemoryFree((void*)arr->data, ctx);
if (to_return != 0) {
return to_return;
}

return TVMPlatformMemoryFree((void*)arr, ctx);
}

tvm_crt_error_t RunTimeEvaluator(tvm_function_index_t function_index, TVMValue* args,
int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);

int TVMFuncCall(TVMFunctionHandle func_handle, TVMValue* arg_values, int* type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
tvm_module_index_t module_index;
Expand All @@ -211,7 +230,10 @@ int TVMFuncCall(TVMFunctionHandle func_handle, TVMValue* arg_values, int* type_c
return -1;
}

if (module_index == kGlobalFuncModuleIndex) {
if (module_index == kTimeEvaluatorModuleIndex) {
return RunTimeEvaluator(function_index, arg_values, type_codes, num_args, ret_val,
ret_type_code);
} else if (module_index == kGlobalFuncModuleIndex) {
resource_handle = NULL;
registry = &global_func_registry.registry;
} else {
Expand Down Expand Up @@ -315,6 +337,8 @@ int TVMFuncFree(TVMFunctionHandle func) {
return 0;
}

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);
tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
Expand Down Expand Up @@ -351,10 +375,120 @@ tvm_crt_error_t TVMInitializeRuntime() {
error = TVMFuncRegisterGlobal("tvm.rpc.server.ModuleGetFunction", &ModuleGetFunction, 0);
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0);
}

if (error != kTvmErrorNoError) {
TVMPlatformMemoryFree(registry_backing_memory, ctx);
TVMPlatformMemoryFree(func_registry_memory, ctx);
}

return error;
}

typedef struct {
uint16_t function_index;
TVMFunctionHandle func_to_time;
TVMContext ctx;
int number;
int repeat;
int min_repeat_ms;
} time_evaluator_state_t;

static time_evaluator_state_t g_time_evaluator_state;

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code) {
ret_val[0].v_handle = NULL;
ret_type_code[0] = kTVMNullptr;
if (num_args < 8) {
TVMAPIErrorf("not enough args");
return kTvmErrorFunctionCallNumArguments;
}
if (type_codes[0] != kTVMModuleHandle || type_codes[1] != kTVMStr ||
type_codes[2] != kTVMArgInt || type_codes[3] != kTVMArgInt || type_codes[4] != kTVMArgInt ||
type_codes[5] != kTVMArgInt || type_codes[6] != kTVMArgInt || type_codes[7] != kTVMStr) {
TVMAPIErrorf("one or more invalid arg types");
return kTvmErrorFunctionCallWrongArgType;
}

TVMModuleHandle mod = (TVMModuleHandle)args[0].v_handle;
const char* name = args[1].v_str;
g_time_evaluator_state.ctx.device_type = args[2].v_int64;
g_time_evaluator_state.ctx.device_id = args[3].v_int64;
g_time_evaluator_state.number = args[4].v_int64;
g_time_evaluator_state.repeat = args[5].v_int64;
g_time_evaluator_state.min_repeat_ms = args[6].v_int64;

int ret_code =
TVMModGetFunction(mod, name, /* query_imports */ 0, &g_time_evaluator_state.func_to_time);
if (ret_code != 0) {
return ret_code;
}

g_time_evaluator_state.function_index++;
ret_val[0].v_handle =
EncodeFunctionHandle(kTimeEvaluatorModuleIndex, g_time_evaluator_state.function_index);
ret_type_code[0] = kTVMPackedFuncHandle;
return kTvmErrorNoError;
}

tvm_crt_error_t RunTimeEvaluator(tvm_function_index_t function_index, TVMValue* args,
int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code) {
if (function_index != g_time_evaluator_state.function_index) {
return kTvmErrorTimeEvaluatorBadHandle;
}

// TODO(areusch): should *really* rethink needing to return doubles
DLContext result_byte_ctx = {kDLCPU, 0};
TVMByteArray* result_byte_arr;
tvm_crt_error_t err =
TVMPlatformMemoryAllocate(sizeof(TVMByteArray), result_byte_ctx, (void*)&result_byte_arr);
if (err != kTvmErrorNoError) {
return err;
}
size_t data_size = sizeof(double) * g_time_evaluator_state.repeat;
err = TVMPlatformMemoryAllocate(data_size, result_byte_ctx, (void*)&result_byte_arr->data);
if (err != kTvmErrorNoError) {
return err;
}
result_byte_arr->size = data_size;
double min_repeat_seconds = ((double)g_time_evaluator_state.min_repeat_ms) / 1000;
double* iter = (double*)result_byte_arr->data;
for (int i = 0; i < g_time_evaluator_state.repeat; i++) {
double repeat_res_seconds = 0.0;
int exec_count = 0;
// do-while structure ensures we run even when `min_repeat_ms` isn't set (i.e., is 0).
do {
tvm_crt_error_t ret_code = TVMPlatformTimerStart();
if (ret_code != kTvmErrorNoError) {
return ret_code;
}

for (int j = 0; j < g_time_evaluator_state.number; j++) {
ret_code = TVMFuncCall(g_time_evaluator_state.func_to_time, args, type_codes, num_args,
ret_val, ret_type_code);
if (ret_code != 0) {
return ret_code;
}
}
exec_count += g_time_evaluator_state.number;

double curr_res_seconds;
ret_code = TVMPlatformTimerStop(&curr_res_seconds);
if (ret_code != kTvmErrorNoError) {
return ret_code;
}
repeat_res_seconds += curr_res_seconds;
} while (repeat_res_seconds < min_repeat_seconds);
double mean_exec_seconds = repeat_res_seconds / exec_count;
*iter = mean_exec_seconds;
iter++;
}

*ret_type_code = kTVMBytes;
ret_val->v_handle = result_byte_arr;
return kTvmErrorNoError;
}
2 changes: 1 addition & 1 deletion src/runtime/crt/host/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#define TVM_CRT_MAX_REGISTERED_MODULES 2

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
Expand Down
23 changes: 12 additions & 11 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,30 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLContext ctx) {
return memory_manager->Free(memory_manager, ptr, ctx);
}

high_resolution_clock::time_point g_utvm_start_time;
steady_clock::time_point g_utvm_start_time;
int g_utvm_timer_running = 0;

int TVMPlatformTimerStart() {
tvm_crt_error_t TVMPlatformTimerStart() {
if (g_utvm_timer_running) {
std::cerr << "timer already running" << std::endl;
return -1;
return kTvmErrorPlatformTimerBadState;
}
g_utvm_start_time = high_resolution_clock::now();
g_utvm_start_time = std::chrono::steady_clock::now();
g_utvm_timer_running = 1;
return 0;
return kTvmErrorNoError;
}

int TVMPlatformTimerStop(double* res_us) {
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
if (!g_utvm_timer_running) {
std::cerr << "timer not running" << std::endl;
return -1;
return kTvmErrorPlatformTimerBadState;
}
auto utvm_stop_time = high_resolution_clock::now();
duration<double, std::micro> time_span(utvm_stop_time - g_utvm_start_time);
*res_us = time_span.count();
auto utvm_stop_time = std::chrono::steady_clock::now();
std::chrono::microseconds time_span =
std::chrono::duration_cast<std::chrono::microseconds>(utvm_stop_time - g_utvm_start_time);
*elapsed_time_seconds = static_cast<double>(time_span.count()) / 1e6;
g_utvm_timer_running = 0;
return 0;
return kTvmErrorNoError;
}
}

Expand Down
Loading

0 comments on commit b0dc298

Please sign in to comment.