Skip to content

Commit

Permalink
fix relay.build to not change the module argument in place (#5822)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jun 16, 2020
1 parent 85d2cb1 commit 8931cfa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ class RelayBuildModule : public runtime::ModuleNode {
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params);
relay_module->Update(main_glb_var, new_main);
IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
relay_module_ptr->Update(main_glb_var, new_main);
}

Array<Pass> pass_seqs;
Expand Down
7 changes: 6 additions & 1 deletion tests/python/relay/test_cpp_build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def test_basic_build():
targets = {
tvm.tir.IntImm("int32", ctx.device_type): tgt
}
g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params)
mod = tvm.IRModule.from_expr(func)
func_in_mod = mod["main"]
assert mod["main"] == func_in_mod, "cannot compare function to itself"

g_json, mmod, params = relay.build(mod, targets, "llvm", params=params)
assert mod["main"] == func_in_mod, "relay.build changed module in-place"

# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
Expand Down

0 comments on commit 8931cfa

Please sign in to comment.