Skip to content

Commit

Permalink
[VTA][SIM] Allow debug mode in simulator to skip execution (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and tmoreau89 committed Dec 1, 2018
1 parent 9b32883 commit 35f8b96
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 88 deletions.
13 changes: 13 additions & 0 deletions vta/python/vta/testing/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def clear_stats():
if f:
f()

# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1

def debug_mode(flag):
"""Set debug mode
Paramaters
----------
flag : int
The debug flag, 0 means clear all flags.
"""
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)


def stats():
"""Clear profiler statistics
Expand Down
81 changes: 0 additions & 81 deletions vta/python/vta/top/arm_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,87 +5,6 @@
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic

_WORKLOADS = [
# resnet 18
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),

# mobilenet float32
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),

# mobilenet int8
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
]

_SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),

# float32 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
SpatialPack(1, 4, 8, 8, 8, True),
SpatialPack(2, 2, 8, 1, 1, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(2, 2, 8, 1, 4, False),
SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True),

# int8 mobilenet
SpatialPack(2, 2, 4, 28, 1, True),
SpatialPack(1, 4, 8, 14, 1, False),
SpatialPack(1, 2, 16, 8, 1, True),
SpatialPack(1, 4, 8, 8, 8, True),
SpatialPack(2, 2, 8, 1, 1, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(2, 2, 8, 1, 4, False),
SpatialPack(2, 2, 8, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
Im2ColPack(7, 4, 1, 4, True),
]

@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
Expand Down
38 changes: 32 additions & 6 deletions vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
namespace vta {
namespace sim {

/*! \brief debug flag for skipping computation */
enum DebugFlagMask {
kSkipExec = 1
};

/*!
* \brief Helper class to pack and unpack bits
* Applies truncation when pack to low level bits.
Expand Down Expand Up @@ -234,8 +239,12 @@ class SRAM {
return &(data_[index]);
}
// Execute the load instruction on this SRAM
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
void Load(const VTAMemInsn* op,
DRAM* dram,
uint64_t* load_counter,
bool skip_exec) {
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
if (skip_exec) return;
DType* sram_ptr = data_ + op->sram_base;
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
op->dram_base * kElemBytes));
Expand Down Expand Up @@ -306,6 +315,8 @@ class Profiler {
uint64_t gemm_counter{0};
/*! \brief instr counter for ALU ops */
uint64_t alu_counter{0};
/*! \brief set debug mode */
int64_t debug_flag{0};
/*! \brief clear the profiler */
void Clear() {
inp_load_nbytes = 0;
Expand All @@ -316,6 +327,10 @@ class Profiler {
gemm_counter = 0;
alu_counter = 0;
}
/*! \return Whether we should skip execution. */
bool SkipExec() const {
return (debug_flag & DebugFlagMask::kSkipExec) != 0;
}

std::string AsJSON() {
std::ostringstream os;
Expand Down Expand Up @@ -379,13 +394,15 @@ class Device {
void RunLoad(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_INP) {
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_WGT) {
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_ACC) {
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec());
} else if (op->memory_type == VTA_MEM_ID_UOP) {
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
// always load in uop, since uop is stateful
// subsequent non-debug mode exec can depend on it.
uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false);
} else {
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
}
Expand All @@ -397,7 +414,9 @@ class Device {
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
if (!prof_->SkipExec()) {
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
}
} else {
LOG(FATAL) << "Store do not support memory_type="
<< op->memory_type;
Expand All @@ -407,6 +426,7 @@ class Device {
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
Expand Down Expand Up @@ -440,6 +460,7 @@ class Device {
}
}
} else {
if (prof_->SkipExec()) return;
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
Expand Down Expand Up @@ -506,6 +527,7 @@ class Device {
template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
prof_->alu_counter += op->iter_out * op->iter_in * op->uop_end - op->uop_bgn;
if (prof_->SkipExec()) return;
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
Expand Down Expand Up @@ -548,6 +570,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->Clear();
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->debug_flag = args[0];
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Profiler::ThreadLocal()->AsJSON();
Expand Down
10 changes: 9 additions & 1 deletion vta/tests/python/unittest/test_vta_insn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,16 @@ def verify(s):

if env.TARGET == "sim":
simulator.clear_stats()
simulator.debug_mode(simulator.DEBUG_SKIP_EXEC)
f(x_nd, w_nd, y_nd)
print(simulator.stats())
stat1 = simulator.stats()
simulator.clear_stats()
simulator.debug_mode(0)
f(x_nd, w_nd, y_nd)
stat2 = simulator.stats()
for k, v in stat1.items():
if k != "uop_load_nbytes":
assert stat1[k] == stat2[k]
else:
f(x_nd, w_nd, y_nd)

Expand Down

0 comments on commit 35f8b96

Please sign in to comment.