From f4145fe453a668d3a04a8826679138ae49a0f9c0 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 23 Jun 2020 17:42:57 +0800 Subject: [PATCH 1/4] Add kMaxRegistersPerBlock device api for cuda --- include/tvm/runtime/device_api.h | 3 ++- src/runtime/cuda/cuda_device_api.cc | 4 ++++ src/runtime/metal/metal_device_api.mm | 4 +++- src/runtime/opencl/opencl_device_api.cc | 2 ++ src/runtime/rocm/rocm_device_api.cc | 2 ++ src/runtime/vulkan/vulkan.cc | 2 ++ 6 files changed, 15 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 421811a52c3b..3cf5566f3231 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -44,7 +44,8 @@ enum DeviceAttrKind : int { kMaxClockRate = 6, kMultiProcessorCount = 7, kMaxThreadDimensions = 8, - kGcnArch = 9 + kMaxRegistersPerBlock = 9, + kGcnArch = 10 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index a6d4a5499469..ccd8e91e0c5d 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -92,6 +92,10 @@ class CUDADeviceAPI final : public DeviceAPI { *rv = ss.str(); return; } + case kMaxRegistersPerBlock: { + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id)); + break; + } case kGcnArch: return; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 3bad2c3e9deb..a64f35ced2c2 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -64,7 +64,9 @@ case kMaxThreadDimensions: return; case kExist: - break; + return; + case kMaxRegistersPerBlock: + return; case kGcnArch: return; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 6d9835e6231c..72d03fb6a4fc 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -107,6 +107,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* *rv = ss.str(); break; } + case kMaxRegistersPerBlock: + return; case kGcnArch: return; } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 475c4fbffadc..e3dbef5ff42a 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -102,6 +102,8 @@ class ROCMDeviceAPI final : public DeviceAPI { *rv = ss.str(); return; } + case kMaxRegistersPerBlock: + return; case kGcnArch: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 44810116c3c2..ade4ddca9376 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -413,6 +413,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* *rv = ss.str(); break; } + case kMaxRegistersPerBlock: + return; case kGcnArch: return; } From 53f43a5d71bbbd28b3f80a2be8575340045d4c63 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 23 Jun 2020 18:11:31 +0800 Subject: [PATCH 2/4] Add vectorize check to verify_gpu_code --- src/tir/analysis/verify_gpu_code.cc | 43 ++++++++++++++----- .../test_tir_analysis_verify_gpu_code.py | 25 +++++++++++ 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 1fbae0fd2dcd..987ba89df46a 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -33,20 +33,22 @@ namespace tvm { namespace tir { -class GPUCodeVerifier : public StmtVisitor { +class GPUCodeVerifier : public StmtExprVisitor { public: bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z) { + int64_t max_thread_z, int64_t max_vector_bytes) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); max_thread_x_ = static_cast(max_thread_x); max_thread_y_ = static_cast(max_thread_y); max_thread_z_ = static_cast(max_thread_z); + max_vector_bytes_ = static_cast(max_vector_bytes); Reset_(); + // TODO(jcf94): Add support of detecting CUDA Misaligned Address error this->VisitStmt(stmt); return valid_; @@ -62,6 +64,9 @@ class GPUCodeVerifier : public StmtVisitor { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } + if (op->dtype.lanes() > 1) { + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); + } } void VisitStmt_(const AttrStmtNode* op) final { @@ -129,6 +134,18 @@ class GPUCodeVerifier : public StmtVisitor { } } + void VisitExpr_(const LoadNode* op) { + // Currently not able to check out: If the index expression failed + // to be simplified to a RampNode + if (op->index->IsInstance()) { + if (op->dtype.lanes() > 1) { + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= + static_cast(max_vector_bytes_); + } + } + ExprVisitor::VisitExpr_(op); + } + private: int nest_level_{0}; @@ -146,6 +163,7 @@ class GPUCodeVerifier : public StmtVisitor { size_t max_shared_memory_per_block_; size_t max_threads_per_block_; size_t max_thread_x_, max_thread_y_, max_thread_z_; + size_t max_vector_bytes_; bool valid_{true}; @@ -169,27 +187,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { int64_t max_thread_x = INT64_MAX; int64_t max_thread_y = INT64_MAX; int64_t max_thread_z = INT64_MAX; + int64_t max_vector_bytes = INT64_MAX; for (auto iter : constraints) { const IntImmNode* val = iter.second.as(); - if (iter.first == "max_local_memory_per_block") + if (iter.first == "max_local_memory_per_block") { max_local_memory_per_block = val->value; - else if (iter.first == "max_shared_memory_per_block") + } else if (iter.first == "max_shared_memory_per_block") { max_shared_memory_per_block = val->value; - else if (iter.first == "max_threads_per_block") + } else if (iter.first == "max_threads_per_block") { max_threads_per_block = val->value; - else if (iter.first == "max_thread_x") + } else if (iter.first == "max_thread_x") { max_thread_x = val->value; - else if (iter.first == "max_thread_y") + } else if (iter.first == "max_thread_y") { max_thread_y = val->value; - else if (iter.first == "max_thread_z") + } else if (iter.first == "max_thread_z") { max_thread_z = val->value; - else + } else if (iter.first == "max_vector_bytes") { + max_vector_bytes = val->value; + } else { LOG(FATAL) << "Invalid check item: " << iter.first; + } } return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, - max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, + max_vector_bytes); } TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py index 11960cad04d4..ece8402a77ce 100644 --- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -208,6 +208,30 @@ def test_wrong_bind(): tvm.build(s, [A, B], target) assert not valid[0] +def test_vectorize(): + N = 1024 + + A = te.placeholder((N, N), name='A') + B = te.compute((N, N), lambda i, j: A[i, j]) + + s = te.create_schedule([B.op]) + + i, j = s[B].op.axis + + s[B].bind(i, te.thread_axis("blockIdx.x")) + jo, ji = s[B].split(j, factor=64) + s[B].bind(jo, te.thread_axis("threadIdx.x")) + s[B].vectorize(ji) + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + + valid = [None] + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ + (2, get_verify_pass(valid, max_vector_bytes=16))]}): + tvm.lower(s, [A, B]) + assert not valid[0] if __name__ == "__main__": test_local_memory() @@ -215,3 +239,4 @@ def test_wrong_bind(): test_num_thread() test_multiple_kernels() test_wrong_bind() + test_vectorize() From 736441aee08b60f56ce7db93cd8246c1ab42892c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 13:35:29 +0800 Subject: [PATCH 3/4] Lint fix --- src/tir/analysis/verify_gpu_code.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 987ba89df46a..0bf3fff33367 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -139,8 +139,7 @@ class GPUCodeVerifier : public StmtExprVisitor { // to be simplified to a RampNode if (op->index->IsInstance()) { if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= - static_cast(max_vector_bytes_); + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); } } ExprVisitor::VisitExpr_(op); From ca60784c17a55dbf95dcf7601270c6f1d465e7b1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 25 Jun 2020 09:34:54 +0800 Subject: [PATCH 4/4] Cast fix --- src/tir/analysis/verify_gpu_code.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 0bf3fff33367..9477e044fc33 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -65,7 +65,7 @@ class GPUCodeVerifier : public StmtExprVisitor { shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); + valid_ &= static_cast(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_; } } @@ -139,7 +139,7 @@ class GPUCodeVerifier : public StmtExprVisitor { // to be simplified to a RampNode if (op->index->IsInstance()) { if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); + valid_ &= static_cast(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_; } } ExprVisitor::VisitExpr_(op);