Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 12, 2019
1 parent eb8aa51 commit 2b63f24
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 47 deletions.
40 changes: 21 additions & 19 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,17 @@ def get_valid_counts_gpu(data, score_threshold=0):
tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
lambda ins, outs: get_valid_counts_pre(
ins[0], outs[0], outs[1], score_threshold),
dtype=["int32", "int32"],
out_buffers=[temp_flag_buf, temp_idx_buf],
name="get_valid_counts_phase_one")
dtype=["int32", "int32"],
out_buffers=[temp_flag_buf, temp_idx_buf],
name="get_valid_counts_phase_one")

valid_count, out_tensor = \
tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx],
lambda ins, outs: get_valid_counts_ir(
ins[0], ins[1], ins[2], outs[0], outs[1]),
dtype=["int32", data.dtype],
in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
tag="get_valid_counts")
lambda ins, outs: get_valid_counts_ir(
ins[0], ins[1], ins[2], outs[0], outs[1]),
dtype=["int32", data.dtype],
in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
tag="get_valid_counts")

return [valid_count, out_tensor]

Expand Down Expand Up @@ -333,7 +333,8 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
with ib.for_range(0, nkeep) as j:
with ib.if_scope(k < box_data_length):
out[(base_idx + j * box_data_length + k)] = \
data[(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)]
data[(base_idx + sorted_index[i * num_anchors + j] \
* box_data_length + k)]
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
with ib.for_range(0, valid_count[i] - nkeep) as j:
Expand All @@ -351,11 +352,11 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
out[base_idx + offset_j] == \
out[base_idx + offset_k]))):
# When force_suppress == True or class_id equals
iou = calculate_overlap(out, base_idx + offset_k + 2,
base_idx + offset_j + 2)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k] = -1.0
box_indices[i * num_anchors + k] = -1
iou = calculate_overlap(out, base_idx + offset_k + 2,
base_idx + offset_j + 2)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k] = -1.0
box_indices[i * num_anchors + k] = -1
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
Expand Down Expand Up @@ -494,7 +495,8 @@ def invalid_to_bottom_ir(data, flag, idx, out):
out[base_idx + j * 6 + k] = -1.0
with ib.if_scope(flag[i * num_anchors + j] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[i * num_anchors + j] - 1) * 6 + k] = data[base_idx + j * 6 + k]
out[base_idx + (idx[i * num_anchors + j] - 1) * 6 + k] \
= data[base_idx + j * 6 + k]
return ib.get()


Expand Down Expand Up @@ -566,7 +568,6 @@ def non_max_supression_gpu(data, valid_count, max_output_size=-1,
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
Expand Down Expand Up @@ -608,8 +609,8 @@ def non_max_supression_gpu(data, valid_count, max_output_size=-1,
max_output_size, iou_threshold, force_suppress,
top_k, id_index),
dtype=[data.dtype, "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
out_buffers=[out_buf, box_indices_buf],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
out_buffers=[out_buf, box_indices_buf],
name="nms",
tag="nms")

Expand All @@ -623,7 +624,8 @@ def non_max_supression_gpu(data, valid_count, max_output_size=-1,
(batch_size, num_anchors,), valid_count_dtype, "temp_flag", data_alignment=8)
temp_idx_buf = api.decl_buffer(
(batch_size, num_anchors,), valid_count_dtype, "temp_idx", data_alignment=8)
temp_flag, temp_idx = tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [out],
temp_flag, temp_idx = tvm.extern([(batch_size, num_anchors,), \
(batch_size, num_anchors,)], [out],
lambda ins, outs: invalid_to_bottom_pre(
ins[0], outs[0], outs[1]),
dtype=["int32", "int32"],
Expand Down
50 changes: 22 additions & 28 deletions topi/python/topi/cuda/ssd/multibox.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, too-many-function-args
"""SSD multibox operators"""
from __future__ import absolute_import as _abs
import tvm
import math
import tvm

from tvm import api
from tvm.intrin import exp, if_then_else
from tvm.intrin import if_then_else

import topi

Expand Down Expand Up @@ -79,10 +79,9 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):

for k in range(num_sizes + num_ratios - 1):
w = if_then_else(k < num_sizes,
size_ratio_concat[
k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
Expand Down Expand Up @@ -174,13 +173,11 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
ib = tvm.ir_builder.create()

cls_prob = ib.buffer_ptr(cls_prob)
cls_id= ib.buffer_ptr(temp_cls_id)
cls_id = ib.buffer_ptr(temp_cls_id)
valid_count = ib.buffer_ptr(valid_count)
temp_valid_count = ib.buffer_ptr(temp_valid_count)
score = ib.buffer_ptr(temp_score)

box_coord = ib.allocate("float32", (4,), name="box_coord", scope="local")
pred_coord = ib.allocate("float32", (4,), name="pred_coord", scope="local")
threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)

max_threads = int(
Expand All @@ -201,9 +198,11 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
cls_id[i * num_anchors + j] = 0
with ib.for_range(0, num_classes-1) as k:
temp = cls_prob[i * num_classes * num_anchors + (k + 1) * num_anchors + j]
cls_id[i * num_anchors + j] = if_then_else(temp > score[i * num_anchors + j], k + 1, cls_id[i * num_anchors + j])
cls_id[i * num_anchors + j] = if_then_else(temp > score[i * num_anchors + j], \
k + 1, cls_id[i * num_anchors + j])
score[i * num_anchors + j] = tvm.max(temp, score[i * num_anchors + j])
with ib.if_scope(tvm.all(cls_id[i * num_anchors + j] > 0, score[i * num_anchors + j] < threshold)):
with ib.if_scope(tvm.all(cls_id[i * num_anchors + j] > 0, \
score[i * num_anchors + j] < threshold)):
cls_id[i * num_anchors + j] = 0
with ib.if_scope(cls_id[i * num_anchors + j] > 0):
temp_valid_count[i * num_anchors + j] = 1
Expand All @@ -213,14 +212,14 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors) as k:
with ib.if_scope(k > 0):
temp_valid_count[tid * num_anchors +
k] += temp_valid_count[tid * num_anchors + k - 1]
temp_valid_count[tid * num_anchors +k] += \
temp_valid_count[tid * num_anchors + k - 1]
valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1]

return ib.get()

def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score, out, \
clip, variances, batch_size, num_classes, num_anchors):
clip, variances, batch_size, num_anchors):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
Expand Down Expand Up @@ -252,9 +251,6 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
batch_size : int
Batch size
num_classes : int
Number of classes
num_anchors : int
Number of anchors
Expand Down Expand Up @@ -296,9 +292,6 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw,
score = ib.buffer_ptr(temp_score)
out_loc = ib.buffer_ptr(out)

box_coord = ib.allocate("float32", (4,), name="box_coord", scope="local")
pred_coord = ib.allocate("float32", (4,), name="pred_coord", scope="local")

max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
Expand All @@ -321,8 +314,8 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw,
out_loc[out_base_idx + 1] = score[tid]
out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \
out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4,
anchor, j * 4, clip, variances[0],
variances[1], variances[2], variances[3])
anchor, j * 4, clip, variances[0],
variances[1], variances[2], variances[3])

return ib.get()

Expand Down Expand Up @@ -363,7 +356,6 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
1-D tensor with shape (batch_size,), number of valid anchor boxes.
"""
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer
Expand All @@ -383,20 +375,21 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
(batch_size, num_anchors,), cls_prob.dtype, "temp_score", data_alignment=8)

valid_count, temp_valid_count, temp_cls_id, temp_score = \
tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), (batch_size, num_anchors,)],
[cls_prob],
tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \
(batch_size, num_anchors,)], [cls_prob],
lambda ins, outs: transform_loc_pre(
ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf, temp_valid_count_buf, temp_cls_id_buf, temp_score_buf],
out_buffers=[valid_count_buf, temp_valid_count_buf, \
temp_cls_id_buf, temp_score_buf],
tag="multibox_transform_loc_phase_one")

out_loc = \
tvm.extern([oshape],
[loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score],
lambda ins, outs: transform_loc_ir(
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \
batch_size, num_classes, num_anchors),
batch_size, num_anchors),
dtype=[out_loc_dtype],
out_buffers=[out_loc_buf],
tag="multibox_transform_loc")
Expand Down Expand Up @@ -446,5 +439,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = non_max_suppression(
inter_out[0], inter_out[1], -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
inter_out[0], inter_out[1], -1, nms_threshold, force_suppress, \
nms_topk, return_indices=False)
return out

0 comments on commit 2b63f24

Please sign in to comment.