From 2a6514c7c9acd503f399ea80d3f6394178effe56 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 5 Nov 2020 09:57:07 -0800 Subject: [PATCH] formatting --- python/tvm/relay/backend/compile_engine.py | 6 ++-- python/tvm/relay/op/_tensor_grad.py | 7 ---- python/tvm/relay/op/strategy/generic.py | 5 ++- python/tvm/te/operation.py | 6 +++- python/tvm/testing.py | 29 ----------------- python/tvm/topi/scatter.py | 32 +++++++++++++------ python/tvm/topi/testing/common.py | 29 +++++++++++++++++ tests/python/relay/test_op_grad_level3.py | 9 ------ tests/python/topi/python/test_topi_scatter.py | 7 ++-- 9 files changed, 68 insertions(+), 62 deletions(-) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 14e1f5d85d9fd..43643a7be7455 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -121,8 +121,10 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): The list of all valid op implementations. """ fstrategy = op.get_attr("FTVMStrategy") - assert fstrategy is not None, "%s doesn't have an FTVMStrategy registered. You can register " \ - "one in python with `tvm.relay.op.register_strategy`." % op.name + assert fstrategy is not None, ( + "%s doesn't have an FTVMStrategy registered. You can register " + "one in python with `tvm.relay.op.register_strategy`." % op.name + ) with target: strategy = fstrategy(attrs, inputs, out_type, target) analyzer = tvm.arith.Analyzer() diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index b200aa1252705..9c84411352f2d 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -62,7 +62,6 @@ squeeze, strided_set, arange, - gather_nd, scatter_nd, ) @@ -811,9 +810,3 @@ def arange_grad(orig, grad): def gather_nd_grad(orig, grad): data, indices = orig.args return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)] - - -# @register_gradient("scatter_nd") -# def scatter_nd_grad(orig, grad): -# data, indices = orig.args -# return [gather_nd(grad, indices), zeros_like(indices)] diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6d7db7bc98eea..89a08e77f38fe 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1059,17 +1059,20 @@ def schedule_scatter_add(attrs, outs, target): with target: return topi.generic.schedule_scatter_add(outs) + # scatter_nd @override_native_generic_func("scatter_nd_strategy") def scatter_nd_strategy(attrs, inputs, out_type, target): """scatter_nd generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_scatter_nd(topi.scatter_nd), wrap_topi_schedule(topi.generic.schedule_extern), + wrap_compute_scatter_nd(topi.scatter_nd), + wrap_topi_schedule(topi.generic.schedule_extern), name="scatter_nd.generic", ) return strategy + def wrap_compute_scatter_nd(topi_compute): """Wrap scatter_nd topi compute""" diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index a924c8b0c0dbb..0f3457af0f10f 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -317,7 +317,11 @@ def extern( if isinstance(body, tvm.tir.PrimExpr): body = tvm.tir.Evaluate(body) if not isinstance(body, tvm.tir.Stmt): - raise ValueError("Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(fcompute.__name__, type(body))) + raise ValueError( + "Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format( + fcompute.__name__, type(body) + ) + ) op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body) res = [op.output(i) for i in range(len(output_placeholders))] diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 7286aa4bbbd9c..e5b17f3d7b534 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -714,33 +714,4 @@ def func(f): return wrap(args) -def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): - """Compare a numpy inputs and output of a function to the results of the TVM version. - - Parameters - ---------- - inputs : Sequence[numpy.nd.array] - List of input numpy arrays to pass to the function. - output : numpy.nd.array - Verified correct function output. - target : tvm.target.Target - Target to run on. - ctx : tvm.TVMContext - Context to run on. - compute : callable - Topi compute function to test against. - schedule : callable - Topi scheduling function to test against. - """ - te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] - te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) - with tvm.target.Target(target): - out = compute(*te_inputs) - s = schedule([out]) - func = tvm.build(s, te_inputs + [out]) - arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] - func(*(arys + [te_out])) - assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) - - tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 382a91e790c68..6c1d1ab39176d 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,7 +18,6 @@ """Scatter operator""" from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate from ..te import extern, hybrid -from . import full @hybrid.script @@ -233,14 +232,23 @@ def scatter_nd(data, indices, shape): ------- ret : tvm.te.Tensor """ - assert indices.shape[0] <= len(shape), f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to the length of the shape of the output ({len(shape)})." - for i in range(len(indices.shape)-1): - assert indices.shape[i+1] == data.shape[i], f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of data[{i}] ({data.shape[i]})." + assert indices.shape[0] <= len(shape), ( + f"The first dimension of the indices ({indices.shape[0]}) must be less than or equal to " + f"the length of the shape of the output ({len(shape)})." + ) + for i in range(len(indices.shape) - 1): + assert indices.shape[i + 1] == data.shape[i], ( + f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " + f"data[{i}] ({data.shape[i]})." + ) for i in range(int(indices.shape[0]), len(shape)): - assert data.shape[i] == out_shape[i], f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" - - assert "int" in indices.dtype, f"Indices must be a tensor of integers, but its elements are {indices.dtype}" + assert ( + data.shape[i] == out_shape[i] + ), f"Dimension of data[{i}] must equal dimension of out_shape[{i}]" + assert ( + "int" in indices.dtype + ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}" def gen_ir(data_ptr, indices_ptr, out_ptr): ib = ir_builder.create() @@ -254,7 +262,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): fused_shape = 1 for i in shape: fused_shape *= i - with ib.for_range(0, fused_shape): + with ib.for_range(0, fused_shape) as i: out[i] = Cast(data_ptr.dtype, 0) # We combine all the indices dimensions but the first one into a single @@ -277,7 +285,13 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - ib.emit(AssertStmt(indices[i + l * fused_indices_dimension] < shape[l], StringImm("index out of bounds"), Evaluate(0))) + ib.emit( + AssertStmt( + indices[i + l * fused_indices_dimension] < shape[l], + StringImm("index out of bounds"), + Evaluate(0), + ) + ) offset *= shape[l] out[index] = data[i * fused_data_dimension + j] diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 51ea19afe7ce6..35a6040fa25a7 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -77,3 +77,32 @@ def get_reduce_schedule(target): def get_conv2d_nchw_implement(target): return dispatch(target, _conv2d_nchw_implement) + + +def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): + """Compare a numpy inputs and output of a function to the results of the TVM version. + + Parameters + ---------- + inputs : Sequence[numpy.nd.array] + List of input numpy arrays to pass to the function. + output : numpy.nd.array + Verified correct function output. + target : tvm.target.Target + Target to run on. + ctx : tvm.TVMContext + Context to run on. + compute : callable + Topi compute function to test against. + schedule : callable + Topi scheduling function to test against. + """ + te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] + te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) + with tvm.target.Target(target): + out = compute(*te_inputs) + s = schedule([out]) + func = tvm.build(s, te_inputs + [out]) + arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] + func(*(arys + [te_out])) + assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 358441738ecee..a5cb916da6135 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -126,14 +126,5 @@ def test_gather_nd_grad(): check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) -# def test_scatter_nd_grad(): -# data = relay.var("data", relay.TensorType((2, 2), "float64")) -# indices = relay.var("indices", relay.TensorType((2, 2), "int64")) -# fwd = relay.Function([data, indices], relay.scatter_nd(data, indices, (2, 2))) -# data_np = np.array([[0, 1], [2, 3]]).astype("float64") -# indices_np = np.array([[1, 0], [1, 1]]) -# check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[indices_np]) - - if __name__ == "__main__": pytest.main() diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index c8c5d0f400049..ef8f946094717 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -25,12 +25,10 @@ def test_scatter_nd(ctx, target): def check_scatter_nd(data, indices, shape, out): implementations = { - "generic": (lambda x,y: topi.scatter_nd(x,y,shape), topi.generic.schedule_extern), + "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - tvm.testing.compare_numpy_tvm( - [data, indices], out, target, ctx, fcompute, fschedule - ) + tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule) data = np.array([2, 3, 0]) indices = np.array([[1, 1, 0], [0, 1, 0]]) @@ -44,5 +42,6 @@ def check_scatter_nd(data, indices, shape, out): out = np.array([[[[0, 0], [1, 2]], [[0, 0], [3, 4]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]) check_scatter_nd(data, indices, shape, out) + if __name__ == "__main__": test_scatter_nd(tvm.context("cpu"), tvm.target.Target("llvm"))