diff --git a/base/asyncmap.jl b/base/asyncmap.jl index bdd35f3450064..b6ead5b2de888 100644 --- a/base/asyncmap.jl +++ b/base/asyncmap.jl @@ -15,16 +15,18 @@ Note: `for task in AsyncCollector(f, results, c...) end` is equivalent to """ type AsyncCollector f + on_error results enumerator::Enumerate ntasks::Int end -function AsyncCollector(f, results, c...; ntasks=0) +AsyncCollector(f::Function, results, c...; kwargs...) = AsyncCollector(f::Function, e->rethrow(e), results, c...; kwargs...) +function AsyncCollector(f::Function, on_error::Function, results, c...; ntasks=0) if ntasks == 0 ntasks = 100 end - AsyncCollector(f, results, enumerate(zip(c...)), ntasks) + AsyncCollector(f, on_error, results, enumerate(zip(c...)), ntasks) end @@ -33,6 +35,7 @@ type AsyncCollectorState active_count::Int task_done::Condition done::Bool + in_error::Bool end @@ -49,12 +52,12 @@ wait(state::AsyncCollectorState) = wait(state.task_done) # Open a @sync block and initialise iterator state. function start(itr::AsyncCollector) sync_begin() - AsyncCollectorState(start(itr.enumerator), 0, Condition(), false) + AsyncCollectorState(start(itr.enumerator), 0, Condition(), false, false) end # Close @sync block when iterator is done. function done(itr::AsyncCollector, state::AsyncCollectorState) - if !state.done && done(itr.enumerator, state.enum_state) + if (!state.done && done(itr.enumerator, state.enum_state)) || state.in_error state.done = true sync_end() end @@ -72,14 +75,28 @@ function next(itr::AsyncCollector, state::AsyncCollectorState) # Execute function call and save result asynchronously @async begin - itr.results[i] = itr.f(args...) - state.active_count -= 1 - notify(state.task_done, nothing) + try + itr.results[i] = itr.f(args...) + catch e + try + itr.results[i] = itr.on_error(e) + catch e2 + state.in_error = true + notify(state.task_done, e2; error=true) + + # The "notify" above raises an exception if "next" is waiting for tasks to finish. + # If the calling task is waiting on sync_end(), the rethrow() below will be captured + # by it. + rethrow(e2) + end + finally + state.active_count -= 1 + notify(state.task_done, nothing) + end end # Count number of concurrent tasks state.active_count += 1 - return (nothing, state) end diff --git a/base/deprecated.jl b/base/deprecated.jl index dc463d55626cf..604bf6e8f2a8c 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -1003,13 +1003,6 @@ function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) end end - if err_stop != nothing - depwarn("err_stop is deprecated, use pmap(@catch(f), c...).", :pmap) - if err_stop == false - f = @catch(f) - end - end - if pids == nothing p = default_worker_pool() else @@ -1017,7 +1010,13 @@ function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) p = WorkerPool(pids) end - return pmap(p, f, c...) + if err_stop != nothing + depwarn("err_stop is deprecated, use pmap(@catch(f), c...; on_error = e->e).", :pmap) + return pmap(p, f, c...; on_error=e->e) + else + return pmap(p, f, c...) + end + end diff --git a/base/error.jl b/base/error.jl index b237db748ee76..3e5e0a66a5b08 100644 --- a/base/error.jl +++ b/base/error.jl @@ -85,26 +85,3 @@ end retry(f::Function, t::Type; kw...) = retry(f, e->isa(e, t); kw...) - -""" - @catch(f) -> Function - -Returns a lambda that executes `f` and returns either the result of `f` or -an `Exception` thrown by `f`. - -**Examples** -```julia -julia> r = @catch(length)([1,2,3]) -3 - -julia> r = @catch(length)() -MethodError(length,()) - -julia> typeof(r) -MethodError -``` -""" -catchf(f) = (args...) -> try f(args...) catch ex; ex end -macro catch(f) - esc(:(Base.catchf($f))) -end diff --git a/base/pmap.jl b/base/pmap.jl index 489379d40a0f2..c7fe73b696faf 100644 --- a/base/pmap.jl +++ b/base/pmap.jl @@ -15,11 +15,18 @@ Note that `f` must be made available to all worker processes; see and Loading Packages `) for details. """ -function pgenerate(p::WorkerPool, f, c) - if length(p) == 0 - return AsyncGenerator(f, c) +function pgenerate(p::WorkerPool, f, c; distributed=true, batch_size=1, on_error = e->rethrow(e)) + if (distributed == false) || + (length(p) == 0) || + (length(p) == 1 && fetch(p.channel) == myid()) + + return AsyncGenerator(f, c; on_error=on_error) + end + if batch_size == :auto + batches = batchsplit(c, min_batch_count = length(p) * 3) + else + batches = batchsplit(c, max_batch_size = batch_size) end - batches = batchsplit(c, min_batch_count = length(p) * 3) return flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b)), batches)) end @@ -46,7 +53,7 @@ Note that `f` must be made available to all worker processes; see and Loading Packages `) for details. """ -pmap(p::WorkerPool, f, c...) = collect(pgenerate(p, f, c...)) +pmap(p::WorkerPool, f, c...; kwargs...) = collect(pgenerate(p, f, c...; kwargs...)) """ diff --git a/base/workerpool.jl b/base/workerpool.jl index d186c5396eb27..55df585b7dc4d 100644 --- a/base/workerpool.jl +++ b/base/workerpool.jl @@ -34,7 +34,17 @@ length(pool::WorkerPool) = pool.count isready(pool::WorkerPool) = isready(pool.channel) function remotecall_pool(rc_f, f, pool::WorkerPool, args...; kwargs...) - worker = take!(pool.channel) + # Find an active worker + while true + pool.count == 0 && throw(ErrorException("No active worker available in pool")) + worker = take!(pool.channel) + if worker in procs() + break; + else + pool.count = pool.count - 1 + end + end + try rc_f(f, worker, args...; kwargs...) finally