Skip to content

Commit

Permalink
[Distributed] Make worker state variable threadsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 13, 2021
1 parent 70cc57c commit e6bd8fb
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 14 deletions.
61 changes: 48 additions & 13 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ mutable struct Worker
add_msgs::Array{Any,1}
gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -133,7 +133,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -143,12 +143,16 @@ mutable struct Worker
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
lock(w.c_state)
if w.state === W_CREATED
unlock(w.c_state)
if !isclusterlazy()
if PGRP.topology === :all_to_all
# Since higher pids connect with lower pids, the remote worker
Expand All @@ -168,6 +172,8 @@ function check_worker_state(w::Worker)
errormonitor(t)
wait_for_conn(w)
end
else
unlock(w.c_state)
end
end

Expand All @@ -186,13 +192,25 @@ function exec_conn_func(w::Worker)
end

function wait_for_conn(w)
lock(w.c_state)
if w.state === W_CREATED
unlock(w.c_state)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
else
unlock(w.c_state)
end
nothing
end
Expand Down Expand Up @@ -483,7 +501,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
@async begin
sleep(1)
notify(launch_ntfy)
end
wait(launch_ntfy)
end

Expand Down Expand Up @@ -636,7 +657,12 @@ function create_worker(manager, wconfig)
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
# wait for wl to join
lock(jw.c_state) do
if jw.state === W_CREATED
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -659,7 +685,12 @@ function create_worker(manager, wconfig)
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
lock(wl.c_state) do
if wl.state === W_CREATED
# wait for wl to join
wait(wl.c_state)
end
end
push!(join_list, wl)
end
end
Expand All @@ -676,7 +707,11 @@ function create_worker(manager, wconfig)
@async manage(w.manager, w.id, w.config, :register)
# wait for rr_ntfy_join with timeout
timedout = false
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
@async begin
sleep($timeout)
timedout = true
put!(rr_ntfy_join, 1)
end
wait(rr_ntfy_join)
if timedout
error("worker did not connect within $timeout seconds")
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Distributed/src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
@async try
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down
1 change: 1 addition & 0 deletions stdlib/Distributed/test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1696,4 +1696,5 @@ include("splitrange.jl")
# Run topology tests last after removing all workers, since a given
# cluster at any time only supports a single topology.
rmprocs(workers())
include("threads.jl")
include("topology.jl")
63 changes: 63 additions & 0 deletions stdlib/Distributed/test/threads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Test
using Distributed, Base.Threads
using Base.Iterators: product

exeflags = ("--startup-file=no",
"--check-bounds=yes",
"--depwarn=error",
"--threads=2")

function call_on(f, wid, tid)
remotecall(wid) do
t = Task(f)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
schedule(t)
@assert threadid(t) == tid
t
end
end

# Run function on process holding the data to only serialize the result of f.
# This becomes useful for things that cannot be serialized (e.g. running tasks)
# or that would be unnecessarily big if serialized.
fetch_from_owner(f, rr) = remotecall_fetch(f fetch, rr.where, rr)

isdone(rr) = fetch_from_owner(istaskdone, rr)
isfailed(rr) = fetch_from_owner(istaskfailed, rr)

@testset "RemoteChannel allows put!/take! from thread other than 1" begin
ws = ts = product(1:2, 1:2)
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
procs_added = addprocs(2; exeflags, lazy=true)
@everywhere procs_added using Base.Threads

p1 = procs_added[w1]
p2 = procs_added[w2]
chan_id = first(procs_added)
chan = RemoteChannel(chan_id)
send = call_on(p1, t1) do
put!(chan, nothing)
end
recv = call_on(p2, t2) do
take!(chan)
end

# Wait on the spawned tasks on the owner
@sync begin
Threads.@spawn fetch_from_owner(wait, recv)
Threads.@spawn fetch_from_owner(wait, send)
end

# Check the tasks
@test isdone(send)
@test isdone(recv)

@test !isfailed(send)
@test !isfailed(recv)

rmprocs(procs_added)
end
end
end

0 comments on commit e6bd8fb

Please sign in to comment.