Skip to content

Commit

Permalink
handle runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Apr 2, 2020
1 parent bac5c24 commit a53ae94
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,6 @@ def lower(sch,
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)


@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target

# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)


def _build_for_device(flist, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Expand All @@ -231,6 +221,15 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target

# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)

target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
Expand Down

0 comments on commit a53ae94

Please sign in to comment.