diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index e748c297c76f..2f133e1a422d 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -345,6 +345,7 @@ def __init__(self, exe, device, memory_cfg=None): self._invoke_stateful = self.module["invoke_stateful"] self._get_output = self.module["get_output"] self._get_num_outputs = self.module["get_num_outputs"] + self._get_input_index = self.module["get_input_index"] self._set_input = self.module["set_input"] self._setup_device(device, memory_cfg) @@ -490,3 +491,19 @@ def get_outputs(self): outputs : List[NDArray] """ return [self._get_output(i) for i in range(self._get_num_outputs())] + + def get_input_index(self, input_name, func_name="main"): + """Get inputs index via input name. + Parameters + ---------- + name : str + The input key name + func_name : str + The function name + + Returns + ------- + index: int + The input index. -1 will be returned if the given input name is not found. + """ + return self._get_input_index(input_name, func_name) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c96364108a2a..925e867f2e1b 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -159,6 +159,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, return 1; } }); + } else if (name == "get_input_index") { + return TypedPackedFunc( + [this](std::string input_name, std::string func_name) { + auto gvit = exec_->global_map.find(func_name); + ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + const auto& vm_func = exec_->functions[func_index]; + const auto& param_names = vm_func.params; + for (uint64_t i = 0; i < param_names.size(); i++) { + if (input_name == param_names[i]) { + return static_cast(i); + } + } + return static_cast(-1); + }); } else if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size() % 3, 0); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 6c229064b094..38dd3a9fafae 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -941,5 +941,22 @@ def test_get_output_multiple(): np.testing.assert_allclose(outputs[1].numpy(), inp) +def test_get_input_index(): + target = tvm.target.Target("llvm") + + # Build a IRModule. + data_0, data_1 = ["d1", "d2"] + x, y = [relay.var(c, shape=(10,)) for c in [data_0, data_1]] + f = relay.Function([x, y], x + y) + mod = IRModule.from_expr(f) + + # Compile to VMExecutable. + vm_exec = vm.compile(mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + assert vm_factory.get_input_index(data_1) == 1 + assert vm_factory.get_input_index(data_0) == 0 + assert vm_factory.get_input_index("invalid") == -1 + + if __name__ == "__main__": pytest.main([__file__])