Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][RELAY] move fallback_device to config #5690

Merged
merged 2 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ class PassContextNode : public Object {
/*! \brief The default optimization level. */
int opt_level{2};

/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};

/*! \brief The list of required passes. */
Array<String> required_pass;
/*! \brief The list of disabled passes. */
Expand Down Expand Up @@ -139,7 +136,6 @@ class PassContextNode : public Object {

void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
v->Visit("config", &config);
Expand All @@ -157,7 +153,6 @@ class PassContextNode : public Object {
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
Expand Down
18 changes: 1 addition & 17 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import functools

import tvm._ffi

import tvm.runtime
from tvm.runtime import ndarray as _nd

from . import _ffi_transform_api

Expand Down Expand Up @@ -61,10 +59,6 @@ class PassContext(tvm.runtime.Object):
opt_level : Optional[int]
The optimization level of this pass.

fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.

required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.

Expand All @@ -76,19 +70,10 @@ class PassContext(tvm.runtime.Object):
"""
def __init__(self,
opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None,
config=None):
if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, tvm.runtime.TVMContext):
fallback_device = fallback_device.device_type
if not isinstance(fallback_device, int):
raise TypeError("fallback_device is expected to be the type of " +
"int/str/TVMContext.")

required = list(required_pass) if required_pass else []
if not isinstance(required, (list, tuple)):
raise TypeError("required_pass is expected to be the type of " +
Expand All @@ -101,8 +86,7 @@ def __init__(self,

config = config if config else None
self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
fallback_device, required,
disabled, trace, config)
required, disabled, trace, config)

def __enter__(self):
_ffi_transform_api.EnterPassContext(self)
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


def build_config(opt_level=2,
fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None):
Expand Down Expand Up @@ -59,10 +58,6 @@ def build_config(opt_level=2,
"FastMath": 4
}

fallback_device : int, str, or tvmContext, optional
The fallback device. It is also used as the default device for
operators without specified device during heterogeneous execution.

required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.

Expand All @@ -77,9 +72,8 @@ def build_config(opt_level=2,
pass_context: PassContext
The pass context for optimizations.
"""
return tvm.ir.transform.PassContext(
opt_level, fallback_device, required_pass,
disabled_pass, trace)
return tvm.ir.transform.PassContext(opt_level, required_pass,
disabled_pass, trace)


@tvm._ffi.register_object("relay.FunctionPass")
Expand Down
7 changes: 2 additions & 5 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(PassContextNode);

TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body_typed([](int opt_level, int fallback_device, Array<String> required,
Array<String> disabled, TraceFunc trace_func,
Optional<Map<std::string, ObjectRef>> config) {
.set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
TraceFunc trace_func, Optional<Map<std::string, ObjectRef>> config) {
auto pctx = PassContext::Create();
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;

pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
Expand All @@ -477,7 +475,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "Pass context information: "
<< "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n";

p->stream << "\trequired passes: [";
for (const auto& it : node->required_pass) {
Expand Down
7 changes: 6 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,12 @@ class RelayBuildModule : public runtime::ModuleNode {
// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
if (targets_.size() > 1) {
relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
Optional<IntImm> opt_fallback_dev =
pass_ctx->GetConfig("relay.fallback_device_type",
IntImm(runtime::DataType::Int(32), static_cast<int>(kDLCPU)));
zhiics marked this conversation as resolved.
Show resolved Hide resolved
auto fallback_dev = opt_fallback_dev.value();
CHECK_GT(fallback_dev->value, 0U);
relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value);
}

// Fuse the operations if it is needed.
Expand Down
2 changes: 2 additions & 0 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace tvm {
namespace relay {
namespace transform {

TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm);

class FunctionPass;

/*!
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/relay_transform_sequential.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ TEST(Relay, Sequential) {
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->opt_level = 3;
pass_ctx->fallback_device = 1;
pass_ctx->config.Set("relay.fallback_device_type", IntImm(DataType::Int(32), 1));
zhiics marked this conversation as resolved.
Show resolved Hide resolved
{
tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
tvm::With<tvm::Target> tctx(tvm::Target::Create("llvm"));
Expand Down
12 changes: 6 additions & 6 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,10 @@ def get_func():
def test_runtime(target, device, func, fallback_device=None,
expected_index=None):
params = {"x": x_data, "y": y_data}
config = {"opt_level": 1}
config = {}
if fallback_device:
config["fallback_device"] = fallback_device
with relay.build_config(**config):
config["relay.fallback_device_type"] = fallback_device.device_type
with tvm.transform.PassContext(opt_level=1, config=config):
graph, lib, params = relay.build(
func,
target,
Expand Down Expand Up @@ -538,9 +538,9 @@ def expected():
expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
config = {"opt_level": 0}
config["fallback_device"] = fallback_device
with relay.build_config(**config):
with tvm.transform.PassContext(opt_level=0,
config={"relay.fallback_device_type":
fallback_device.device_type}):
graph, lib, params = relay.build(annotated_func, target, params=params)
contexts = [tvm.cpu(0), tvm.context(dev)]
graph_json = json.loads(graph)
Expand Down