From 263ae9e1e4242295ab3eabca3d150cfecc56c2a3 Mon Sep 17 00:00:00 2001 From: wangxiang2713 <49302617+wangxiang2713@users.noreply.github.com> Date: Thu, 4 Mar 2021 06:24:52 +0800 Subject: [PATCH] compile engine dump tir and shape funcs (#7552) --- python/tvm/relay/backend/compile_engine.py | 33 ++++++++++++++++++++++ src/relay/backend/compile_engine.cc | 18 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index a39f72e2e61f..68397cc0cef6 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -386,6 +386,18 @@ def items(self): assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] + def shape_func_items(self): + """List items in the shape_func_cache. + + Returns + ------- + item_list : List[Tuple[CCacheKey, CCacheValue]] + The list of shape_func_items. + """ + res = _backend._CompileEngineListShapeFuncItems(self) + assert len(res) % 2 == 0 + return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] + def get_current_ccache_key(self): return _backend._CompileEngineGetCurrentCCacheKey(self) @@ -405,7 +417,28 @@ def dump(self): res += "target={}\n".format(k.target) res += "use_count={}\n".format(v.use_count) res += "func_name={}\n".format(v.cached_func.func_name) + res += "----relay function----\n" + res += k.source_func.astext() + "\n" + res += "----tir function----- \n" + res += "inputs={}\n".format(v.cached_func.inputs) + res += "outputs={}\n".format(v.cached_func.outputs) + res += "function: \n" + res += v.cached_func.funcs.astext() + "\n" + res += "===================================\n" + shape_func_items = self.shape_func_items() + res += "%d shape_func_items cached\n" % len(shape_func_items) + for k, v in shape_func_items: + res += "------------------------------------\n" + res += "target={}\n".format(k.target) + res += "use_count={}\n".format(v.use_count) + res += "func_name={}\n".format(v.cached_func.func_name) + res += "----relay function----\n" res += k.source_func.astext() + "\n" + res += "----tir function----- \n" + res += "inputs={}\n".format(v.cached_func.inputs) + res += "outputs={}\n".format(v.cached_func.outputs) + res += "function: \n" + res += v.cached_func.funcs.astext() + "\n" res += "===================================\n" return res diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ed09e4f6eb32..ae975a5f3240 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -692,6 +692,17 @@ class CompileEngineImpl : public CompileEngineNode { return items; } + // List all items in the shape_func_cache. + Array ListShapeFuncItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : shape_func_cache_) { + items.push_back(kv.first); + items.push_back(kv.second); + } + return items; + } + /*! * \brief Get the cache key of the function that is being lowered currently * \return the cache key @@ -882,6 +893,13 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](C return ptr->ListItems(); }); +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") + .set_body_typed([](CompileEngine self) { + CompileEngineImpl* ptr = dynamic_cast(self.operator->()); + ICHECK(ptr != nullptr); + return ptr->ListShapeFuncItems(); + }); + TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") .set_body_typed([](CompileEngine self) { CompileEngineImpl* ptr = dynamic_cast(self.operator->());