diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index ca9b22a9099d..c2dc0307e166 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -234,8 +234,8 @@ TVMByteArray Executable::Save() { } void Executable::SaveGlobalSection(dmlc::Stream* strm) { - std::vector > globals(this->global_map.begin(), - this->global_map.end()); + std::vector> globals(this->global_map.begin(), + this->global_map.end()); auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; @@ -273,13 +273,20 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { primitive_names[packed_index] = it.first; } strm->Write(primitive_names); - // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer. - // std::vector>> primitive_attrs; - // for (const auto& it : this->op_attrs) { - // auto packed_index = static_cast(it.first); - // primitive_attrs.push_back({packed_index, it.second}); - // } - // strm->Write(primitive_attrs); + std::map> primitive_attrs; + for (const auto& it : this->op_attrs) { + auto packed_index = static_cast(it.first); + std::map attrs; + for (const auto& elem : it.second) { + // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer, so we just serialize + // strings for now + if (elem.second.as()) { + attrs[elem.first] = Downcast(elem.second); + } + } + primitive_attrs[packed_index] = attrs; + } + strm->Write(primitive_attrs); } // Serialize a virtual machine instruction. It creates a list that contains the @@ -576,12 +583,16 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { for (size_t i = 0; i < primitive_names.size(); i++) { this->primitive_map.insert({primitive_names[i], i}); } - // TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer. - // std::vector>> primitive_attrs; - // STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs"); - // for (auto p : primitive_attrs) { - // this->op_attrs.insert(p); - // } + + std::map> primitive_attrs; + STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs"); + for (const auto& fn : primitive_attrs) { + std::vector> attrs; + for (const auto& elem : fn.second) { + attrs.push_back({elem.first, String(elem.second)}); + } + this->op_attrs[fn.first] = Map(attrs.begin(), attrs.end()); + } } // Extract the `cnt` number of fields started at `start` from the list @@ -864,8 +875,8 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetV const auto* exec = dynamic_cast(mod.operator->()); ICHECK(exec); int idx = args[1]; - std::vector > globals(exec->global_map.begin(), - exec->global_map.end()); + std::vector> globals(exec->global_map.begin(), + exec->global_map.end()); auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index e3422ae45945..75b61d281840 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -29,7 +29,9 @@ def test_basic(dev, target): return exe = relay.vm.compile(mod, target, params=params) - vm = profiler_vm.VirtualMachineProfiler(exe, dev) + code, lib = exe.save() + des_exe = tvm.runtime.vm.Executable.load_exec(code, lib) + vm = profiler_vm.VirtualMachineProfiler(des_exe, dev) data = np.random.rand(1, 1, 28, 28).astype("float32") res = vm.profile(tvm.nd.array(data), func_name="main")