diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index eba80b955d96..6d7f4bdf7533 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add") Var b = args[1]; *rv = a + b; }); + +TVM_REGISTER_GLOBAL("device_api.ext_dev") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); + }); } // namespace tvm_ext diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index 1cd867d4bb54..0bbfff14eeef 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -1,5 +1,6 @@ import tvm_ext import tvm +import numpy as np def test_bind_add(): def add(a, b): @@ -7,6 +8,24 @@ def add(a, b): f = tvm_ext.bind_add(add, 1) assert f(2) == 3 +def test_ext_dev(): + n = 10 + A = tvm.placeholder((n,), name='A') + B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + f = tvm.build(s, [A, B], "ext_dev", "llvm") + ctx = tvm.ext_dev(0) + # launch the kernel. + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) + f(a, b) + np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1) + check_llvm() + + def test_sym_add(): a = tvm.var('a') b = tvm.var('b') @@ -26,6 +45,7 @@ def ivec_cb(v2): tvm.convert(ivec_cb)(ivec) if __name__ == "__main__": + test_ext_dev() test_ext_vec() test_bind_add() test_sym_add() diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5bf5b72492c4..91175f671a56 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -55,6 +55,9 @@ typedef int64_t tvm_index_t; /*! \brief Extension device types in TVM */ typedef enum { + // Extension DRAM type, used for quickly test extension device + // The device api can differ depending on the xpu driver registered. + kExtDev = 12 // AddExtraTVMType which is not in DLPack here } TVMDeviceExtType; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 4bc20aac7dd6..90ac45988bcb 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -17,7 +17,7 @@ from . import target from . import ndarray as nd -from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm +from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, ext_dev from ._ffi.runtime_ctypes import TypeCode from ._ffi.function import Function diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 9ea8ef579e10..596e13f1f3fe 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure): 4 : 'opencl', 8 : 'metal', 9 : 'vpi', - 10: 'rocm' + 10: 'rocm', + 12: 'ext_dev', } STR2MASK = { 'cpu': 1, @@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure): 'opencl': 4, 'metal': 8, 'vpi': 9, - 'rocm': 10 + 'rocm': 10, + 'ext_dev': 12, } def __init__(self, device_type, device_id): super(TVMContext, self).__init__() diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index d08aa8ad0e85..621f8a727274 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -345,7 +345,7 @@ def build(sch, else: raise ValueError("unknown function type %d" % func.func_type) - if not target.startswith("llvm") and target != "stackvm" and not fdevice: + if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice: warnings.warn( "Specified target %s, but cannot find device code, did you do bind?" % target) diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py index 3ad77cb34743..7b29b1ddac01 100644 --- a/python/tvm/contrib/rpc.py +++ b/python/tvm/contrib/rpc.py @@ -247,6 +247,10 @@ def metal(self, dev_id=0): """Construct remote Metal device.""" return self.context(8, dev_id) + def ext_dev(self, dev_id=0): + """Construct remote extension device.""" + return self.context(12, dev_id) + def upload(self, data, target=None): """Upload file to remote runtime temp folder diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 578f63f08d5d..1556c4912a35 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -120,6 +120,27 @@ def vpi(dev_id=0): """ return TVMContext(9, dev_id) +def ext_dev(dev_id=0): + """Construct a extension device + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + ctx : TVMContext + The created context + + Note + ---- + This API is reserved for quick testing of new + device by plugin device API as ext_dev. + """ + return TVMContext(12, dev_id) + + cl = opencl mtl = metal diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 9ff24f4b761d..ce4a65dc79e2 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -31,6 +31,7 @@ inline std::string DeviceName(int type) { case kMetal: return "metal"; case kVPI: return "vpi"; case kROCM: return "rocm"; + case kExtDev: return "ext_dev"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 3edb1c67c4d4..4fa91dcbeb3f 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI { typedef dmlc::ThreadLocalStore ROCMThreadStore; ROCMThreadEntry::ROCMThreadEntry() - : pool(kGPU, ROCMDeviceAPI::Global()) { + : pool(kROCM, ROCMDeviceAPI::Global()) { } ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index c4fbb4eccac1..bbdd65e4be1c 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -7,7 +7,7 @@ def test_add_pipeline(): A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') - D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C') + D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D') s = tvm.create_schedule(D.op) # GPU schedule have to split by gridIdx and threadIdx @@ -26,11 +26,11 @@ def test_add_pipeline(): stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') - Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') + Db = tvm.decl_buffer(D.shape, D.dtype, name='D') stmt = tvm.ir_pass.LoopPartition(stmt) - stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64) stmt = tvm.ir_pass.Simplify(stmt) - fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True) + fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) @@ -49,10 +49,10 @@ def check_target(device, host="stackvm"): n = 1027 a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) - c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) - f(a, b, c) + d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) + f(a, b, d) np.testing.assert_allclose( - c.asnumpy(), a.asnumpy() + b.asnumpy()) + d.asnumpy(), a.asnumpy() + b.asnumpy() + 1) def check_module_save(device, host="stackvm"): if not tvm.module.enabled(host): @@ -73,10 +73,10 @@ def check_module_save(device, host="stackvm"): n = 1027 a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) - c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) - f(a, b, c) + d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) + f(a, b, d) np.testing.assert_allclose( - c.asnumpy(), a.asnumpy() + b.asnumpy()) + d.asnumpy(), a.asnumpy() + b.asnumpy() + 1) check_target("cuda", host="stackvm") check_target("cuda", host="llvm")