Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge pull request #446 from JuliaGPU/tb/findall
Browse files Browse the repository at this point in the history
Implement findall
  • Loading branch information
maleadt authored Oct 10, 2019
2 parents 51c33d7 + 9459fce commit a794963
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 23 deletions.
42 changes: 23 additions & 19 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,37 @@ function Base._accumulate!(op::Function, vout::CuVector{T}, v::CuVector, dims::N
Δ = 1 # Δ = 2^d
n = ceil(Int, log2(length(v)))

num_threads = 256
num_blocks = ceil(Int, length(v) / num_threads)
# partial in-place accumulation
function kernel(op, vout, vin, Δ)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

for d in 0:n # passes through data
@cuda blocks=num_blocks threads=num_threads _partial_accumulate!(op, vout, vin, Δ)
@inbounds if i <= length(vin)
if i > Δ
vout[i] = op(vin[i - Δ], vin[i])
else
vout[i] = vin[i]
end
end

vin, vout = vout, vin
Δ *= 2
return
end

return vin
end
function configurator(kernel)
fun = kernel.fun
config = launch_configuration(fun)
blocks = cld(length(v), config.threads)

function _partial_accumulate!(op, vout, vin, Δ)
@inbounds begin
k = threadIdx().x + (blockIdx().x - 1) * blockDim().x
return (threads=config.threads, blocks=blocks)
end

if k <= length(vin)
if k > Δ
vout[k] = op(vin[k - Δ], vin[k])
else
vout[k] = vin[k]
end
end
for d in 0:n # passes through data
@cuda config=configurator kernel(op, vout, vin, Δ)

vin, vout = vout, vin
Δ *= 2
end

return
return vin
end

Base.accumulate_pairwise!(op, result::CuVector, v::CuVector) = accumulate!(op, result, v)
63 changes: 59 additions & 4 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
ys = CuArray{T}(undef, n)

if n > 0
num_threads = min(n, 256)
num_blocks = ceil(Int, length(indices) / num_threads)

function kernel(ys::CuDeviceArray{T}, xs::CuDeviceArray{T}, bools, indices)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand All @@ -38,8 +35,66 @@ function Base.getindex(xs::CuArray{T}, bools::CuArray{Bool}) where {T}
return
end

@cuda blocks=num_blocks threads=num_threads kernel(ys, xs, bools, indices)
function configurator(kernel)
fun = kernel.fun
config = launch_configuration(fun)
blocks = cld(length(indices), config.threads)

return (threads=config.threads, blocks=blocks)
end

@cuda config=configurator kernel(ys, xs, bools, indices)
end

unsafe_free!(indices)

return ys
end


## findall

function Base.findall(bools::CuArray{Bool})
indices = cumsum(bools)

n = _getindex(indices, length(indices))
ys = CuArray{Int}(undef, n)

if n > 0
num_threads = min(n, 256)
num_blocks = ceil(Int, length(indices) / num_threads)

function kernel(ys::CuDeviceArray{Int}, bools, indices)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

if i <= length(bools) && bools[i]
b = indices[i] # new position
ys[b] = i

end

return
end

function configurator(kernel)
fun = kernel.fun
config = launch_configuration(fun)
blocks = cld(length(indices), config.threads)

return (threads=config.threads, blocks=blocks)
end

@cuda config=configurator kernel(ys, bools, indices)
end

unsafe_free!(indices)

return ys
end

function Base.findall(f::Function, A::CuArray)
bools = map(f, A)
ys = findall(bools)
unsafe_free!(bools)
return ys
end
5 changes: 5 additions & 0 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,8 @@ end
inds = rand(1:100, 150, 150)
@test testf(x->permutedims(view(x, inds, :), (3, 2, 1)), rand(100, 100))
end

@testset "findall" begin
@test testf(x->findall(x), rand(Bool, 100))
@test testf(x->findall(y->y>0.5, x), rand(100))
end

0 comments on commit a794963

Please sign in to comment.