Skip to content

Commit

Permalink
[Relay] Expose vm OptimizeModule to Python (#4800)
Browse files Browse the repository at this point in the history
* Expose VM OptimizeModule to python

* added missing imports

* fix import
  • Loading branch information
masahi authored Feb 2, 2020
1 parent cf173fd commit 73a9e99
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

import tvm
import tvm.ndarray as _nd
from tvm import autotvm, container
from tvm.object import Object
from tvm.relay import expr as _expr
Expand Down Expand Up @@ -409,6 +410,8 @@ def __init__(self):
self._codegen = self.mod["codegen"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
self._optimize = self.mod["optimize"]

def set_params(self, params):
"""Set constant parameters for the model.
Expand All @@ -426,6 +429,14 @@ def set_params(self, params):
inputs[name] = _expr.const(param)
self._set_params_func(inputs)

def get_params(self):
"""Return the updated weights."""
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret

def lower(self, mod, target=None, target_host=None):
"""Lower the module to VM bytecode.
Expand Down Expand Up @@ -458,6 +469,33 @@ def codegen(self):
"""Generate the kernel library."""
self._codegen()

def optimize(self, mod, target=None, params=None):
"""Helper method that optimizes a Relay module via VM.
Parameters
----------
mod : relay.Module
target : str, :any:`tvm.target.Target`, or dict of str (i.e.
device/context name) to str/tvm.target.Target, optional
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
mod : relay.Module
The optimized relay module.
params : dict
The parameters of the final module.
"""
target = self._update_target(target)
if params:
self.set_params(params)
return self._optimize(mod, target), self.get_params()

def get_exec(self):
"""Get the VM executable.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/scope_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""The scope builder interface."""
from __future__ import absolute_import

from . import ty as _ty
from . import expr as _expr
from .._ffi import base as _base

Expand Down
13 changes: 13 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,19 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
this->SetParam(kv.first, kv.second->data);
}
});
} else if (name == "get_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> ret;
for (const auto& kv : params_) {
ret.Set(kv.first, ConstantNode::make(kv.second));
}
*rv = ret;
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
*rv = this->OptimizeModule(args[0], args[1]);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.testing.config import ctx_list
from tvm.relay.prelude import Prelude
from tvm.relay import testing
import pytest

def check_result(args, expected_result, mod=None):
Expand Down Expand Up @@ -570,6 +571,10 @@ def test_add_op_broadcast():
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)

def test_vm_optimize():
mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18)
comp = relay.backend.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, "llvm", params)

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 73a9e99

Please sign in to comment.