Skip to content

Commit

Permalink
WorkerPool and AsyncCollector - minor fixes (#16325)
Browse files Browse the repository at this point in the history
* workerpool fixes - test workers and default pool when master is a worker.
* AsyncCollector - error detection, default # of tasks changes.
  • Loading branch information
amitmurthy committed May 13, 2016
1 parent afd457d commit 69d05ae
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
61 changes: 42 additions & 19 deletions base/asyncmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@


"""
AsyncCollector(f, results, c...; ntasks=100) -> iterator
AsyncCollector(f, results, c...; ntasks=0) -> iterator
Apply f to each element of c using at most 100 asynchronous tasks.
For multiple collection arguments, apply f elementwise.
Output is collected into "results".
Apply `f` to each element of `c` using at most `ntasks` asynchronous
tasks.
If `ntasks` is unspecified, uses `max(100, nworkers())` tasks.
For multiple collection arguments, apply `f` elementwise.
Output is collected into `results`.
Note: `next(::AsyncCollector, state) -> (nothing, state)`
Expand All @@ -22,17 +24,17 @@ end

function AsyncCollector(f, results, c...; ntasks=0)
if ntasks == 0
ntasks = 100
ntasks = max(nworkers(), 100)
end
AsyncCollector(f, results, enumerate(zip(c...)), ntasks)
end


type AsyncCollectorState
enum_state
active_count::Int
task_done::Condition
done::Bool
in_error::Bool
end


Expand All @@ -49,11 +51,20 @@ wait(state::AsyncCollectorState) = wait(state.task_done)
# Open a @sync block and initialise iterator state.
function start(itr::AsyncCollector)
sync_begin()
AsyncCollectorState(start(itr.enumerator), 0, Condition(), false)
AsyncCollectorState(start(itr.enumerator), 0, Condition(), false, false)
end

# Close @sync block when iterator is done.
function done(itr::AsyncCollector, state::AsyncCollectorState)
if state.in_error
sync_end()

# state.in_error is only being set in the @async block (and an error thrown),
# which in turn should have been caught and thrown by the sync_end() call above.
# Control should not come here.
@assert false "Error should have been captured and thrown previously."
end

if !state.done && done(itr.enumerator, state.enum_state)
state.done = true
sync_end()
Expand All @@ -62,32 +73,44 @@ function done(itr::AsyncCollector, state::AsyncCollectorState)
end

function next(itr::AsyncCollector, state::AsyncCollectorState)
# Wait if the maximum number of concurrent tasks are already running
# Wait if the maximum number of concurrent tasks are already running.
while isbusy(itr, state)
wait(state)
if state.in_error
# Stop processing immediately on error.
return (nothing, state)
end
end

# Get index and mapped function arguments from enumeration iterator
# Get index and mapped function arguments from enumeration iterator.
(i, args), state.enum_state = next(itr.enumerator, state.enum_state)

# Execute function call and save result asynchronously
# Execute function call and save result asynchronously.
@async begin
itr.results[i] = itr.f(args...)
state.active_count -= 1
notify(state.task_done, nothing)
try
itr.results[i] = itr.f(args...)
catch e
# The in_error flag causes done() to end the iteration early and call sync_end().
# sync_end() then re-throws "e" in the main task.
state.in_error = true
rethrow(e)
finally
state.active_count -= 1
notify(state.task_done, nothing)
end
end

# Count number of concurrent tasks
# Count number of concurrent tasks.
state.active_count += 1

return (nothing, state)
end


"""
AsyncGenerator(f, c...; ntasks=100) -> iterator
AsyncGenerator(f, c...; ntasks=0) -> iterator
Apply f to each element of c using at most 100 asynchronous tasks.
Apply `f` to each element of `c` using at most `ntasks` asynchronous tasks.
If `ntasks` is unspecified, uses `max(100, nworkers())` tasks.
For multiple collection arguments, apply f elementwise.
Results are returned by the iterator as they become available.
Note: `collect(AsyncGenerator(f, c...; ntasks=1))` is equivalent to
Expand Down Expand Up @@ -115,7 +138,7 @@ function done(itr::AsyncGenerator, state::AsyncGeneratorState)
done(itr.collector, state.async_state) && isempty(itr.collector.results)
end

# Pump the source async collector if it is not already busy
# Pump the source async collector if it is not already busy.
function pump_source(itr::AsyncGenerator, state::AsyncGeneratorState)
if !isbusy(itr.collector, state.async_state) &&
!done(itr.collector, state.async_state)
Expand All @@ -132,7 +155,7 @@ function next(itr::AsyncGenerator, state::AsyncGeneratorState)
results = itr.collector.results
while !haskey(results, state.i)

# Wait for results to become available
# Wait for results to become available.
if !pump_source(itr,state) && !haskey(results, state.i)
wait(state.async_state)
end
Expand Down
26 changes: 24 additions & 2 deletions base/workerpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,33 @@ length(pool::WorkerPool) = pool.count
isready(pool::WorkerPool) = isready(pool.channel)

function remotecall_pool(rc_f, f, pool::WorkerPool, args...; kwargs...)
worker = take!(pool.channel)
# Find an active worker
worker = 0
while true
if pool.count == 0
if pool === default_worker_pool()
# No workers, the master process is used as a worker
worker = 1
break
else
throw(ErrorException("No active worker available in pool"))
end
end

worker = take!(pool.channel)
if worker in procs()
break
else
pool.count = pool.count - 1
end
end

try
rc_f(f, worker, args...; kwargs...)
finally
put!(pool.channel, worker)
if worker != 1
put!(pool.channel, worker)
end
end
end

Expand Down
5 changes: 4 additions & 1 deletion test/parallel_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ if Base.JLOptions().code_coverage == 1
elseif Base.JLOptions().code_coverage == 2
cov_flag = `--code-coverage=all`
end

# Test a `remote` invocation when no workers are present
@test remote(myid)() == 1

addprocs(3; exeflags=`$cov_flag $inline_flag --check-bounds=yes --depwarn=error`)

# Test remote()

let
pool = Base.default_worker_pool()

Expand Down

0 comments on commit 69d05ae

Please sign in to comment.