Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige committed Nov 5, 2020
1 parent 9a0f7a6 commit 2a6514c
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 62 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
The list of all valid op implementations.
"""
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have an FTVMStrategy registered. You can register " \
"one in python with `tvm.relay.op.register_strategy`." % op.name
assert fstrategy is not None, (
"%s doesn't have an FTVMStrategy registered. You can register "
"one in python with `tvm.relay.op.register_strategy`." % op.name
)
with target:
strategy = fstrategy(attrs, inputs, out_type, target)
analyzer = tvm.arith.Analyzer()
Expand Down
7 changes: 0 additions & 7 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
squeeze,
strided_set,
arange,
gather_nd,
scatter_nd,
)

Expand Down Expand Up @@ -811,9 +810,3 @@ def arange_grad(orig, grad):
def gather_nd_grad(orig, grad):
data, indices = orig.args
return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]


# @register_gradient("scatter_nd")
# def scatter_nd_grad(orig, grad):
# data, indices = orig.args
# return [gather_nd(grad, indices), zeros_like(indices)]
5 changes: 4 additions & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,17 +1059,20 @@ def schedule_scatter_add(attrs, outs, target):
with target:
return topi.generic.schedule_scatter_add(outs)


# scatter_nd
@override_native_generic_func("scatter_nd_strategy")
def scatter_nd_strategy(attrs, inputs, out_type, target):
"""scatter_nd generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.scatter_nd), wrap_topi_schedule(topi.generic.schedule_extern),
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.generic",
)
return strategy


def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""

Expand Down
6 changes: 5 additions & 1 deletion python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,11 @@ def extern(
if isinstance(body, tvm.tir.PrimExpr):
body = tvm.tir.Evaluate(body)
if not isinstance(body, tvm.tir.Stmt):
raise ValueError("Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(fcompute.__name__, type(body)))
raise ValueError(
"Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(
fcompute.__name__, type(body)
)
)

op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body)
res = [op.output(i) for i in range(len(output_placeholders))]
Expand Down
29 changes: 0 additions & 29 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,33 +714,4 @@ def func(f):
return wrap(args)


def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule):
"""Compare a numpy inputs and output of a function to the results of the TVM version.
Parameters
----------
inputs : Sequence[numpy.nd.array]
List of input numpy arrays to pass to the function.
output : numpy.nd.array
Verified correct function output.
target : tvm.target.Target
Target to run on.
ctx : tvm.TVMContext
Context to run on.
compute : callable
Topi compute function to test against.
schedule : callable
Topi scheduling function to test against.
"""
te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs]
te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx)
with tvm.target.Target(target):
out = compute(*te_inputs)
s = schedule([out])
func = tvm.build(s, te_inputs + [out])
arys = [tvm.nd.array(x, ctx=ctx) for x in inputs]
func(*(arys + [te_out]))
assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4)


tvm._ffi._init_api("testing", __name__)
32 changes: 23 additions & 9 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Scatter operator"""
from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate
from ..te import extern, hybrid
from . import full


@hybrid.script
Expand Down Expand Up @@ -233,14 +232,23 @@ def scatter_nd(data, indices, shape):
-------
ret : tvm.te.Tensor
"""
assert indices.shape[0] <= len(shape), f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to the length of the shape of the output ({len(shape)})."
for i in range(len(indices.shape)-1):
assert indices.shape[i+1] == data.shape[i], f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of data[{i}] ({data.shape[i]})."
assert indices.shape[0] <= len(shape), (
f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to "
f"the length of the shape of the output ({len(shape)})."
)
for i in range(len(indices.shape) - 1):
assert indices.shape[i + 1] == data.shape[i], (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
f"data[{i}] ({data.shape[i]})."
)
for i in range(int(indices.shape[0]), len(shape)):
assert data.shape[i] == out_shape[i], f"Dimension of data[{i}] must equal dimension of out_shape[{i}]"

assert "int" in indices.dtype, f"Indices must be a tensor of integers, but its elements are {indices.dtype}"
assert (
data.shape[i] == out_shape[i]
), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]"

assert (
"int" in indices.dtype
), f"Indices must be a tensor of integers, but its elements are {indices.dtype}"

def gen_ir(data_ptr, indices_ptr, out_ptr):
ib = ir_builder.create()
Expand All @@ -254,7 +262,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
fused_shape = 1
for i in shape:
fused_shape *= i
with ib.for_range(0, fused_shape):
with ib.for_range(0, fused_shape) as i:
out[i] = Cast(data_ptr.dtype, 0)

# We combine all the indices dimensions but the first one into a single
Expand All @@ -277,7 +285,13 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
ib.emit(AssertStmt(indices[i + l * fused_indices_dimension] < shape[l], StringImm("index out of bounds"), Evaluate(0)))
ib.emit(
AssertStmt(
indices[i + l * fused_indices_dimension] < shape[l],
StringImm("index out of bounds"),
Evaluate(0),
)
)
offset *= shape[l]
out[index] = data[i * fused_data_dimension + j]

Expand Down
29 changes: 29 additions & 0 deletions python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,32 @@ def get_reduce_schedule(target):

def get_conv2d_nchw_implement(target):
return dispatch(target, _conv2d_nchw_implement)


def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule):
"""Compare a numpy inputs and output of a function to the results of the TVM version.
Parameters
----------
inputs : Sequence[numpy.nd.array]
List of input numpy arrays to pass to the function.
output : numpy.nd.array
Verified correct function output.
target : tvm.target.Target
Target to run on.
ctx : tvm.TVMContext
Context to run on.
compute : callable
Topi compute function to test against.
schedule : callable
Topi scheduling function to test against.
"""
te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs]
te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx)
with tvm.target.Target(target):
out = compute(*te_inputs)
s = schedule([out])
func = tvm.build(s, te_inputs + [out])
arys = [tvm.nd.array(x, ctx=ctx) for x in inputs]
func(*(arys + [te_out]))
assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4)
9 changes: 0 additions & 9 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,5 @@ def test_gather_nd_grad():
check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np])


# def test_scatter_nd_grad():
# data = relay.var("data", relay.TensorType((2, 2), "float64"))
# indices = relay.var("indices", relay.TensorType((2, 2), "int64"))
# fwd = relay.Function([data, indices], relay.scatter_nd(data, indices, (2, 2)))
# data_np = np.array([[0, 1], [2, 3]]).astype("float64")
# indices_np = np.array([[1, 0], [1, 1]])
# check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np])


if __name__ == "__main__":
pytest.main()
7 changes: 3 additions & 4 deletions tests/python/topi/python/test_topi_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
def test_scatter_nd(ctx, target):
def check_scatter_nd(data, indices, shape, out):
implementations = {
"generic": (lambda x,y: topi.scatter_nd(x,y,shape), topi.generic.schedule_extern),
"generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.testing.compare_numpy_tvm(
[data, indices], out, target, ctx, fcompute, fschedule
)
tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule)

data = np.array([2, 3, 0])
indices = np.array([[1, 1, 0], [0, 1, 0]])
Expand All @@ -44,5 +42,6 @@ def check_scatter_nd(data, indices, shape, out):
out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]])
check_scatter_nd(data, indices, shape, out)


if __name__ == "__main__":
test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm"))

0 comments on commit 2a6514c

Please sign in to comment.