Skip to content

Commit

Permalink
[PTX] Intrinsics for async copy from global to shared (SM80) (#11368)
Browse files Browse the repository at this point in the history
* registor ptx builtin for async copy

* add basic codegen

* add test

* update codegen

* wip

* codegen bug fixed, test working

* add commit group

* add doc
  • Loading branch information
masahi authored May 20, 2022
1 parent a6a3404 commit 7e99d30
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 0 deletions.
19 changes: 19 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,25 @@ TVM_DLL const Op& ptx_mma_sp();
*/
TVM_DLL const Op& ptx_ldmatrix();

/*!
* \brief tvm intrinsics for ptx async copy from global to shared memory
*
* void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
* bytes);
*
*/
TVM_DLL const Op& ptx_cp_async();

/*!
* \brief tvm intrinsics for ptx async copy commit and wait.
*
* void ptx_commit_group();
* void ptx_wait_group(int num);
*
*/
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
12 changes: 12 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
} else if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
std::string N = this->PrintExpr(op->args[0]);
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
26 changes: 26 additions & 0 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,5 +638,31 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
return asm_code;
}

std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes) {
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"cp.async.cg.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
)";
Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}

} // namespace codegen
} // namespace tvm
13 changes: 13 additions & 0 deletions src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
const std::string& smem_ptr,
const std::string& smem_elem_offset);

/*!
* \brief Print ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
*/
std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes);

} // namespace codegen
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

Expand Down
70 changes: 70 additions & 0 deletions tests/python/unittest/test_tir_ptx_cp_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.script import tir as T
import numpy as np
import tvm.testing


@T.prim_func
def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(bx, 1)
T.launch_thread(tx, 32)
with T.block():
A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
T.reads(A[0:32, 0:128])
T.writes(B[0:32, 0:128])

for i in range(16):
T.evaluate(
T.ptx_cp_async(
A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16"
)
)

# TODO(masahi): Remove dtype requirement from TVMScript parser
T.evaluate(T.ptx_commit_group(dtype="float16"))
T.evaluate(T.ptx_wait_group(0, dtype="float16"))

for i in range(128):
B[tx, i] = A_shared[tx, i]


@tvm.testing.requires_cuda
def test_ptx_cp_async():
f = ptx_cp_async
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Require at least SM80
return

mod = tvm.build(f, target="cuda")
A_np = np.random.rand(32, 128).astype("float16")
B_np = np.zeros((32, 128)).astype("float16")
dev = tvm.cuda(0)
A_nd = tvm.nd.array(A_np, device=dev)
B_nd = tvm.nd.array(B_np, device=dev)
mod(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.numpy(), A_np)


if __name__ == "__main__":
test_ptx_cp_async()

0 comments on commit 7e99d30

Please sign in to comment.