Skip to content

Commit

Permalink
[VM] Add get_input_index support. (apache#8661)
Browse files Browse the repository at this point in the history
  • Loading branch information
huajsj authored and mehrdadh committed Aug 11, 2021
1 parent f758bda commit 4222a3f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def __init__(self, exe, device, memory_cfg=None):
self._invoke_stateful = self.module["invoke_stateful"]
self._get_output = self.module["get_output"]
self._get_num_outputs = self.module["get_num_outputs"]
self._get_input_index = self.module["get_input_index"]
self._set_input = self.module["set_input"]
self._setup_device(device, memory_cfg)

Expand Down Expand Up @@ -490,3 +491,19 @@ def get_outputs(self):
outputs : List[NDArray]
"""
return [self._get_output(i) for i in range(self._get_num_outputs())]

def get_input_index(self, input_name, func_name="main"):
"""Get inputs index via input name.
Parameters
----------
name : str
The input key name
func_name : str
The function name
Returns
-------
index: int
The input index. -1 will be returned if the given input name is not found.
"""
return self._get_input_index(input_name, func_name)
15 changes: 15 additions & 0 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
return 1;
}
});
} else if (name == "get_input_index") {
return TypedPackedFunc<int64_t(std::string, std::string)>(
[this](std::string input_name, std::string func_name) {
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
for (uint64_t i = 0; i < param_names.size(); i++) {
if (input_name == param_names[i]) {
return static_cast<int64_t>(i);
}
}
return static_cast<int64_t>(-1);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size() % 3, 0);
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,5 +937,22 @@ def test_get_output_multiple():
np.testing.assert_allclose(outputs[1].numpy(), inp)


def test_get_input_index():
target = tvm.target.Target("llvm")

# Build a IRModule.
data_0, data_1 = ["d1", "d2"]
x, y = [relay.var(c, shape=(10,)) for c in [data_0, data_1]]
f = relay.Function([x, y], x + y)
mod = IRModule.from_expr(f)

# Compile to VMExecutable.
vm_exec = vm.compile(mod, target=target)
vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
assert vm_factory.get_input_index(data_1) == 1
assert vm_factory.get_input_index(data_0) == 0
assert vm_factory.get_input_index("invalid") == -1


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 4222a3f

Please sign in to comment.