Skip to content

Commit

Permalink
Add 3D parallel_for
Browse files Browse the repository at this point in the history
Swap indices
  • Loading branch information
williamfgc committed Jun 28, 2024
1 parent f62f003 commit 2e2d170
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 22 deletions.
23 changes: 23 additions & 0 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ function JACC.parallel_for(
AMDGPU.synchronize()
end

function JACC.parallel_for(
(M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Lthreads = min(L, numThreads)
Mthreads = min(M, numThreads)
Nthreads = 1
Lblocks = ceil(Int, L / Lthreads)
Mblocks = ceil(Int, M / Mthreads)
Nblocks = ceil(Int, N / Nthreads)
@roc groupsize=(Lthreads, Mthreads, Nthreads) gridsize=(
Lblocks, Mblocks, Nblocks) _parallel_for_amdgpu_LMN(
f, x...)
AMDGPU.synchronize()
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
Expand Down Expand Up @@ -76,6 +91,14 @@ function _parallel_for_amdgpu_MN(f, x...)
return nothing
end

function _parallel_for_amdgpu_LMN(f, x...)
k = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y
i = (workgroupIdx().z - 1) * workgroupDim().z + workitemIdx().z
f(i, j, k, x...)
return nothing
end

function _parallel_reduce_amdgpu(N, ret, f, x...)
shared_mem = @ROCStaticLocalArray(Float64, 512)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
Expand Down
34 changes: 31 additions & 3 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ function JACC.parallel_for(
# f, x...)
end

function JACC.parallel_for(
(L, M, N)::Tuple{I, I, I}, f::F, x...) where {
I <: Integer, F <: Function}
#To use JACC.shared, it is recommended to use a high number of threads per block to maximize the
# potential benefit from using shared memory.
#numThreads = 32
numThreads = 16
Lthreads = min(L, numThreads)
Mthreads = min(M, numThreads)
Nthreads = 1
Lblocks = ceil(Int, L / Lthreads)
Mblocks = ceil(Int, M / Mthreads)
Nblocks = ceil(Int, N / Nthreads)
CUDA.@sync @cuda threads=(Lthreads, Mthreads, Nthreads) blocks=(Lblocks,
Mblocks, Nblocks) _parallel_for_cuda_LMN(f, x...)
# To use JACC.shared, we need to define shmem size using the dynamic shared memory API. The size should be the biggest size of shared memory available for the GPU
#CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(Mblocks, Nblocks) shmem = 4 * numThreads * numThreads * sizeof(Float64) _parallel_for_cuda_MN(
# f, x...)
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
Expand Down Expand Up @@ -79,12 +99,20 @@ function _parallel_for_cuda(N, f, x...)
end

function _parallel_for_cuda_MN(f, x...)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
j = (blockIdx().x - 1) * blockDim().x + threadIdx().x
i = (blockIdx().y - 1) * blockDim().y + threadIdx().y
f(i, j, x...)
return nothing
end

function _parallel_for_cuda_LMN(f, x...)
k = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
i = (blockIdx().z - 1) * blockDim().z + threadIdx().z
f(i, j, k, x...)
return nothing
end

function _parallel_reduce_cuda(N, ret, f, x...)
shared_mem = @cuDynamicSharedMem(Float64, 512)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
Expand Down Expand Up @@ -147,7 +175,7 @@ function reduce_kernel_cuda(N, red, ret)
ii += 512
end
elseif (i <= N)
tmp = @inbounds red[i]
tmp = @inbounds red[i]
end
shared_mem[threadIdx().x] = tmp
sync_threads()
Expand Down
26 changes: 24 additions & 2 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ function JACC.parallel_for(
f, x...)
end

function JACC.parallel_for(
(L, M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
maxPossibleItems = 16
Litems = min(M, maxPossibleItems)
Mitems = min(M, maxPossibleItems)
Nitems = 1
Lgroups = ceil(Int, L / Litems)
Mgroups = ceil(Int, M / Mitems)
Ngroups = ceil(Int, N / Nitems)
oneAPI.@sync @oneapi items=(Litems, Mitems, Nitems) groups=(
Lgroups, Mgroups, Ngroups) _parallel_for_oneapi_LMN(
f, x...)
end

function JACC.parallel_reduce(
N::I, f::F, x...) where {I <: Integer, F <: Function}
numItems = 256
Expand Down Expand Up @@ -64,12 +78,20 @@ function _parallel_for_oneapi(f, x...)
end

function _parallel_for_oneapi_MN(f, x...)
i = get_global_id(0)
j = get_global_id(1)
j = get_global_id(0)
i = get_global_id(1)
f(i, j, x...)
return nothing
end

function _parallel_for_oneapi_LMN(f, x...)
k = get_global_id(0)
j = get_global_id(1)
i = get_global_id(2)
f(i, j, k, x...)
return nothing
end

function _parallel_reduce_oneapi(N, ret, f, x...)
#shared_mem = oneLocalArray(Float32, 256)
shared_mem = oneLocalArray(Float64, 256)
Expand Down
13 changes: 13 additions & 0 deletions src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ function parallel_for(
end
end

function parallel_for(
(L, M, N)::Tuple{I, I, I}, f::F, x...) where {
I <: Integer, F <: Function}
# only threaded at the first level (no collapse equivalent)
@maybe_threaded for k in 1:N
for j in 1:M
for i in 1:L
f(i, j, k, x...)
end
end
end
end

function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
Expand Down
45 changes: 40 additions & 5 deletions test/tests_amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,28 +104,63 @@ end
# @inbounds x[i] += alpha * y[i]
# end
# end

# function seq_dot(N, x, y)
# r = 0.0
# for i in 1:N
# @inbounds r += x[i] * y[i]
# end
# return r
# end

# x = ones(1_000)
# y = ones(1_000)
# jx = JACC.ones(1_000)
# jy = JACC.ones(1_000)
# alpha = 2.0

# seq_axpy(1_000, alpha, x, y)
# ref_result = seq_dot(1_000, x, y)

# JACC.BLAS.axpy(1_000, alpha, jx, jy)
# jresult = JACC.BLAS.dot(1_000, jx, jy)
# result = Array(jresult)

# @test result[1]≈ref_result rtol=1e-8

#end

@testset "Add-2D" begin
function add!(i, j, A, B, C)
@inbounds C[i, j] = A[i, j] + B[i, j]
end

M = 10
N = 10
A = JACC.Array(ones(Float32, M, N))
B = JACC.Array(ones(Float32, M, N))
C = JACC.Array(zeros(Float32, M, N))

JACC.parallel_for((M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, M, N)
@test Array(C)C_expected rtol=1e-5
end

@testset "Add-3D" begin
function add!(i, j, k, A, B, C)
@inbounds C[i, j, k] = A[i, j, k] + B[i, j, k]
end

L = 10
M = 10
N = 10
A = JACC.Array(ones(Float32, L, M, N))
B = JACC.Array(ones(Float32, L, M, N))
C = JACC.Array(zeros(Float32, L, M, N))

JACC.parallel_for((L, M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, L, M, N)
@test Array(C)C_expected rtol=1e-5
end
47 changes: 40 additions & 7 deletions test/tests_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,34 +135,67 @@ end
# end

@testset "JACC.BLAS" begin

function seq_axpy(N, alpha, x, y)
for i in 1:N
@inbounds x[i] += alpha * y[i]
end
end

function seq_dot(N, x, y)
r = 0.0
for i in 1:N
@inbounds r += x[i] * y[i]
end
return r
end

x = ones(1_000)
y = ones(1_000)
jx = JACC.ones(1_000)
jy = JACC.ones(1_000)
alpha = 2.0

seq_axpy(1_000, alpha, x, y)
ref_result = seq_dot(1_000, x, y)

JACC.BLAS.axpy(1_000, alpha, jx, jy)
jresult = JACC.BLAS.dot(1_000, jx, jy)
result = Array(jresult)
result = Array(jresult)

@test result[1]ref_result rtol=1e-8
end

@testset "Add-2D" begin
function add!(i, j, A, B, C)
@inbounds C[i, j] = A[i, j] + B[i, j]
end

M = 10
N = 10
A = JACC.Array(ones(Float32, M, N))
B = JACC.Array(ones(Float32, M, N))
C = JACC.Array(zeros(Float32, M, N))

JACC.parallel_for((M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, M, N)
@test Array(C)C_expected rtol=1e-5
end

@testset "Add-3D" begin
function add!(i, j, k, A, B, C)
@inbounds C[i, j, k] = A[i, j, k] + B[i, j, k]
end

L = 10
M = 10
N = 10
A = JACC.Array(ones(Float32, L, M, N))
B = JACC.Array(ones(Float32, L, M, N))
C = JACC.Array(zeros(Float32, L, M, N))

JACC.parallel_for((L, M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, L, M, N)
@test Array(C)C_expected rtol=1e-5
end
52 changes: 47 additions & 5 deletions test/tests_threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ end
end

@testset "JACC.BLAS" begin

x = ones(1_000)
y = ones(1_000)
jx = JACC.ones(1_000)
Expand All @@ -291,22 +290,65 @@ end
@inbounds x[i] += alpha * y[i]
end
end

function seq_dot(N, x, y)
r = 0.0
for i in 1:N
@inbounds r += x[i] * y[i]
end
return r
end

seq_axpy(1_000, alpha, x, y)
ref_result = seq_dot(1_000, x, y)

JACC.BLAS.axpy(1_000, alpha, jx, jy)
jresult = JACC.BLAS.dot(1_000, jx, jy)
result = jresult[1]
result = jresult[1]

@test resultref_result rtol=1e-8
end

# 2D
@testset "Add-2D" begin
function add!(i, j, A, B, C)
@inbounds C[i, j] = A[i, j] + B[i, j]
end

M = 10
N = 10
A = JACC.Array(ones(Float32, M, N))
B = JACC.Array(ones(Float32, M, N))
C = JACC.Array(zeros(Float32, M, N))

JACC.parallel_for((M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, M, N)
@test CC_expected rtol=1e-5
end

@testset "Add-3D" begin
function add!(i, j, k, A, B, C)
@inbounds C[i, j, k] = A[i, j, k] + B[i, j, k]
end

L = 10
M = 10
N = 10
A = JACC.Array(ones(Float32, L, M, N))
B = JACC.Array(ones(Float32, L, M, N))
C = JACC.Array(zeros(Float32, L, M, N))

for i in 1:L
for j in 1:M
for k in 1:N
C[i, j, k] = A[i, j, k] + B[i, j, k]
end
end
end

JACC.parallel_for((L, M, N), add!, A, B, C)

C_expected = Float32(2.0) .* ones(Float32, L, M, N)
@test CC_expected rtol=1e-5
end

0 comments on commit 2e2d170

Please sign in to comment.