Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Enable ext_dev type for quick plugin of device #542

Merged
merged 2 commits into from
Oct 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add")
Var b = args[1];
*rv = a + b;
});

TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});
} // namespace tvm_ext
20 changes: 20 additions & 0 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import tvm_ext
import tvm
import numpy as np

def test_bind_add():
def add(a, b):
return a + b
f = tvm_ext.bind_add(add, 1)
assert f(2) == 3

def test_ext_dev():
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_llvm():
if not tvm.module.enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
ctx = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
check_llvm()


def test_sym_add():
a = tvm.var('a')
b = tvm.var('b')
Expand All @@ -26,6 +45,7 @@ def ivec_cb(v2):
tvm.convert(ivec_cb)(ivec)

if __name__ == "__main__":
test_ext_dev()
test_ext_vec()
test_bind_add()
test_sym_add()
3 changes: 3 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ typedef int64_t tvm_index_t;

/*! \brief Extension device types in TVM */
typedef enum {
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import target

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, ext_dev

from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import Function
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure):
4 : 'opencl',
8 : 'metal',
9 : 'vpi',
10: 'rocm'
10: 'rocm',
12: 'ext_dev',
}
STR2MASK = {
'cpu': 1,
Expand All @@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure):
'opencl': 4,
'metal': 8,
'vpi': 9,
'rocm': 10
'rocm': 10,
'ext_dev': 12,
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)

if not target.startswith("llvm") and target != "stackvm" and not fdevice:
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)

Expand Down
4 changes: 4 additions & 0 deletions python/tvm/contrib/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)

def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)

def upload(self, data, target=None):
"""Upload file to remote runtime temp folder

Expand Down
21 changes: 21 additions & 0 deletions python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ def vpi(dev_id=0):
"""
return TVMContext(9, dev_id)

def ext_dev(dev_id=0):
"""Construct a extension device

Parameters
----------
dev_id : int, optional
The integer device id

Returns
-------
ctx : TVMContext
The created context

Note
----
This API is reserved for quick testing of new
device by plugin device API as ext_dev.
"""
return TVMContext(12, dev_id)


cl = opencl
mtl = metal

Expand Down
1 change: 1 addition & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ inline std::string DeviceName(int type) {
case kMetal: return "metal";
case kVPI: return "vpi";
case kROCM: return "rocm";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;

ROCMThreadEntry::ROCMThreadEntry()
: pool(kGPU, ROCMDeviceAPI::Global()) {
: pool(kROCM, ROCMDeviceAPI::Global()) {
}

ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
Expand Down
20 changes: 10 additions & 10 deletions tests/python/unittest/test_codegen_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_add_pipeline():
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C')
D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
s = tvm.create_schedule(D.op)

# GPU schedule have to split by gridIdx and threadIdx
Expand All @@ -26,11 +26,11 @@ def test_add_pipeline():
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])

Expand All @@ -49,10 +49,10 @@ def check_target(device, host="stackvm"):
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)

def check_module_save(device, host="stackvm"):
if not tvm.module.enabled(host):
Expand All @@ -73,10 +73,10 @@ def check_module_save(device, host="stackvm"):
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)

check_target("cuda", host="stackvm")
check_target("cuda", host="llvm")
Expand Down