Skip to content

Commit

Permalink
Refactor to use IsOp utility
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Dec 27, 2019
1 parent a55d119 commit fdcee3e
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 84 deletions.
17 changes: 15 additions & 2 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,10 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,


/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "primitive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* matches the form of primitive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
Expand All @@ -610,6 +610,19 @@ inline bool IsPrimitiveOp(const Expr& expr) {
return op != nullptr && op->IsPrimitiveOp();
}

/*!
* \brief Check if an op is the one with the provided name.
*
* \param ref The op to be checked.
* \param name The given operator name.
*
* \return True if the op is the same with the given name. Otherwise, false.
*/
inline bool IsOp(const NodeRef& ref, const std::string& name) {
CHECK(ref.defined());
return ref->IsInstance<OpNode>() && Downcast<Op>(ref) == Op::Get(name);
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_
6 changes: 4 additions & 2 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"

#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
Expand All @@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
Expand All @@ -38,7 +41,6 @@
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -251,7 +253,7 @@ class ScheduleGetter :
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
// Check if the op is a device copy op.
bool is_copy_op = op.same_as(Op::Get("device_copy"));
bool is_copy_op = IsOp(op, "device_copy");
Array<Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (is_copy_op) {
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
Expand Down Expand Up @@ -56,11 +57,11 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
// Make function declaration
macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", ";

if (IsOp(call, "add")) {
if (IsOp(call->op, "add")) {
macro_stream << "+";
} else if (IsOp(call, "subtract")) {
} else if (IsOp(call->op, "subtract")) {
macro_stream << "-";
} else if (IsOp(call, "multiply")) {
} else if (IsOp(call->op, "multiply")) {
macro_stream << "*";
} else {
LOG(FATAL) << "Unrecognized op";
Expand Down
16 changes: 0 additions & 16 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,6 @@ class CodegenCBase {
return shape;
}

/*!
* \brief Check if a call has the provided name.
*
* \param call A Relay call node.
* \param op_name The name of the expected call.
*
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, std::string op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
return op == Op::Get(op_name);
}

/*!
* \brief A common interface that is used by various external runtime to
* generate the wrapper to invoke external kernels.
Expand Down
11 changes: 6 additions & 5 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
Expand Down Expand Up @@ -61,19 +62,19 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
std::vector<std::string> args;

// Get the arguments for various DNNL kernels.
if (IsOp(call, "nn.conv2d")) {
if (IsOp(call->op, "nn.conv2d")) {
decl_stream << "dnnl_conv2d";
args = Conv2d(call);
} else if (IsOp(call, "nn.dense")) {
} else if (IsOp(call->op, "nn.dense")) {
decl_stream << "dnnl_dense";
args = Dense(call);
} else if (IsOp(call, "nn.relu")) {
} else if (IsOp(call->op, "nn.relu")) {
decl_stream << "dnnl_relu";
args = Relu(call);
} else if (IsOp(call, "nn.batch_norm")) {
} else if (IsOp(call->op, "nn.batch_norm")) {
decl_stream << "dnnl_bn";
args = BatchNorm(call);
} else if (IsOp(call, "add")) {
} else if (IsOp(call->op, "add")) {
decl_stream << "dnnl_add";
args = Add(call);
} else {
Expand Down
5 changes: 3 additions & 2 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -456,7 +457,7 @@ class Interpreter :
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();

if (call_node && call_node->op == Op::Get("debug")) {
if (call_node && IsOp(call_node->op, "debug")) {
auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]);

Expand Down Expand Up @@ -540,7 +541,7 @@ class Interpreter :
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) {
if (IsOp(call_node->op, "shape_of")) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
Expand Down
6 changes: 2 additions & 4 deletions src/relay/pass/canonicalize_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,10 @@ class CastCanonicalizer : public ExprMutator {

Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor

static auto& cast = Op::Get("cast");
Expr new_expr = this->VisitExpr(e);

if (const CallNode* call = e.as<CallNode>()) {
if (call->op.same_as(cast)) {
if (IsOp(call->op, "cast")) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
CHECK(from_type);
Expand All @@ -108,7 +106,7 @@ class CastCanonicalizer : public ExprMutator {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
CHECK(new_call->op.same_as(cast));
CHECK(IsOp(new_call->op, "cast"));
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
new_call->type_args);
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/pass/canonicalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
Expand All @@ -34,9 +35,8 @@ namespace relay {
class BiasAddSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const CallNode* n) {
static const Op& bias_add = Op::Get("nn.bias_add");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(bias_add)) {
if (IsOp(n->op, "nn.bias_add")) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
Expand Down
13 changes: 7 additions & 6 deletions src/relay/pass/combine_parallel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
#include "expr_subst.h"
#include "pattern_util.h"
#include "combine_parallel_op.h"


namespace tvm {
Expand All @@ -48,16 +51,14 @@ BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
}

std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);

this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : op_roots_) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue;
if (!IsOp(child->op, op_name_)) continue;

auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
Expand Down
11 changes: 6 additions & 5 deletions src/relay/pass/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
Expand Down Expand Up @@ -119,15 +120,15 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op.same_as(Op::Get("shape_of"))) {
if (IsOp(call->op, "shape_of")) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}

// We should think about potentially constant evaluation over these ops too.
if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) ||
call->op.same_as(Op::Get("memory.shape_func")) ||
call->op.same_as(Op::Get("memory.alloc_tensor")) ||
call->op.same_as(Op::Get("memory.alloc_storage"))) {
if (IsOp(call->op, "memory.invoke_tvm_op") ||
IsOp(call->op, "memory.shape_func") ||
IsOp(call->op, "memory.alloc_tensor") ||
IsOp(call->op, "memory.alloc_storage")) {
return GetRef<Call>(call);
}

Expand Down
3 changes: 1 addition & 2 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,6 @@ class FuseMutator : private ExprMutator {

// Transform calls.
Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) {
static auto fnoncomputational =
Op::GetAttr<TNonComputational>("TNonComputational");
Expand All @@ -872,7 +871,7 @@ class FuseMutator : private ExprMutator {
// If it is a primitive op call
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
if (call->op.same_as(stop_fusion)) {
if (IsOp(call->op, "annotation.stop_fusion")) {
return ExprMutator::VisitExpr(call->args[0]);
}
auto* ret_group = gmap_.at(call)->FindRoot();
Expand Down
18 changes: 7 additions & 11 deletions src/relay/pass/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,11 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {

TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);

Op WithFuncIdOp() {
static const Op& op = Op::Get("annotation.with_funcid");
return op;
}

Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
static const Op& op = Op::Get("annotation.with_funcid");
attrs->fid = fid;
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
return CallNode::make(op, {expr}, Attrs(attrs), {});
}

RELAY_REGISTER_OP("annotation.with_funcid")
Expand All @@ -582,7 +578,7 @@ Function AsFunc(const Expr& e) {
if (e.as<FunctionNode>()) {
return Downcast<Function>(e);
} else if (const CallNode* c = e.as<CallNode>()) {
CHECK(c->op.same_as(WithFuncIdOp()));
CHECK(IsOp(c->op, "annotation.with_funcid"));
CHECK_EQ(c->args.size(), 1);
return AsFunc(c->args[0]);
} else {
Expand All @@ -604,7 +600,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>

PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (const CallNode* c = e.as<CallNode>()) {
if (c->op.same_as(WithFuncIdOp())) {
if (IsOp(c->op, "annotation.with_funcid")) {
CHECK_EQ(c->args.size(), 1);
return VisitExpr(c->args[0], ll, name);
}
Expand Down Expand Up @@ -722,7 +718,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}

PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
if (op->op.same_as(WithFuncIdOp())) {
if (IsOp(op->op, "annotation.with_funcid")) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll);
}
Expand Down Expand Up @@ -1096,7 +1092,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
if (IsOp(op->op, "annotation.with_funcid")) {
CHECK_EQ(op->args.size(), 1);
CHECK(op->attrs.defined());
CHECK(op->attrs.as<WithFuncIdAttrs>());
Expand Down Expand Up @@ -1194,7 +1190,7 @@ Expr Remap(const Expr& e) {
Expr StripWithFuncId(const Expr& e) {
struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
Expr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
if (IsOp(op->op, "annotation.with_funcid")) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0]);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/pass/quantize/calibrate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include "./quantize.h"


Expand All @@ -47,11 +48,10 @@ class StatsCollector : private ExprMutator {
Array<Expr> profile_data_;

Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
if (IsOp(new_call->op, "relay.op.annotation.simulated_quantize")) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_node<SimulatedQuantizeAttrs>();
Expand Down
Loading

0 comments on commit fdcee3e

Please sign in to comment.