diff --git a/base/asyncmap.jl b/base/asyncmap.jl index bdd35f3450064..bafdce61bb81d 100644 --- a/base/asyncmap.jl +++ b/base/asyncmap.jl @@ -15,24 +15,25 @@ Note: `for task in AsyncCollector(f, results, c...) end` is equivalent to """ type AsyncCollector f + on_error results enumerator::Enumerate ntasks::Int end -function AsyncCollector(f, results, c...; ntasks=0) +function AsyncCollector(f, results, c...; ntasks=0, on_error=nothing) if ntasks == 0 ntasks = 100 end - AsyncCollector(f, results, enumerate(zip(c...)), ntasks) + AsyncCollector(f, on_error, results, enumerate(zip(c...)), ntasks) end - type AsyncCollectorState enum_state active_count::Int task_done::Condition done::Bool + in_error::Bool end @@ -49,12 +50,12 @@ wait(state::AsyncCollectorState) = wait(state.task_done) # Open a @sync block and initialise iterator state. function start(itr::AsyncCollector) sync_begin() - AsyncCollectorState(start(itr.enumerator), 0, Condition(), false) + AsyncCollectorState(start(itr.enumerator), 0, Condition(), false, false) end # Close @sync block when iterator is done. function done(itr::AsyncCollector, state::AsyncCollectorState) - if !state.done && done(itr.enumerator, state.enum_state) + if (!state.done && done(itr.enumerator, state.enum_state)) || state.in_error state.done = true sync_end() end @@ -72,14 +73,32 @@ function next(itr::AsyncCollector, state::AsyncCollectorState) # Execute function call and save result asynchronously @async begin - itr.results[i] = itr.f(args...) - state.active_count -= 1 - notify(state.task_done, nothing) + try + itr.results[i] = itr.f(args...) + catch e + try + if isa(itr.on_error, Function) + itr.results[i] = itr.on_error(e) + else + rethrow(e) + end + catch e2 + state.in_error = true + notify(state.task_done, e2; error=true) + + # The "notify" above raises an exception if "next" is waiting for tasks to finish. + # If the calling task is waiting on sync_end(), the rethrow() below will be captured + # by it. + rethrow(e2) + end + finally + state.active_count -= 1 + notify(state.task_done, nothing) + end end # Count number of concurrent tasks state.active_count += 1 - return (nothing, state) end @@ -97,8 +116,8 @@ type AsyncGenerator collector::AsyncCollector end -function AsyncGenerator(f, c...; ntasks=0) - AsyncGenerator(AsyncCollector(f, Dict{Int,Any}(), c...; ntasks=ntasks)) +function AsyncGenerator(f, c...; ntasks=0, on_error=nothing) + AsyncGenerator(AsyncCollector(f, Dict{Int,Any}(), c...; ntasks=ntasks, on_error=on_error)) end @@ -153,7 +172,7 @@ Transform collection `c` by applying `@async f` to each element. For multiple collection arguments, apply f elementwise. """ -asyncmap(f, c...) = collect(AsyncGenerator(f, c...)) +asyncmap(f, c...; on_error=nothing) = collect(AsyncGenerator(f, c...; on_error=on_error)) """ @@ -161,7 +180,7 @@ asyncmap(f, c...) = collect(AsyncGenerator(f, c...)) In-place version of `asyncmap()`. """ -asyncmap!(f, c) = (for x in AsyncCollector(f, c, c) end; c) +asyncmap!(f, c; on_error=nothing) = (for x in AsyncCollector(f, c, c; on_error=on_error) end; c) """ @@ -169,4 +188,4 @@ asyncmap!(f, c) = (for x in AsyncCollector(f, c, c) end; c) Like `asyncmap()`, but stores output in `results` rather returning a collection. """ -asyncmap!(f, r, c1, c...) = (for x in AsyncCollector(f, r, c1, c...) end; r) +asyncmap!(f, r, c1, c...; on_error=nothing) = (for x in AsyncCollector(f, r, c1, c...; on_error=on_error) end; r) diff --git a/base/deprecated.jl b/base/deprecated.jl index dc463d55626cf..45d4748b33612 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -995,7 +995,14 @@ export call # and added to pmap.jl # pmap(f, c...) = pmap(default_worker_pool(), f, c...) -function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) +function pmap(f, c...; kwargs...) + kwdict = merge(DEFAULT_PMAP_ARGS, AnyDict(kwargs)) + validate_pmap_kwargs(kwdict, append!([:err_retry, :pids, :err_stop], PMAP_KW_NAMES)) + + err_retry = get(kwdict, :err_retry, nothing) + err_stop = get(kwdict, :err_stop, nothing) + pids = get(kwdict, :pids, nothing) + if err_retry != nothing depwarn("err_retry is deprecated, use pmap(retry(f), c...).", :pmap) if err_retry == true @@ -1003,13 +1010,6 @@ function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) end end - if err_stop != nothing - depwarn("err_stop is deprecated, use pmap(@catch(f), c...).", :pmap) - if err_stop == false - f = @catch(f) - end - end - if pids == nothing p = default_worker_pool() else @@ -1017,7 +1017,16 @@ function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) p = WorkerPool(pids) end - return pmap(p, f, c...) + if err_stop != nothing + depwarn("err_stop is deprecated, use pmap(f, c...; on_error = error_handling_func).", :pmap) + if err_stop == false + kwdict[:on_error] = e->e + end + end + + pmap(p, f, c...; distributed=kwdict[:distributed], + batch_size=kwdict[:batch_size], + on_error=kwdict[:on_error]) end diff --git a/base/error.jl b/base/error.jl index b237db748ee76..3e5e0a66a5b08 100644 --- a/base/error.jl +++ b/base/error.jl @@ -85,26 +85,3 @@ end retry(f::Function, t::Type; kw...) = retry(f, e->isa(e, t); kw...) - -""" - @catch(f) -> Function - -Returns a lambda that executes `f` and returns either the result of `f` or -an `Exception` thrown by `f`. - -**Examples** -```julia -julia> r = @catch(length)([1,2,3]) -3 - -julia> r = @catch(length)() -MethodError(length,()) - -julia> typeof(r) -MethodError -``` -""" -catchf(f) = (args...) -> try f(args...) catch ex; ex end -macro catch(f) - esc(:(Base.catchf($f))) -end diff --git a/base/pmap.jl b/base/pmap.jl index 489379d40a0f2..5e8650cd298e9 100644 --- a/base/pmap.jl +++ b/base/pmap.jl @@ -15,38 +15,81 @@ Note that `f` must be made available to all worker processes; see and Loading Packages `) for details. """ -function pgenerate(p::WorkerPool, f, c) - if length(p) == 0 - return AsyncGenerator(f, c) +function pgenerate(p::WorkerPool, f, c; config=DEFAULT_PMAP_ARGS) + batch_size = config[:batch_size] + on_error = config[:on_error] + distributed = config[:distributed] + + if (distributed == false) || + (length(p) == 0) || + (length(p) == 1 && fetch(p.channel) == myid()) + + return AsyncGenerator(f, c; on_error=on_error) + end + + if batch_size == :auto + batches = batchsplit(c, min_batch_count = length(p) * 3) + else + batches = batchsplit(c, max_batch_size = batch_size) end - batches = batchsplit(c, min_batch_count = length(p) * 3) - return flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b)), batches)) + return flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b; on_error=on_error)), batches; on_error=on_error)) end -pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...)) +pgenerate(p::WorkerPool, f, c1, c...; kwargs...) = pgenerate(p, a->f(a...), zip(c1, c...); kwargs...) -pgenerate(f, c) = pgenerate(default_worker_pool(), f, c...) -pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...)) +pgenerate(f, c; kwargs...) = pgenerate(default_worker_pool(), f, c...; kwargs...) +pgenerate(f, c1, c...; kwargs...) = pgenerate(a->f(a...), zip(c1, c...); kwargs...) """ - pmap([::WorkerPool], f, c...) -> collection + pmap([::WorkerPool], f, c...; distributed=true, batch_size=1, on_error=nothing) -> collection Transform collection `c` by applying `f` to each element using available workers and tasks. For multiple collection arguments, apply f elementwise. -Note that `err_retry=true` and `err_stop=false` are deprecated, -use `pmap(retry(f), c)` or `pmap(@catch(f), c)` instead -(or to retry on a different worker, use `asyncmap(retry(remote(f)), c)`). - Note that `f` must be made available to all worker processes; see [Code Availability and Loading Packages](:ref:`Code Availability and Loading Packages `) for details. + +If a worker pool is not specified, all available workers, i.e., the default worker pool +is used. + +By default, `pmap` distributes the computation over all specified workers. To use only the +local process and distribute over tasks, specifiy `distributed=false` + +`pmap` can also use a mix of processes and tasks via the `batch_size` argument. For batch sizes +greater than 1, the collection is split into multiple batches, which are distributed across +workers. Each such batch is processed in parallel via tasks in each worker. `batch_size=:auto` +will automtically calculate a batch size depending on the length of the collection and number +of workers available. + +Any error stops pmap from processing the remainder of the collection. To override this behavior +you can specify an error handling function via argument `on_error` which takes in a single argument, i.e., +the exception. The function can stop the processing by rethrowing the error, or, to continue, return any value +which is then returned inline with the results to the caller. """ -pmap(p::WorkerPool, f, c...) = collect(pgenerate(p, f, c...)) +function pmap(p::WorkerPool, f, c...; kwargs...) + kwdict = merge(DEFAULT_PMAP_ARGS, AnyDict(kwargs)) + validate_pmap_kwargs(kwdict, PMAP_KW_NAMES) + + collect(pgenerate(p, f, c...; config=kwdict)) +end + + +const DEFAULT_PMAP_ARGS = AnyDict( + :distributed => true, + :batch_size => 1, + :on_error => nothing) + +const PMAP_KW_NAMES = [:distributed, :batch_size, :on_error] +function validate_pmap_kwargs(kwdict, kwnames) + unsupported = filter(x -> !(x in kwnames), collect(keys(kwdict))) + length(unsupported) > 1 && throw(ArgumentError("keyword arguments $unsupported are not supported.")) + nothing +end """ @@ -72,7 +115,7 @@ function batchsplit(c; min_batch_count=1, max_batch_size=100) # If there are not enough batches, use a smaller batch size if length(head) < min_batch_count batch_size = max(1, div(sum(length, head), min_batch_count)) - return partition(flatten(head), batch_size) + return partition(collect(flatten(head)), batch_size) end return flatten((head, tail)) diff --git a/base/workerpool.jl b/base/workerpool.jl index d186c5396eb27..a29c3437a1549 100644 --- a/base/workerpool.jl +++ b/base/workerpool.jl @@ -34,7 +34,18 @@ length(pool::WorkerPool) = pool.count isready(pool::WorkerPool) = isready(pool.channel) function remotecall_pool(rc_f, f, pool::WorkerPool, args...; kwargs...) - worker = take!(pool.channel) + # Find an active worker + worker = 0 + while true + pool.count == 0 && throw(ErrorException("No active worker available in pool")) + worker = take!(pool.channel) + if worker in procs() + break + else + pool.count = pool.count - 1 + end + end + try rc_f(f, worker, args...; kwargs...) finally diff --git a/test/error.jl b/test/error.jl index c55078b4505c2..d3471909e55fa 100644 --- a/test/error.jl +++ b/test/error.jl @@ -1,12 +1,5 @@ # This file is a part of Julia. License is MIT: http://julialang.org/license - -@test map(typeof, map(@catch(i->[1,2,3][i]), 1:6)) == - [Int, Int, Int, BoundsError, BoundsError, BoundsError] - -@test typeof(@catch(open)("/no/file/with/this/name")) == SystemError - - let function foo_error(c, n) c[1] += 1 @@ -33,37 +26,37 @@ let # 3 failed attempts, so exception is raised c = [0] - ex = @catch(retry(foo_error))(c,3) + ex = try retry(foo_error)(c,3); catch e; e; end @test ex.msg == "foo" @test c[1] == 3 c = [0] - ex = @catch(retry(foo_error, ErrorException))(c,3) + ex = try (retry(foo_error, ErrorException))(c,3); catch e; e; end @test typeof(ex) == ErrorException @test ex.msg == "foo" @test c[1] == 3 c = [0] - ex = @catch(retry(foo_error, e->e.msg == "foo"))(c,3) + ex = try (retry(foo_error, e->e.msg == "foo"))(c,3) catch e; e; end @test typeof(ex) == ErrorException @test ex.msg == "foo" @test c[1] == 3 # No retry if condition does not match c = [0] - ex = @catch(retry(foo_error, e->e.msg == "bar"))(c,3) + ex = try (retry(foo_error, e->e.msg == "bar"))(c,3) catch e; e; end @test typeof(ex) == ErrorException @test ex.msg == "foo" @test c[1] == 1 c = [0] - ex = @catch(retry(foo_error, e->e.http_status_code == "503"))(c,3) + ex = try (retry(foo_error, e->e.http_status_code == "503"))(c,3) catch e; e; end @test typeof(ex) == ErrorException @test ex.msg == "foo" @test c[1] == 1 c = [0] - ex = @catch(retry(foo_error, SystemError))(c,3) + ex = try (retry(foo_error, SystemError))(c,3) catch e; e; end @test typeof(ex) == ErrorException @test ex.msg == "foo" @test c[1] == 1 diff --git a/test/parallel_exec.jl b/test/parallel_exec.jl index 3026cfe49db6f..3088ab78bdf2b 100644 --- a/test/parallel_exec.jl +++ b/test/parallel_exec.jl @@ -9,7 +9,7 @@ if Base.JLOptions().code_coverage == 1 elseif Base.JLOptions().code_coverage == 2 cov_flag = `--code-coverage=all` end -addprocs(3; exeflags=`$cov_flag $inline_flag --check-bounds=yes --depwarn=error`) +addprocs(4; exeflags=`$cov_flag $inline_flag --check-bounds=yes --depwarn=error`) # Test remote() @@ -670,47 +670,43 @@ let ex @test repeated == 1 end -# The below block of tests are usually run only on local development systems, since: -# - tests which print errors -# - addprocs tests are memory intensive -# - ssh addprocs requires sshd to be running locally with passwordless login enabled. -# The test block is enabled by defining env JULIA_TESTFULL=1 - -DoFullTest = Bool(parse(Int,(get(ENV, "JULIA_TESTFULL", "0")))) +# pmap tests. Needs at least 4 processors dedicated to the below tests. Which we currently have +# since the parallel tests are now spawned as a separate set. +s = "abcdefghijklmnopqrstuvwxyz"; +ups = uppercase(s); -if DoFullTest - # pmap tests - # needs at least 4 processors dedicated to the below tests - ppids = remotecall_fetch(()->addprocs(4), 1) - pool = WorkerPool(ppids) - s = "abcdefghijklmnopqrstuvwxyz"; - ups = uppercase(s); - - unmangle_exception = e -> begin - if isa(e, CompositeException) - e = e.exceptions[1].ex - if isa(e, RemoteException) - e = e.captured.ex.exceptions[1].ex - end +unmangle_exception = e -> begin + if isa(e, CompositeException) + e = e.exceptions[1].ex + if isa(e, RemoteException) + e = e.captured.ex.exceptions[1].ex end - return e end + return e +end + +errifeqa = x->(x=='a') ? error("foobar") : uppercase(x) +errifeven = x->iseven(Int(x)) ? error("foobar") : uppercase(x) - for mapf in [map, asyncmap, (f, c) -> pmap(pool, f, c)] - @test ups == bytestring(UInt8[UInt8(c) for c in mapf(x->uppercase(x), s)]) - @test ups == bytestring(UInt8[UInt8(c) for c in mapf(x->uppercase(Char(x)), s.data)]) +for (throws_err, mapf) in [ (true, map), + (true, asyncmap), + (true, pmap), + (false, (f,c)->pmap(f, c; on_error = e->true)) + ] + @test ups == bytestring(UInt8[UInt8(c) for c in mapf(x->uppercase(x), s)]) + @test ups == bytestring(UInt8[UInt8(c) for c in mapf(x->uppercase(Char(x)), s.data)]) + + if throws_err # retry, on error exit - errifeqa = x->(x=='a') ? - error("EXPECTED TEST ERROR. TO BE IGNORED.") : uppercase(x) try res = mapf(retry(errifeqa), s) error("unexpected") catch e e = unmangle_exception(e) @test isa(e, ErrorException) - @test e.msg == "EXPECTED TEST ERROR. TO BE IGNORED." + @test e.msg == "foobar" end # no retry, on error exit @@ -720,32 +716,38 @@ if DoFullTest catch e e = unmangle_exception(e) @test isa(e, ErrorException) - @test e.msg == "EXPECTED TEST ERROR. TO BE IGNORED." + @test e.msg == "foobar" end - - # no retry, on error continue - res = mapf(@catch(errifeqa), Any[s...]) + else + res = mapf(errifeven, s) @test length(res) == length(ups) - res[1] = unmangle_exception(res[1]) - @test isa(res[1], ErrorException) - @test res[1].msg == "EXPECTED TEST ERROR. TO BE IGNORED." - @test ups[2:end] == string(res[2:end]...) - end - - # retry, on error exit - mapf = (f, c) -> asyncmap(retry(remote(pool, f), n=10, max_delay=0), c) - errifevenid = x->iseven(myid()) ? - error("EXPECTED TEST ERROR. TO BE IGNORED.") : uppercase(x) - res = mapf(errifevenid, s) - @test length(res) == length(ups) - @test ups == bytestring(UInt8[UInt8(c) for c in res]) - - # retry, on error continue - mapf = (f, c) -> asyncmap(@catch(retry(remote(pool, f), n=10, max_delay=0)), c) - res = mapf(errifevenid, s) - @test length(res) == length(ups) - @test ups == bytestring(UInt8[UInt8(c) for c in res]) + for i in 1:length(s) + if iseven(Int(s[i])) + @test res[i] == true + else + @test res[i] == uppercase(s[i]) + end + end + end +end + +# retry till success. +errifevenid = x->iseven(myid()) ? error("foobar") : myid() +mapf = (f, c) -> asyncmap(retry(remote(f), n=typemax(Int), max_delay=0), c) +res = mapf(errifevenid, s) +@test length(res) == length(ups) +@test all(isodd, res) + + +# The below block of tests are usually run only on local development systems, since: +# - tests which print errors +# - addprocs tests are memory intensive +# - ssh addprocs requires sshd to be running locally with passwordless login enabled. +# The test block is enabled by defining env JULIA_TESTFULL=1 +DoFullTest = Bool(parse(Int,(get(ENV, "JULIA_TESTFULL", "0")))) + +if DoFullTest # Topology tests need to run externally since a given cluster at any # time can only support a single topology and the current session # is already running in parallel under the default topology.