Skip to content

Commit

Permalink
compile engine dump tir and shape funcs (apache#7552)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiang2713 authored and trevor-m committed May 11, 2021
1 parent b4630b8 commit fda0de9
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
33 changes: 33 additions & 0 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,18 @@ def items(self):
assert len(res) % 2 == 0
return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)]

def shape_func_items(self):
"""List items in the shape_func_cache.
Returns
-------
item_list : List[Tuple[CCacheKey, CCacheValue]]
The list of shape_func_items.
"""
res = _backend._CompileEngineListShapeFuncItems(self)
assert len(res) % 2 == 0
return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)]

def get_current_ccache_key(self):
return _backend._CompileEngineGetCurrentCCacheKey(self)

Expand All @@ -405,7 +417,28 @@ def dump(self):
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
res += "inputs={}\n".format(v.cached_func.inputs)
res += "outputs={}\n".format(v.cached_func.outputs)
res += "function: \n"
res += v.cached_func.funcs.astext() + "\n"
res += "===================================\n"
shape_func_items = self.shape_func_items()
res += "%d shape_func_items cached\n" % len(shape_func_items)
for k, v in shape_func_items:
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.func_name)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
res += "inputs={}\n".format(v.cached_func.inputs)
res += "outputs={}\n".format(v.cached_func.outputs)
res += "function: \n"
res += v.cached_func.funcs.astext() + "\n"
res += "===================================\n"
return res

Expand Down
18 changes: 18 additions & 0 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,17 @@ class CompileEngineImpl : public CompileEngineNode {
return items;
}

// List all items in the shape_func_cache.
Array<ObjectRef> ListShapeFuncItems() {
std::lock_guard<std::mutex> lock(mutex_);
Array<ObjectRef> items;
for (auto& kv : shape_func_cache_) {
items.push_back(kv.first);
items.push_back(kv.second);
}
return items;
}

/*!
* \brief Get the cache key of the function that is being lowered currently
* \return the cache key
Expand Down Expand Up @@ -882,6 +893,13 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](C
return ptr->ListItems();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems")
.set_body_typed([](CompileEngine self) {
CompileEngineImpl* ptr = dynamic_cast<CompileEngineImpl*>(self.operator->());
ICHECK(ptr != nullptr);
return ptr->ListShapeFuncItems();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey")
.set_body_typed([](CompileEngine self) {
CompileEngineImpl* ptr = dynamic_cast<CompileEngineImpl*>(self.operator->());
Expand Down

0 comments on commit fda0de9

Please sign in to comment.