Skip to content

Commit

Permalink
[VM] Allow serialization of function attrs which are strings (#8485)
Browse files Browse the repository at this point in the history
* [VM] Allow serialization of function attrs which are strings

* add test
  • Loading branch information
tkonolige authored Jul 17, 2021
1 parent c95d16e commit 5ecd6cd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
45 changes: 28 additions & 17 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ TVMByteArray Executable::Save() {
}

void Executable::SaveGlobalSection(dmlc::Stream* strm) {
std::vector<std::pair<std::string, Index> > globals(this->global_map.begin(),
this->global_map.end());
std::vector<std::pair<std::string, Index>> globals(this->global_map.begin(),
this->global_map.end());
auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
Expand Down Expand Up @@ -273,13 +273,20 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) {
primitive_names[packed_index] = it.first;
}
strm->Write(primitive_names);
// TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer.
// std::vector<std::pair<size_t, Map<String, ObjectRef>>> primitive_attrs;
// for (const auto& it : this->op_attrs) {
// auto packed_index = static_cast<size_t>(it.first);
// primitive_attrs.push_back({packed_index, it.second});
// }
// strm->Write(primitive_attrs);
std::map<size_t, std::map<std::string, std::string>> primitive_attrs;
for (const auto& it : this->op_attrs) {
auto packed_index = static_cast<size_t>(it.first);
std::map<std::string, std::string> attrs;
for (const auto& elem : it.second) {
// TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer, so we just serialize
// strings for now
if (elem.second.as<StringObj>()) {
attrs[elem.first] = Downcast<String>(elem.second);
}
}
primitive_attrs[packed_index] = attrs;
}
strm->Write(primitive_attrs);
}

// Serialize a virtual machine instruction. It creates a list that contains the
Expand Down Expand Up @@ -576,12 +583,16 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) {
for (size_t i = 0; i < primitive_names.size(); i++) {
this->primitive_map.insert({primitive_names[i], i});
}
// TODO(tkonolige): cannot serialize ObjectRefs with dmlc's serializer.
// std::vector<std::pair<size_t, Map<String, ObjectRef>>> primitive_attrs;
// STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs");
// for (auto p : primitive_attrs) {
// this->op_attrs.insert(p);
// }

std::map<size_t, std::map<std::string, std::string>> primitive_attrs;
STREAM_CHECK(strm->Read(&primitive_attrs), "primitive attrs");
for (const auto& fn : primitive_attrs) {
std::vector<std::pair<String, ObjectRef>> attrs;
for (const auto& elem : fn.second) {
attrs.push_back({elem.first, String(elem.second)});
}
this->op_attrs[fn.first] = Map<String, ObjectRef>(attrs.begin(), attrs.end());
}
}

// Extract the `cnt` number of fields started at `start` from the list
Expand Down Expand Up @@ -864,8 +875,8 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetV
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
ICHECK(exec);
int idx = args[1];
std::vector<std::pair<std::string, Index> > globals(exec->global_map.begin(),
exec->global_map.end());
std::vector<std::pair<std::string, Index>> globals(exec->global_map.begin(),
exec->global_map.end());
auto comp = [](const std::pair<std::string, Index>& a, const std::pair<std::string, Index>& b) {
return a.second < b.second;
};
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_runtime_vm_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def test_basic(dev, target):
return

exe = relay.vm.compile(mod, target, params=params)
vm = profiler_vm.VirtualMachineProfiler(exe, dev)
code, lib = exe.save()
des_exe = tvm.runtime.vm.Executable.load_exec(code, lib)
vm = profiler_vm.VirtualMachineProfiler(des_exe, dev)

data = np.random.rand(1, 1, 28, 28).astype("float32")
res = vm.profile(tvm.nd.array(data), func_name="main")
Expand Down

0 comments on commit 5ecd6cd

Please sign in to comment.