diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 5a20480a222b..cded2e136cd1 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -34,25 +34,6 @@ register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) -@script -def _topk_shape_func_input_data(data, k, axis): - ndim = len(data.shape) - val_out = output_tensor((ndim,), "int64") - indices_out = output_tensor((ndim,), "int64") - - for i in const_range(ndim): - if i != axis: - val_out[i] = int64(data.shape[i]) - indices_out[i] = int64(data.shape[i]) - else: - if k[0] < 1: - val_out[i] = int64(data.shape[i]) - indices_out[i] = int64(data.shape[i]) - else: - val_out[i] = int64(k[0]) - indices_out[i] = int64(k[0]) - return val_out, indices_out - @script def _topk_shape_func_input_shape(data_shape, k, axis): ndim = data_shape.shape[0] @@ -72,22 +53,16 @@ def _topk_shape_func_input_shape(data_shape, k, axis): indices_out[i] = int64(k) return val_out, indices_out -@_reg.register_shape_func("topk", True) +@_reg.register_shape_func("topk", False) def topk_shape_func(attrs, inputs, _): """ Shape func for topk. """ axis = attrs.axis - if attrs.k is not None: - if axis < 0: - axis += inputs[0].shape[0] - val_out, indices_out = \ - _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis)) - else: - if axis < 0: - axis += len(inputs[0].shape) - val_out, indices_out = \ - _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis)) + if axis < 0: + axis += inputs[0].shape[0] + val_out, indices_out = \ + _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis)) ret_type = attrs.ret_type if ret_type == "both": ret = [val_out, indices_out] diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index d31e89a49f43..5aeb7e647b4e 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -16,8 +16,10 @@ # under the License. """Classic algorithm operation""" from __future__ import absolute_import as _abs +import numpy as np from . import _make -from ..expr import TupleWrapper, const +from .dyn import _make as _dyn_make +from ..expr import TupleWrapper, Expr, Constant def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies @@ -82,9 +84,12 @@ def topk(data, k=1, axis=-1, ret_type="both", out : relay.Expr or List[relay.Expr] The computed result. """ - if isinstance(k, int): - k = const(k, "int64") - out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) + if isinstance(k, Constant): + k = np.asscalar(k.data.asnumpy()) + if isinstance(k, Expr): + out = _dyn_make.topk(data, k, axis, ret_type, is_ascend, dtype) + else: + out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) if ret_type == "both": return TupleWrapper(out, 2) return out diff --git a/python/tvm/relay/op/dyn/__init__.py b/python/tvm/relay/op/dyn/__init__.py index d659203e27e1..f4d47a6d780c 100644 --- a/python/tvm/relay/op/dyn/__init__.py +++ b/python/tvm/relay/op/dyn/__init__.py @@ -17,4 +17,5 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay namespace containing dynamic ops.""" +from . import _algorithm from . import _transform diff --git a/python/tvm/relay/op/dyn/_algorithm.py b/python/tvm/relay/op/dyn/_algorithm.py new file mode 100644 index 000000000000..b98b7753f403 --- /dev/null +++ b/python/tvm/relay/op/dyn/_algorithm.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Definition of classic algorithms" +# pylint: disable=invalid-name,unused-argument +from __future__ import absolute_import + +from tvm.te.hybrid import script +from tvm.runtime import convert + +from .. import strategy +from .. import op as _reg +from ..op import OpPattern, register_pattern +from ..op import register_strategy + +# topk +register_strategy("dyn.topk", strategy.topk_strategy) +register_pattern("dyn.topk", OpPattern.OPAQUE) + +@script +def _topk_shape_func_input_data(data, k, axis): + ndim = len(data.shape) + val_out = output_tensor((ndim,), "int64") + indices_out = output_tensor((ndim,), "int64") + + for i in const_range(ndim): + if i != axis: + val_out[i] = int64(data.shape[i]) + indices_out[i] = int64(data.shape[i]) + else: + if k[0] < 1: + val_out[i] = int64(data.shape[i]) + indices_out[i] = int64(data.shape[i]) + else: + val_out[i] = int64(k[0]) + indices_out[i] = int64(k[0]) + return val_out, indices_out + +@_reg.register_shape_func("dyn.topk", True) +def topk_shape_func(attrs, inputs, _): + """ + Shape func for topk. + """ + axis = attrs.axis + if axis < 0: + axis += len(inputs[0].shape) + val_out, indices_out = \ + _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis)) + + ret_type = attrs.ret_type + if ret_type == "both": + ret = [val_out, indices_out] + elif ret_type == "values": + ret = [val_out] + else: + ret = [indices_out] + + return ret diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 632445b71bd2..db0577cf8bdf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -656,9 +656,10 @@ def argsort_strategy(attrs, inputs, out_type, target): def wrap_compute_topk(topi_compute): """Wrap topk compute""" def _compute_topk(attrs, inputs, out_type): - k = inputs[1] if attrs.k is not None: k = attrs.k + else: + k = inputs[1] axis = get_const_int(attrs.axis) ret_type = attrs.ret_type is_ascend = bool(get_const_int(attrs.is_ascend)) diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 10c226e17f7a..c8dbb49e15db 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) { return false; } - if (op->name == "topk") { - if (const auto* attrs = call->attrs.as()) { - if (attrs->k) { - // If k attribute exists, it isn't data dependant. - return false; - } - } - } else if (op->name == "strided_slice") { + if (op->name == "strided_slice") { if (const auto* attrs = call->attrs.as()) { if (attrs->begin && attrs->end && attrs->strides) { // not data dependant if begin, end and strides exist diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index b02fe86f6baa..14308dd592d6 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -27,7 +27,6 @@ namespace tvm { namespace relay { -using tir::make_const; TVM_REGISTER_NODE_TYPE(TopKAttrs); @@ -35,7 +34,7 @@ bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const TopKAttrs* param = attrs.as(); - CHECK_EQ(types.size(), 3); + CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); CHECK(data); int ndim = data->shape.size(); @@ -48,42 +47,38 @@ bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, for (int i = 0; i < ndim; ++i) { if (i != axis) { out_shape.push_back(data->shape[i]); - } else if (param->k) { + } else { const Integer& ck = param->k.value(); if (ck->value < 1) { out_shape.push_back(data->shape[i]); } else { out_shape.push_back(ck); } - } else { - out_shape.push_back(Any()); } } auto values_ty = TensorType(out_shape, data->dtype); auto indices_ty = TensorType(out_shape, param->dtype); if (param->ret_type == "both") { - reporter->Assign(types[2], TupleType({values_ty, indices_ty})); + reporter->Assign(types[1], TupleType({values_ty, indices_ty})); } else if (param->ret_type == "values") { - reporter->Assign(types[2], values_ty); + reporter->Assign(types[1], values_ty); } else if (param->ret_type == "indices") { - reporter->Assign(types[2], indices_ty); + reporter->Assign(types[1], indices_ty); } else { LOG(FATAL) << "Unsupported ret type: " << param->ret_type; } return true; } -Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) { +Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) { auto attrs = make_object(); - if (const auto& ck = k.as()) { - attrs->k = tvm::Integer(reinterpret_cast(ck->data->data)[0]); - } + attrs->k = Integer(k); attrs->axis = axis; attrs->ret_type = ret_type; attrs->is_ascend = is_ascend; attrs->dtype = dtype; static const Op& op = Op::Get("topk"); - return Call(op, {data, k}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); @@ -91,10 +86,9 @@ TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); RELAY_REGISTER_OP("topk") .describe(R"doc(Get the top k elements in an input tensor along the given axis. )doc" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(1) .set_attrs_type() .add_argument("data", "Tensor", "Input data.") - .add_argument("k", "Tensor", "Number of top elements.") .set_support_level(6) .add_type_rel("TopK", TopKRel); diff --git a/src/relay/op/dyn/algorithm/topk.cc b/src/relay/op/dyn/algorithm/topk.cc new file mode 100644 index 000000000000..1c88730a5463 --- /dev/null +++ b/src/relay/op/dyn/algorithm/topk.cc @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file topk.cc + * \brief TopK operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { +namespace dyn { + +bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, k, result] + const TopKAttrs* param = attrs.as(); + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* k = types[1].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "tile: expect input type to be TensorType but get " << types[0]; + return false; + } + if (k == nullptr) { + CHECK(types[1].as()) + << "tile: expect input type to be TensorType but get " << types[1]; + return false; + } + CHECK(k->shape.size() <= 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )"; + if (k->shape.size() == 1) { + const IntImmNode* k_shape = k->shape[0].as(); + CHECK(k_shape) << "Parameter k must have static shape"; + CHECK_EQ(k_shape->value, 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )"; + } + int ndim = data->shape.size(); + int axis = param->axis; + if (axis < 0) { + axis += ndim; + } + CHECK(axis >= 0 && axis < ndim); + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i != axis) { + out_shape.push_back(data->shape[i]); + } else { + out_shape.push_back(Any()); + } + } + auto values_ty = TensorType(out_shape, data->dtype); + auto indices_ty = TensorType(out_shape, param->dtype); + if (param->ret_type == "both") { + reporter->Assign(types[2], TupleType({values_ty, indices_ty})); + } else if (param->ret_type == "values") { + reporter->Assign(types[2], values_ty); + } else if (param->ret_type == "indices") { + reporter->Assign(types[2], indices_ty); + } else { + LOG(FATAL) << "Unsupported ret type: " << param->ret_type; + } + return true; +} + +Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) { + auto attrs = make_object(); + attrs->axis = axis; + attrs->ret_type = ret_type; + attrs->is_ascend = is_ascend; + attrs->dtype = dtype; + static const Op& op = Op::Get("dyn.topk"); + return Call(op, {data, k}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.topk").set_body_typed(MakeTopK); + +RELAY_REGISTER_OP("dyn.topk") + .describe(R"doc(Get the top k elements in an input tensor along the given axis. +)doc" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .add_argument("k", "Tensor", "Number of top elements.") + .set_support_level(6) + .add_type_rel("DynTopK", TopKRel); + +} // namespace dyn +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index d09230ac30d7..dced5020ca0b 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -22,6 +22,7 @@ * \file dynamic_to_static.cc * \brief Rewrite Dynamic Operations to Static operations where possible */ +#include #include #include @@ -33,7 +34,9 @@ namespace relay { class DynamicToStaticMutator : public MixedModeMutator { public: DynamicToStaticMutator() - : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {} + : dyn_reshape_op_(Op::Get("dyn.reshape")), + dyn_tile_op_(Op::Get("dyn.tile")), + dyn_topk_op_(Op::Get("dyn.topk")) {} private: Expr Rewrite_(const CallNode* pre, const Expr& post) override { @@ -55,6 +58,20 @@ class DynamicToStaticMutator : public MixedModeMutator { static const Op& op = Op::Get("tile"); return Call(op, {call_node->args[0]}, Attrs(attrs), {}); } + } else if (call_node->op == dyn_topk_op_) { + if (const ConstantNode* k = call_node->args[1].as()) { + const TopKAttrs* param = call_node->attrs.as(); + CHECK(param); + auto attrs = make_object(); + attrs->k = Integer(ToScalar(k->data, 0)); + std::cout << attrs->k << std::endl; + attrs->axis = param->axis; + attrs->ret_type = param->ret_type; + attrs->is_ascend = param->is_ascend; + attrs->dtype = param->dtype; + static const Op& op = Op::Get("topk"); + return Call(op, {call_node->args[0]}, Attrs(attrs), {}); + } } return post; } @@ -68,6 +85,7 @@ class DynamicToStaticMutator : public MixedModeMutator { const Op& dyn_reshape_op_; const Op& dyn_tile_op_; + const Op& dyn_topk_op_; }; Expr DynamicToStatic(Function f, IRModule m) { diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py new file mode 100644 index 000000000000..60a1433bd108 --- /dev/null +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -0,0 +1,76 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Support level6 operator test cases. +""" +import numpy as np +import tvm +from tvm import te +from tvm import relay +from tvm.relay.testing import ctx_list + +def test_dynamic_topk(): + def verify_topk(k, axis, ret_type, is_ascend, dtype): + shape = (20, 100) + x = relay.var("x", relay.TensorType(shape, "float32")) + k_var = relay.var("x", relay.TensorType((1,), "float32")) + out = relay.topk(x, k_var, axis, ret_type, is_ascend, dtype) + if isinstance(out, relay.expr.TupleWrapper): + out = out.astuple() + func = relay.Function([x, k_var], out) + + np_data = np.random.uniform(size=shape).astype("float32") + if is_ascend: + np_indices = np.argsort(np_data, axis=axis) + else: + np_indices = np.argsort(-np_data, axis=axis) + kk = k if k >= 1 else shape[axis] + if axis == 0: + np_indices = np_indices[:kk, :] + np_values = np.zeros(np_indices.shape).astype("float32") + for i in range(shape[1]): + np_values[:, i] = np_data[np_indices[:, i], i] + else: + np_indices = np_indices[:, :kk] + np_values = np.zeros(np_indices.shape).astype("float32") + for i in range(shape[0]): + np_values[i, :] = np_data[i, np_indices[i, :]] + np_indices = np_indices.astype(dtype) + + for target, ctx in ctx_list(): + if "llvm" not in target: continue + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(np_data, np.array([k]).astype("float32")) + if ret_type == "both": + tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values) + tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices) + elif ret_type == "values": + tvm.testing.assert_allclose(op_res.asnumpy(), np_values) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), np_indices) + np.random.seed(0) + for k in [0, 1, 5]: + for axis in [0, -1, 1]: + for ret_type in ["both", "values", "indices"]: + verify_topk(k, axis, ret_type, True, "int64") + verify_topk(k, axis, ret_type, False, "float32") + + +if __name__ == "__main__": + test_topk() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 3415ce01d5fd..bcd8a644e807 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -129,9 +129,62 @@ def verify_tile(shape, reps, oshape): verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20)) verify_tile((4, 7), (4, 2), (16, 14)) +def test_dynamic_to_static_topk(): + def verify_topk(k, axis, ret_type, is_ascend, dtype): + shape = (20, 100) + x = relay.var("x", relay.TensorType(shape, "float32")) + k_var = relay.const(k) + out = relay.topk(x, k_var, axis, ret_type, is_ascend, dtype) + if isinstance(out, relay.expr.TupleWrapper): + out = out.astuple() + func = relay.Function([x], out) + + np_data = np.random.uniform(size=shape).astype("float32") + if is_ascend: + np_indices = np.argsort(np_data, axis=axis) + else: + np_indices = np.argsort(-np_data, axis=axis) + kk = k if k >= 1 else shape[axis] + if axis == 0: + np_indices = np_indices[:kk, :] + np_values = np.zeros(np_indices.shape).astype("float32") + for i in range(shape[1]): + np_values[:, i] = np_data[np_indices[:, i], i] + else: + np_indices = np_indices[:, :kk] + np_values = np.zeros(np_indices.shape).astype("float32") + for i in range(shape[0]): + np_values[i, :] = np_data[i, np_indices[i, :]] + np_indices = np_indices.astype(dtype) + + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("topk") + + for target, ctx in ctx_list(): + if "llvm" not in target: continue + for kind in ["graph", "vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func2) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(np_data) + if ret_type == "both": + tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values) + tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices) + elif ret_type == "values": + tvm.testing.assert_allclose(op_res.asnumpy(), np_values) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), np_indices) + np.random.seed(0) + for k in [0, 1, 5]: + for axis in [0, -1, 1]: + for ret_type in ["both", "values", "indices"]: + verify_topk(k, axis, ret_type, True, "int64") + verify_topk(k, axis, ret_type, False, "float32") if __name__=="__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() test_dynamic_to_static_quad_reshape() test_dynamic_to_static_tile() + test_dynamic_to_static_topk()