Skip to content

Commit

Permalink
fix string downcast
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 10, 2020
1 parent 38f8564 commit 8801a51
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {

GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
const CallNode* caller) {
const auto pattern_name = callee->GetAttr<tir::StringImm>(attr::kComposite);
const auto pattern_name = callee->GetAttr<tvm::runtime::String>(attr::kComposite);
CHECK(pattern_name.defined()) << "Only functions with composite attribute supported";

if (pattern_name->value == "dnnl.conv2d_bias_relu") {
if (pattern_name == "dnnl.conv2d_bias_relu") {
const auto* conv_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller),
Conv2d(conv_call));
} else if (pattern_name->value == "dnnl.conv2d_relu") {
} else if (pattern_name == "dnnl.conv2d_relu") {
const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller),
Conv2d(conv_call));
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tvm import relay, runtime
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay import transform
from tvm.relay.backend import compile_engine
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
Expand Down

0 comments on commit 8801a51

Please sign in to comment.