Skip to content

Commit

Permalink
Use optimized findall on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 10, 2019
1 parent 803bc9c commit df1671b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ end
next_condition = get_condition(integrator, callback, abst)
@. next_sign = sign(next_condition)

event_idx = findall(x-> ((prev_sign[x] < 0 && callback.affect! !== nothing) || (prev_sign[x] > 0 && callback.affect_neg! !== nothing)) && prev_sign[x]*next_sign[x]<=0, keys(prev_sign))
event_idx = findall_events(affect!,affect_neg!,prev_sign,next_sign)
if length(event_idx) != 0
event_occurred = true
interp_index = callback.interp_points
Expand All @@ -310,7 +310,7 @@ end
for i in 2:length(Θs)
abst = integrator.tprev+integrator.dt*Θs[i]
new_sign = get_condition(integrator, callback, abst)
_event_idx = findall(x -> ((prev_sign[x] < 0 && callback.affect! !== nothing) || (prev_sign[x] > 0 && callback.affect_neg! !== nothing)) && prev_sign[x]*new_sign[x]<0, keys(prev_sign))
_event_idx = findall_events(affect!,affect_neg!,prev_sign,new_sign)
if length(_event_idx) != 0
event_occurred = true
event_idx = _event_idx
Expand Down Expand Up @@ -398,6 +398,11 @@ end
event_occurred,interp_index,Θs,prev_sign,prev_sign_index,event_idx
end

## Different definition for GPUs
function findall_events(affect!,affect_neg,prev_sign,next_sign)
findall(x-> ((prev_sign[x] < 0 && affect! !== nothing) || (prev_sign[x] > 0 && affect_neg! !== nothing)) && prev_sign[x]*next_sign[x]<=0, keys(prev_sign))
end

function find_callback_time(integrator,callback::ContinuousCallback,counter)
event_occurred,interp_index,Θs,prev_sign,prev_sign_index,event_idx = determine_event_occurance(integrator,callback,counter)
if event_occurred
Expand Down
8 changes: 8 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ function __init__()
function LinearAlgebra.ldiv!(x::CuArrays.CuArray,_qr::CuArrays.CUSOLVER.CuQR,b::CuArrays.CuArray)
_x = UpperTriangular(_qr.R) \ (_qr.Q' * reshape(b,length(b),1))
x .= vec(_x)
unsafe_free!(_x)
end
function findall_events(affect!,affect_neg,prev_sign::CuArrays.CuArray,next_sign::CuArrays.CuArray)
f = (prev_sign,next_sign)-> ((prev_sign < 0 && affect! !== nothing) || (prev_sign > 0 && affect_neg! !== nothing)) && prev_sign*next_sign<=0
A = map(f,prev_sign,next_sign)
out = findall(A)
unsafe_free!(A)
out
end
end
end

0 comments on commit df1671b

Please sign in to comment.