diff --git a/ext/JACCAMDGPU/JACCAMDGPU.jl b/ext/JACCAMDGPU/JACCAMDGPU.jl index dafc620..8690c7f 100644 --- a/ext/JACCAMDGPU/JACCAMDGPU.jl +++ b/ext/JACCAMDGPU/JACCAMDGPU.jl @@ -29,6 +29,21 @@ function JACC.parallel_for( AMDGPU.synchronize() end +function JACC.parallel_for( + (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} + numThreads = 16 + Lthreads = min(L, numThreads) + Mthreads = min(M, numThreads) + Nthreads = 1 + Lblocks = ceil(Int, L / Lthreads) + Mblocks = ceil(Int, M / Mthreads) + Nblocks = ceil(Int, N / Nthreads) + @roc groupsize=(Lthreads, Mthreads, Nthreads) gridsize=( + Lblocks, Mblocks, Nblocks) _parallel_for_amdgpu_LMN( + f, x...) + AMDGPU.synchronize() +end + function JACC.parallel_reduce( N::I, f::F, x...) where {I <: Integer, F <: Function} numThreads = 512 @@ -76,6 +91,14 @@ function _parallel_for_amdgpu_MN(f, x...) return nothing end +function _parallel_for_amdgpu_LMN(f, x...) + k = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x + j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y + i = (workgroupIdx().z - 1) * workgroupDim().z + workitemIdx().z + f(i, j, k, x...) + return nothing +end + function _parallel_reduce_amdgpu(N, ret, f, x...) shared_mem = @ROCStaticLocalArray(Float64, 512) i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x diff --git a/ext/JACCCUDA/JACCCUDA.jl b/ext/JACCCUDA/JACCCUDA.jl index 21b5dc9..8574f49 100644 --- a/ext/JACCCUDA/JACCCUDA.jl +++ b/ext/JACCCUDA/JACCCUDA.jl @@ -37,6 +37,26 @@ function JACC.parallel_for( # f, x...) end +function JACC.parallel_for( + (L, M, N)::Tuple{I, I, I}, f::F, x...) where { + I <: Integer, F <: Function} + #To use JACC.shared, it is recommended to use a high number of threads per block to maximize the + # potential benefit from using shared memory. + #numThreads = 32 + numThreads = 16 + Lthreads = min(L, numThreads) + Mthreads = min(M, numThreads) + Nthreads = 1 + Lblocks = ceil(Int, L / Lthreads) + Mblocks = ceil(Int, M / Mthreads) + Nblocks = ceil(Int, N / Nthreads) + CUDA.@sync @cuda threads=(Lthreads, Mthreads, Nthreads) blocks=(Lblocks, + Mblocks, Nblocks) _parallel_for_cuda_LMN(f, x...) + # To use JACC.shared, we need to define shmem size using the dynamic shared memory API. The size should be the biggest size of shared memory available for the GPU + #CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(Mblocks, Nblocks) shmem = 4 * numThreads * numThreads * sizeof(Float64) _parallel_for_cuda_MN( + # f, x...) +end + function JACC.parallel_reduce( N::I, f::F, x...) where {I <: Integer, F <: Function} numThreads = 512 @@ -79,12 +99,20 @@ function _parallel_for_cuda(N, f, x...) end function _parallel_for_cuda_MN(f, x...) - i = (blockIdx().x - 1) * blockDim().x + threadIdx().x - j = (blockIdx().y - 1) * blockDim().y + threadIdx().y + j = (blockIdx().x - 1) * blockDim().x + threadIdx().x + i = (blockIdx().y - 1) * blockDim().y + threadIdx().y f(i, j, x...) return nothing end +function _parallel_for_cuda_LMN(f, x...) + k = (blockIdx().x - 1) * blockDim().x + threadIdx().x + j = (blockIdx().y - 1) * blockDim().y + threadIdx().y + i = (blockIdx().z - 1) * blockDim().z + threadIdx().z + f(i, j, k, x...) + return nothing +end + function _parallel_reduce_cuda(N, ret, f, x...) shared_mem = @cuDynamicSharedMem(Float64, 512) i = (blockIdx().x - 1) * blockDim().x + threadIdx().x @@ -147,7 +175,7 @@ function reduce_kernel_cuda(N, red, ret) ii += 512 end elseif (i <= N) - tmp = @inbounds red[i] + tmp = @inbounds red[i] end shared_mem[threadIdx().x] = tmp sync_threads() diff --git a/ext/JACCONEAPI/JACCONEAPI.jl b/ext/JACCONEAPI/JACCONEAPI.jl index f138383..f4bbc08 100644 --- a/ext/JACCONEAPI/JACCONEAPI.jl +++ b/ext/JACCONEAPI/JACCONEAPI.jl @@ -28,6 +28,20 @@ function JACC.parallel_for( f, x...) end +function JACC.parallel_for( + (L, M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} + maxPossibleItems = 16 + Litems = min(M, maxPossibleItems) + Mitems = min(M, maxPossibleItems) + Nitems = 1 + Lgroups = ceil(Int, L / Litems) + Mgroups = ceil(Int, M / Mitems) + Ngroups = ceil(Int, N / Nitems) + oneAPI.@sync @oneapi items=(Litems, Mitems, Nitems) groups=( + Lgroups, Mgroups, Ngroups) _parallel_for_oneapi_LMN( + f, x...) +end + function JACC.parallel_reduce( N::I, f::F, x...) where {I <: Integer, F <: Function} numItems = 256 @@ -64,12 +78,20 @@ function _parallel_for_oneapi(f, x...) end function _parallel_for_oneapi_MN(f, x...) - i = get_global_id(0) - j = get_global_id(1) + j = get_global_id(0) + i = get_global_id(1) f(i, j, x...) return nothing end +function _parallel_for_oneapi_LMN(f, x...) + k = get_global_id(0) + j = get_global_id(1) + i = get_global_id(2) + f(i, j, k, x...) + return nothing +end + function _parallel_reduce_oneapi(N, ret, f, x...) #shared_mem = oneLocalArray(Float32, 256) shared_mem = oneLocalArray(Float64, 256) diff --git a/src/JACC.jl b/src/JACC.jl index bb7b474..8d56370 100644 --- a/src/JACC.jl +++ b/src/JACC.jl @@ -34,6 +34,19 @@ function parallel_for( end end +function parallel_for( + (L, M, N)::Tuple{I, I, I}, f::F, x...) where { + I <: Integer, F <: Function} + # only threaded at the first level (no collapse equivalent) + @maybe_threaded for k in 1:N + for j in 1:M + for i in 1:L + f(i, j, k, x...) + end + end + end +end + function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} tmp = zeros(Threads.nthreads()) ret = zeros(1) diff --git a/test/tests_amdgpu.jl b/test/tests_amdgpu.jl index f5bd322..44e9a63 100644 --- a/test/tests_amdgpu.jl +++ b/test/tests_amdgpu.jl @@ -104,7 +104,7 @@ end # @inbounds x[i] += alpha * y[i] # end # end - + # function seq_dot(N, x, y) # r = 0.0 # for i in 1:N @@ -112,20 +112,55 @@ end # end # return r # end - + # x = ones(1_000) # y = ones(1_000) # jx = JACC.ones(1_000) # jy = JACC.ones(1_000) # alpha = 2.0 - + # seq_axpy(1_000, alpha, x, y) # ref_result = seq_dot(1_000, x, y) - + # JACC.BLAS.axpy(1_000, alpha, jx, jy) # jresult = JACC.BLAS.dot(1_000, jx, jy) # result = Array(jresult) - + # @test result[1]≈ref_result rtol=1e-8 #end + +@testset "Add-2D" begin + function add!(i, j, A, B, C) + @inbounds C[i, j] = A[i, j] + B[i, j] + end + + M = 10 + N = 10 + A = JACC.Array(ones(Float32, M, N)) + B = JACC.Array(ones(Float32, M, N)) + C = JACC.Array(zeros(Float32, M, N)) + + JACC.parallel_for((M, N), add!, A, B, C) + + C_expected = Float32(2.0) .* ones(Float32, M, N) + @test Array(C)≈C_expected rtol=1e-5 +end + +@testset "Add-3D" begin + function add!(i, j, k, A, B, C) + @inbounds C[i, j, k] = A[i, j, k] + B[i, j, k] + end + + L = 10 + M = 10 + N = 10 + A = JACC.Array(ones(Float32, L, M, N)) + B = JACC.Array(ones(Float32, L, M, N)) + C = JACC.Array(zeros(Float32, L, M, N)) + + JACC.parallel_for((L, M, N), add!, A, B, C) + + C_expected = Float32(2.0) .* ones(Float32, L, M, N) + @test Array(C)≈C_expected rtol=1e-5 +end \ No newline at end of file diff --git a/test/tests_cuda.jl b/test/tests_cuda.jl index a1686be..f29ccee 100644 --- a/test/tests_cuda.jl +++ b/test/tests_cuda.jl @@ -135,13 +135,12 @@ end # end @testset "JACC.BLAS" begin - function seq_axpy(N, alpha, x, y) for i in 1:N @inbounds x[i] += alpha * y[i] end end - + function seq_dot(N, x, y) r = 0.0 for i in 1:N @@ -149,20 +148,54 @@ end end return r end - + x = ones(1_000) y = ones(1_000) jx = JACC.ones(1_000) jy = JACC.ones(1_000) alpha = 2.0 - + seq_axpy(1_000, alpha, x, y) ref_result = seq_dot(1_000, x, y) - + JACC.BLAS.axpy(1_000, alpha, jx, jy) jresult = JACC.BLAS.dot(1_000, jx, jy) - result = Array(jresult) - + result = Array(jresult) + @test result[1]≈ref_result rtol=1e-8 +end + +@testset "Add-2D" begin + function add!(i, j, A, B, C) + @inbounds C[i, j] = A[i, j] + B[i, j] + end + + M = 10 + N = 10 + A = JACC.Array(ones(Float32, M, N)) + B = JACC.Array(ones(Float32, M, N)) + C = JACC.Array(zeros(Float32, M, N)) + + JACC.parallel_for((M, N), add!, A, B, C) + + C_expected = Float32(2.0) .* ones(Float32, M, N) + @test Array(C)≈C_expected rtol=1e-5 +end + +@testset "Add-3D" begin + function add!(i, j, k, A, B, C) + @inbounds C[i, j, k] = A[i, j, k] + B[i, j, k] + end + + L = 10 + M = 10 + N = 10 + A = JACC.Array(ones(Float32, L, M, N)) + B = JACC.Array(ones(Float32, L, M, N)) + C = JACC.Array(zeros(Float32, L, M, N)) + + JACC.parallel_for((L, M, N), add!, A, B, C) + C_expected = Float32(2.0) .* ones(Float32, L, M, N) + @test Array(C)≈C_expected rtol=1e-5 end diff --git a/test/tests_threads.jl b/test/tests_threads.jl index 80953a0..404bdcf 100644 --- a/test/tests_threads.jl +++ b/test/tests_threads.jl @@ -279,7 +279,6 @@ end end @testset "JACC.BLAS" begin - x = ones(1_000) y = ones(1_000) jx = JACC.ones(1_000) @@ -291,7 +290,7 @@ end @inbounds x[i] += alpha * y[i] end end - + function seq_dot(N, x, y) r = 0.0 for i in 1:N @@ -299,14 +298,57 @@ end end return r end - + seq_axpy(1_000, alpha, x, y) ref_result = seq_dot(1_000, x, y) JACC.BLAS.axpy(1_000, alpha, jx, jy) jresult = JACC.BLAS.dot(1_000, jx, jy) - result = jresult[1] - + result = jresult[1] + @test result≈ref_result rtol=1e-8 +end + +# 2D +@testset "Add-2D" begin + function add!(i, j, A, B, C) + @inbounds C[i, j] = A[i, j] + B[i, j] + end + + M = 10 + N = 10 + A = JACC.Array(ones(Float32, M, N)) + B = JACC.Array(ones(Float32, M, N)) + C = JACC.Array(zeros(Float32, M, N)) + + JACC.parallel_for((M, N), add!, A, B, C) + + C_expected = Float32(2.0) .* ones(Float32, M, N) + @test C≈C_expected rtol=1e-5 +end + +@testset "Add-3D" begin + function add!(i, j, k, A, B, C) + @inbounds C[i, j, k] = A[i, j, k] + B[i, j, k] + end + + L = 10 + M = 10 + N = 10 + A = JACC.Array(ones(Float32, L, M, N)) + B = JACC.Array(ones(Float32, L, M, N)) + C = JACC.Array(zeros(Float32, L, M, N)) + + for i in 1:L + for j in 1:M + for k in 1:N + C[i, j, k] = A[i, j, k] + B[i, j, k] + end + end + end + + JACC.parallel_for((L, M, N), add!, A, B, C) + C_expected = Float32(2.0) .* ones(Float32, L, M, N) + @test C≈C_expected rtol=1e-5 end