Skip to content

Commit

Permalink
[AUTOMATION] Infrastructure to use hardware and schedules from vta-ex…
Browse files Browse the repository at this point in the history
…periments (#7)
  • Loading branch information
tmoreau89 committed Dec 1, 2018
1 parent 35f8b96 commit aa3696a
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 18 deletions.
21 changes: 11 additions & 10 deletions vta/config/pynq_sample.json
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
{
"TARGET" : "pynq",
"HW_VER" : "0.0.1",
"HW_VER" : "0.0.2",
"HW_FREQ" : 100,
"HW_CLK_TARGET" : 8,
"ALU" : true,
"GEMM_II" : 2,
"TALU_II" : 4,
"HW_CLK_TARGET" : 7,
"ALU_EN" : true,
"MUL_EN" : true,
"GEMM_II" : 1,
"TALU_II" : 2,
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 1,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_OUT_WIDTH" : 3,
"LOG_BATCH" : 0,
"LOG_BLOCK_IN" : 5,
"LOG_BLOCK_OUT" : 5,
"LOG_BLOCK_IN" : 4,
"LOG_BLOCK_OUT" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 17,
"LOG_WGT_BUFF_SIZE" : 17,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
6 changes: 3 additions & 3 deletions vta/hardware/xilinx/sim/vta_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ int main(void) {
#endif // ALU_EN

// Run blocked GEMM test
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2);
// status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2);
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, false, 2);
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 1);
// status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 1);
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, false, 1);

// Simple GEMM unit test
status |= gemm_test(4 * VTA_BATCH, 4 * VTA_BLOCK_OUT, 4 * VTA_BLOCK_IN, true);
status |= gemm_test(4 * VTA_BATCH, 4 * VTA_BLOCK_OUT, 4 * VTA_BLOCK_IN, false);

return status;
}
1 change: 0 additions & 1 deletion vta/python/vta/bitstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def get_bitstream_path():

# Derive destination path
cache_dir = os.getenv("VTA_CACHE_PATH", os.path.join(os.getenv("HOME"), ".vta_cache/"))
cache_dir = os.path.join(cache_dir, env.TARGET)
# Create the directory if it didn't exist
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
Expand Down
2 changes: 1 addition & 1 deletion vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, cfg):
self._dev_ctx = None
self._last_env = None
# derive bitstream name
self.BITSTREAM = "{}_{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
self.BITSTREAM = "{}/bitstreams_valid/{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
self.HW_VER.replace('.', '_'),
self.BATCH,
self.BLOCK_IN,
Expand Down
21 changes: 19 additions & 2 deletions vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import tvm
import topi
import re

from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
Expand Down Expand Up @@ -340,7 +341,7 @@ def _get_workload(data, pad_data, kernel, output):
w_str = (i_w + w_pad*2 - k_w) // (o_w - 1)
return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str)

def schedule_packed_conv2d(outs, plan=None, skip_load_inp=False, skip_load_wgt=False,
def schedule_packed_conv2d(outs, planStr=None, skip_load_inp=False, skip_load_wgt=False,
skip_load_acc=False, skip_store_out=False, skip_alu=False,
skip_gemm=False):
""" Schedule the packed conv2d.
Expand Down Expand Up @@ -377,7 +378,23 @@ def _traverse(op):
else:
pad_data = None
wrkld = _get_workload(data, pad_data, kernel, output)
if plan is None:
if planStr:
matchObj = re.match( r'b(\d+)oc(\d+)ic(\d+)h(\d+)w(\d+)oc_t(\d+)h_t(\d+)', sched_str)
b_factor = int(matchObj.group(1))
oc_factor = int(matchObj.group(2))
ic_factor = int(matchObj.group(3))
h_factor = int(matchObj.group(4))
w_factor = int(matchObj.group(5))
oc_nthread = int(matchObj.group(6))
h_nthread = int(matchObj.group(7))
plan = Schedule(b_factor=b_factor,
oc_factor=oc_factor,
ic_factor=ic_factor,
h_factor=h_factor,
w_factor=w_factor,
oc_nthread=oc_nthread,
h_nthread=h_nthread)
else:
plan = find_schedules(wrkld, vt_only=True, best_only=True)[0]
logging.info("Trying to find plan for %s", wrkld)
env = get_env()
Expand Down
2 changes: 1 addition & 1 deletion vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ class Device {

template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
prof_->alu_counter += op->iter_out * op->iter_in * op->uop_end - op->uop_bgn;
prof_->alu_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
if (prof_->SkipExec()) return;
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
Expand Down

0 comments on commit aa3696a

Please sign in to comment.