Skip to content

Commit

Permalink
Fix pylint 2.2.2 gripes. (apache#2642)
Browse files Browse the repository at this point in the history
  • Loading branch information
mshawcroft authored and tqchen committed Feb 21, 2019
1 parent 282c063 commit 5c49b07
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 26 deletions.
5 changes: 2 additions & 3 deletions python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,9 @@ def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim":
if self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)
raise ValueError("Unknown target %s" % self.TARGET)


def get_env():
Expand Down
2 changes: 1 addition & 1 deletion python/vta/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _clean_cast(node, target_type):
op_name = node.attr("op_name")
if op_name == "cast":
return _clean_cast(node.get_children(), target_type)
elif op_name == "relu":
if op_name == "relu":
data, has_clip = _clean_cast(
node.get_children(), target_type)
data = nnvm.sym.relu(data)
Expand Down
2 changes: 1 addition & 1 deletion python/vta/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def instr(index):
dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope",
dev.vta_push_uop)
if index == 0 or index == 2:
if index in (0, 2):
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 0,
Expand Down
29 changes: 12 additions & 17 deletions python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ def _post_order(op):
args.append(m[1])
args += op.args[base_args+3:]
return tvm.call_extern("int32", "VTAUopPush", *args)
else:
if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op)
return op
if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op)
return op

ret = tvm.ir_pass.IRTransform(
stmt.body, None, _post_order, ["Call"])
Expand Down Expand Up @@ -165,22 +164,21 @@ def _post_order(op):
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
elif isinstance(op, tvm.expr.Load):
if isinstance(op, tvm.expr.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Load(op.dtype, new_var, op.index)
elif isinstance(op, tvm.stmt.Store):
if isinstance(op, tvm.stmt.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Store(new_var, op.value, op.index)
else:
raise RuntimeError("not reached")
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
Expand Down Expand Up @@ -233,23 +231,20 @@ def _pre_order(op):
if op.attr_key == "virtual_thread":
lift_stmt.append([])

return None

def _post_order(op):
if isinstance(op, tvm.stmt.Allocate):
lift_stmt[-1].append(op)
return op.body
elif isinstance(op, tvm.stmt.AttrStmt):
if isinstance(op, tvm.stmt.AttrStmt):
if op.attr_key == "storage_scope":
lift_stmt[-1].append(op)
return op.body
elif op.attr_key == "virtual_thread":
if op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
elif isinstance(op, tvm.stmt.For):
if isinstance(op, tvm.stmt.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
else:
raise RuntimeError("not reached")
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
Expand Down Expand Up @@ -297,7 +292,7 @@ def _do_fold(stmt):
sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
elif _match_pragma(stmt, "trim_loop"):
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.stmt.For)
return tvm.make.For(
Expand Down Expand Up @@ -584,7 +579,7 @@ def _do_fold(stmt):
tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
elif _match_pragma(stmt, "skip_alu"):
if _match_pragma(stmt, "skip_alu"):
return tvm.make.Evaluate(0)
return stmt

Expand Down
7 changes: 3 additions & 4 deletions python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)

Expand Down Expand Up @@ -279,10 +279,9 @@ def schedule_conv2d(attrs, outs, target):
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d(outs)
elif str(target).startswith("llvm"):
if str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
raise RuntimeError("not support target %s" % target)
return _nn.schedule_conv2d(attrs, outs, target)


Expand Down

0 comments on commit 5c49b07

Please sign in to comment.