Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA device API & VerifyGPUCode pass update #5898

Merged
merged 4 commits into from
Jun 25, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ enum DeviceAttrKind : int {
kMaxClockRate = 6,
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8,
kGcnArch = 9
kMaxRegistersPerBlock = 9,
kGcnArch = 10
};

/*! \brief Number of bytes each allocation must align to */
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class CUDADeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
case kMaxRegistersPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id));
break;
}
case kGcnArch:
return;
}
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
case kMaxThreadDimensions:
return;
case kExist:
break;
return;
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
*rv = ss.str();
break;
}
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
case kMaxRegistersPerBlock:
return;
case kGcnArch: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
*rv = ss.str();
break;
}
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
}
Expand Down
42 changes: 32 additions & 10 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,22 @@
namespace tvm {
namespace tir {

class GPUCodeVerifier : public StmtVisitor {
class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z) {
int64_t max_thread_z, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);

Reset_();

// TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);

return valid_;
Expand All @@ -62,6 +64,9 @@ class GPUCodeVerifier : public StmtVisitor {
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast<int>(max_vector_bytes_);
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
}
}

void VisitStmt_(const AttrStmtNode* op) final {
Expand Down Expand Up @@ -129,6 +134,17 @@ class GPUCodeVerifier : public StmtVisitor {
}
}

void VisitExpr_(const LoadNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->dtype.lanes() > 1) {
valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast<int>(max_vector_bytes_);
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
}
}
ExprVisitor::VisitExpr_(op);
}

private:
int nest_level_{0};

Expand All @@ -146,6 +162,7 @@ class GPUCodeVerifier : public StmtVisitor {
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;
size_t max_vector_bytes_;

bool valid_{true};

Expand All @@ -169,27 +186,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
int64_t max_vector_bytes = INT64_MAX;

for (auto iter : constraints) {
const IntImmNode* val = iter.second.as<IntImmNode>();
if (iter.first == "max_local_memory_per_block")
if (iter.first == "max_local_memory_per_block") {
max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block")
} else if (iter.first == "max_shared_memory_per_block") {
max_shared_memory_per_block = val->value;
else if (iter.first == "max_threads_per_block")
} else if (iter.first == "max_threads_per_block") {
max_threads_per_block = val->value;
else if (iter.first == "max_thread_x")
} else if (iter.first == "max_thread_x") {
max_thread_x = val->value;
else if (iter.first == "max_thread_y")
} else if (iter.first == "max_thread_y") {
max_thread_y = val->value;
else if (iter.first == "max_thread_z")
} else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
else
} else if (iter.first == "max_vector_bytes") {
max_vector_bytes = val->value;
} else {
LOG(FATAL) << "Invalid check item: " << iter.first;
}
}

return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z);
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
max_vector_bytes);
}

TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_tir_analysis_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,35 @@ def test_wrong_bind():
tvm.build(s, [A, B], target)
assert not valid[0]

def test_vectorize():
N = 1024

A = te.placeholder((N, N), name='A')
B = te.compute((N, N), lambda i, j: A[i, j])

s = te.create_schedule([B.op])

i, j = s[B].op.axis

s[B].bind(i, te.thread_axis("blockIdx.x"))
jo, ji = s[B].split(j, factor=64)
s[B].bind(jo, te.thread_axis("threadIdx.x"))
s[B].vectorize(ji)

for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue

valid = [None]
with tvm.transform.PassContext(config={"tir.add_lower_pass": [
(2, get_verify_pass(valid, max_vector_bytes=16))]}):
tvm.lower(s, [A, B])
assert not valid[0]

if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_num_thread()
test_multiple_kernels()
test_wrong_bind()
test_vectorize()