From 35f8b96b1f65833659179fe78c6d48264b3ff910 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 5 Aug 2018 16:18:49 -0700 Subject: [PATCH] [VTA][SIM] Allow debug mode in simulator to skip execution (#6) --- vta/python/vta/testing/simulator.py | 13 ++++ vta/python/vta/top/arm_conv2d.py | 81 ---------------------- vta/src/sim/sim_driver.cc | 38 ++++++++-- vta/tests/python/unittest/test_vta_insn.py | 10 ++- 4 files changed, 54 insertions(+), 88 deletions(-) diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py index 5695c9419c8bb..bf98bbbf97fc8 100644 --- a/vta/python/vta/testing/simulator.py +++ b/vta/python/vta/testing/simulator.py @@ -27,6 +27,19 @@ def clear_stats(): if f: f() +# debug flag to skip execution. +DEBUG_SKIP_EXEC = 1 + +def debug_mode(flag): + """Set debug mode + + Paramaters + ---------- + flag : int + The debug flag, 0 means clear all flags. + """ + tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag) + def stats(): """Clear profiler statistics diff --git a/vta/python/vta/top/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py index e3acb7a202df5..012c16b098ed4 100644 --- a/vta/python/vta/top/arm_conv2d.py +++ b/vta/python/vta/top/arm_conv2d.py @@ -5,87 +5,6 @@ from topi.nn import conv2d, conv2d_alter_layout from topi import generic -_WORKLOADS = [ - # resnet 18 - Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), - Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), - Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - - # mobilenet float32 - Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2), - Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), - - # mobilenet int8 - Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2), - Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), - Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), -] - -_SCHEDULES = [ - # float32 imagenet - SpatialPack(1, 8, 4, 1, 4, True), - SpatialPack(1, 8, 4, 1, 4, True), - SpatialPack(1, 7, 4, 2, 4, True), - SpatialPack(1, 4, 8, 4, 1, True), - SpatialPack(1, 4, 4, 1, 16, False), - SpatialPack(1, 4, 8, 4, 8, False), - SpatialPack(1, 7, 4, 3, 8, True), - SpatialPack(1, 2, 8, 1, 8, True), - SpatialPack(2, 1, 16, 1, 4, True), - SpatialPack(1, 7, 4, 1, 1, True), - Im2ColPack(7, 4, 1, 16, True), - Im2ColPack(7, 4, 1, 8, False), - Im2ColPack(7, 4, 1, 16, False), - - # float32 mobilenet - SpatialPack(2, 2, 4, 28, 1, True), - SpatialPack(1, 4, 8, 14, 1, False), - SpatialPack(1, 2, 16, 8, 1, True), - SpatialPack(1, 4, 8, 8, 8, True), - SpatialPack(2, 2, 8, 1, 1, False), - SpatialPack(1, 4, 8, 4, 8, False), - SpatialPack(2, 2, 8, 1, 4, False), - SpatialPack(2, 2, 8, 1, 8, False), - Im2ColPack(7, 4, 1, 16, False), - Im2ColPack(7, 4, 1, 4, True), - - # int8 mobilenet - SpatialPack(2, 2, 4, 28, 1, True), - SpatialPack(1, 4, 8, 14, 1, False), - SpatialPack(1, 2, 16, 8, 1, True), - SpatialPack(1, 4, 8, 8, 8, True), - SpatialPack(2, 2, 8, 1, 1, False), - SpatialPack(1, 4, 8, 4, 8, False), - SpatialPack(2, 2, 8, 1, 4, False), - SpatialPack(2, 2, 8, 1, 8, False), - Im2ColPack(7, 4, 1, 16, False), - Im2ColPack(7, 4, 1, 4, True), -] @conv2d.register(["vtacpu", "vta"]) def compute(*args, **kwargs): diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc index 58d322a07c8ad..ef1f62842f01a 100644 --- a/vta/src/sim/sim_driver.cc +++ b/vta/src/sim/sim_driver.cc @@ -16,6 +16,11 @@ namespace vta { namespace sim { +/*! \brief debug flag for skipping computation */ +enum DebugFlagMask { + kSkipExec = 1 +}; + /*! * \brief Helper class to pack and unpack bits * Applies truncation when pack to low level bits. @@ -234,8 +239,12 @@ class SRAM { return &(data_[index]); } // Execute the load instruction on this SRAM - void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) { + void Load(const VTAMemInsn* op, + DRAM* dram, + uint64_t* load_counter, + bool skip_exec) { load_counter[0] += (op->x_size * op->y_size) * kElemBytes; + if (skip_exec) return; DType* sram_ptr = data_ + op->sram_base; uint8_t* dram_ptr = static_cast(dram->GetAddr( op->dram_base * kElemBytes)); @@ -306,6 +315,8 @@ class Profiler { uint64_t gemm_counter{0}; /*! \brief instr counter for ALU ops */ uint64_t alu_counter{0}; + /*! \brief set debug mode */ + int64_t debug_flag{0}; /*! \brief clear the profiler */ void Clear() { inp_load_nbytes = 0; @@ -316,6 +327,10 @@ class Profiler { gemm_counter = 0; alu_counter = 0; } + /*! \return Whether we should skip execution. */ + bool SkipExec() const { + return (debug_flag & DebugFlagMask::kSkipExec) != 0; + } std::string AsJSON() { std::ostringstream os; @@ -379,13 +394,15 @@ class Device { void RunLoad(const VTAMemInsn* op) { if (op->x_size == 0) return; if (op->memory_type == VTA_MEM_ID_INP) { - inp_.Load(op, dram_, &(prof_->inp_load_nbytes)); + inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec()); } else if (op->memory_type == VTA_MEM_ID_WGT) { - wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes)); + wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec()); } else if (op->memory_type == VTA_MEM_ID_ACC) { - acc_.Load(op, dram_, &(prof_->acc_load_nbytes)); + acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec()); } else if (op->memory_type == VTA_MEM_ID_UOP) { - uop_.Load(op, dram_, &(prof_->uop_load_nbytes)); + // always load in uop, since uop is stateful + // subsequent non-debug mode exec can depend on it. + uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false); } else { LOG(FATAL) << "Unknown memory_type=" << op->memory_type; } @@ -397,7 +414,9 @@ class Device { op->memory_type == VTA_MEM_ID_UOP) { prof_->out_store_nbytes += ( op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8); - acc_.TruncStore(op, dram_); + if (!prof_->SkipExec()) { + acc_.TruncStore(op, dram_); + } } else { LOG(FATAL) << "Store do not support memory_type=" << op->memory_type; @@ -407,6 +426,7 @@ class Device { void RunGEMM(const VTAGemInsn* op) { if (!op->reset_reg) { prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn); + if (prof_->SkipExec()) return; for (uint32_t y = 0; y < op->iter_out; ++y) { for (uint32_t x = 0; x < op->iter_in; ++x) { for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) { @@ -440,6 +460,7 @@ class Device { } } } else { + if (prof_->SkipExec()) return; // reset for (uint32_t y = 0; y < op->iter_out; ++y) { for (uint32_t x = 0; x < op->iter_in; ++x) { @@ -506,6 +527,7 @@ class Device { template void RunALULoop(const VTAAluInsn* op, F func) { prof_->alu_counter += op->iter_out * op->iter_in * op->uop_end - op->uop_bgn; + if (prof_->SkipExec()) return; for (int y = 0; y < op->iter_out; ++y) { for (int x = 0; x < op->iter_in; ++x) { for (int k = op->uop_bgn; k < op->uop_end; ++k) { @@ -548,6 +570,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear") .set_body([](TVMArgs args, TVMRetValue* rv) { Profiler::ThreadLocal()->Clear(); }); +TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Profiler::ThreadLocal()->debug_flag = args[0]; + }); TVM_REGISTER_GLOBAL("vta.simulator.profiler_status") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = Profiler::ThreadLocal()->AsJSON(); diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py index b9803684b22f2..5c5cab096f3bd 100644 --- a/vta/tests/python/unittest/test_vta_insn.py +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -183,8 +183,16 @@ def verify(s): if env.TARGET == "sim": simulator.clear_stats() + simulator.debug_mode(simulator.DEBUG_SKIP_EXEC) f(x_nd, w_nd, y_nd) - print(simulator.stats()) + stat1 = simulator.stats() + simulator.clear_stats() + simulator.debug_mode(0) + f(x_nd, w_nd, y_nd) + stat2 = simulator.stats() + for k, v in stat1.items(): + if k != "uop_load_nbytes": + assert stat1[k] == stat2[k] else: f(x_nd, w_nd, y_nd)