Skip to content

Commit

Permalink
Add threadpool support to runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Sep 18, 2021
1 parent dba8a08 commit b060dba
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 36 deletions.
27 changes: 18 additions & 9 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ true
"""
istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed)

Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)
Threads.threadid(t::Task) = Int(ccall(:jl_get_task_relative_tid, Int16, (Any,), t)+1)
Threads.rawthreadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)
Threads.threadpoolid(t::Task) = Int(ccall(:jl_get_task_poolid, Int16, (Any,), t)+1)

task_result(t::Task) = t.result

Expand Down Expand Up @@ -599,8 +601,9 @@ function list_deletefirst!(W::InvasiveLinkedListSynchronized{T}, t::T) where T
end

const StickyWorkqueue = InvasiveLinkedListSynchronized{Task}
global const Workqueues = [StickyWorkqueue()]
global const Workqueues = [StickyWorkqueue()] # default threadpool is first
global const Workqueue = Workqueues[1] # default work queue is thread 1
global const AllWorkqueues = [Workqueues]
function __preinit_threads__()
if length(Workqueues) < Threads.nthreads()
resize!(Workqueues, Threads.nthreads())
Expand All @@ -613,7 +616,9 @@ end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
tp = Threads.threadpoolid(t)
tid = Threads.threadid(t)
_tid = Threads.rawthreadid(t)
# Note there are three reasons a Task might be put into a sticky queue
# even if t.sticky == false:
# 1. The Task's stack is currently being used by the scheduler for a certain thread.
Expand All @@ -627,23 +632,27 @@ function enq_work(t::Task)
# set it to be sticky.
# XXX: Ideally we would be able to unset this
current_task().sticky = true
tp = Threads.threadpoolid()
tid = Threads.threadid()
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
_tid = Threads.rawthreadid()
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, _tid-1)
end
push!(Workqueues[tid], t)
push!(AllWorkqueues[tp][tid], t)
else
if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0
# if multiq is full, give to a random thread (TODO fix)
if tid == 0
tp = Threads.threadpoolid()
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, _tid-1)
_tid = Threads.rawthreadid(t)
end
push!(Workqueues[tid], t)
push!(AllWorkqueues[tp][tid], t)
else
tid = 0
_tid = 0
end
end
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (_tid - 1) % Int16)
return t
end

Expand Down Expand Up @@ -819,7 +828,7 @@ end

function wait()
GC.safepoint()
W = Workqueues[Threads.threadid()]
W = AllWorkqueues[Threads.threadpoolid()][Threads.threadid()]
poptask(W)
result = try_yieldto(ensure_rescheduled)
process_events()
Expand Down
51 changes: 43 additions & 8 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,59 @@
export threadid, nthreads, @threads, @spawn

"""
Threads.threadid()
Threads.threadid() -> Int
Get the ID number of the current thread of execution. The master thread has ID `1`.
Get the ID number of the current thread of execution within the current
threadpool. The master thread has ID `1`.
"""
threadid() = Int(ccall(:jl_threadid, Int16, ())+1)
threadid() = Int(ccall(:jl_relative_threadid, Int16, ())+1)

# Inclusive upper bound on threadid()
"""
Threads.nthreads()
Threads.threadid() -> Int
Get the number of threads available to the Julia process. This is the inclusive upper bound
on [`threadid()`](@ref).
Get the ID number of the current thread of execution within the Julia session.
The master thread has ID `1`.
"""
rawthreadid() = Int(ccall(:jl_threadid, Int16, ())+1)

"""
Threads.nthreads(tp::Int=Threads.threadpoolid()) -> Int
Get the number of threads available in the specified threadpool. This is the
inclusive upper bound on [`threadid()`](@ref).
See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
[`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the
[`Distributed`](@ref man-distributed) standard library.
"""
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
nthreads(tp::Int=Threads.threadpoolid()) =
ccall(:jl_num_threads, Cint, (Cint,), tp-1)

"""
Threads.spawn_threadpool(n::Int) -> Int
Spawns a new threadpool of size `n`, and returns the threadpool ID.
"""
function spawn_threadpool(n::Int)
wq = [Base.StickyWorkqueue() for _ in 1:n]
push!(Base.AllWorkqueues, wq)
return ccall(:jl_start_threads_dedicated, Cint, (Cint,), n)
end

"""
Threads.threadpoolid() -> Int
Returns the threadpool ID that the current thread resides in. The default
threadpool has ID `1`.
"""
threadpoolid() = ccall(:jl_threadpoolid, Cint, ())+1

"""
Threads.npools() -> Int
Returns the number of threadpools currently configured.
"""
npools() = Int(unsafe_load(cglobal(:jl_threadpools, Cint)))

function threading_run(func)
ccall(:jl_enter_threaded_region, Cvoid, ())
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
XX(jl_string_type) \
XX(jl_symbol_type) \
XX(jl_task_type) \
XX(jl_threadpools) \
XX(jl_top_module) \
XX(jl_true) \
XX(jl_tuple_typename) \
Expand Down
4 changes: 4 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@
XX(jl_get_root_symbol) \
XX(jl_get_safe_restore) \
XX(jl_get_size) \
XX(jl_get_task_poolid) \
XX(jl_get_task_relative_tid) \
XX(jl_get_task_tid) \
XX(jl_gettimeofday) \
XX(jl_get_tls_world_age) \
Expand Down Expand Up @@ -458,6 +460,7 @@
XX(jl_spawn) \
XX(jl_specializations_get_linfo) \
XX(jl_specializations_lookup) \
XX(jl_start_threads_dedicated) \
XX(jl_static_show) \
XX(jl_static_show_func_sig) \
XX(jl_stderr_obj) \
Expand Down Expand Up @@ -496,6 +499,7 @@
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
XX(jl_threading_enabled) \
XX(jl_threadpoolid) \
XX(jl_throw) \
XX(jl_throw_out_of_memory_error) \
XX(jl_too_few_args) \
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,7 @@ JL_DLLEXPORT jl_sym_t *jl_get_UNAME(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_sym_t *jl_get_ARCH(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_get_libllvm(void) JL_NOTSAFEPOINT;
extern JL_DLLIMPORT int jl_n_threads;
extern JL_DLLIMPORT int jl_threadpools;

// environment entries
JL_DLLEXPORT jl_value_t *jl_environ(int i);
Expand Down
10 changes: 10 additions & 0 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,16 @@ JL_DLLEXPORT int16_t jl_get_task_tid(jl_task_t *t) JL_NOTSAFEPOINT
return t->tid;
}

JL_DLLEXPORT int16_t jl_get_task_relative_tid(jl_task_t *t) JL_NOTSAFEPOINT
{
return jl_tid_to_relative(t->tid);
}

JL_DLLEXPORT int16_t jl_get_task_poolid(jl_task_t *t) JL_NOTSAFEPOINT
{
return jl_tid_to_poolid(t->tid);
}


#ifdef _OS_WINDOWS_
#if defined(_CPU_X86_)
Expand Down
115 changes: 96 additions & 19 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -287,17 +287,46 @@ void jl_pgcstack_getkey(jl_get_pgcstack_func **f, jl_pgcstack_key_t *k)
#endif

jl_ptls_t *jl_all_tls_states JL_GLOBALLY_ROOTED;
int *jl_threadpool_map;
int *jl_threadpool_sizes;
uint8_t jl_measure_compile_time_enabled = 0;
uint64_t jl_cumulative_compile_time = 0;

// return calling thread's ID
// Also update the suspended_threads list in signals-mach when changing the
// type of the thread id.
// return calling thread's absolute ID
JL_DLLEXPORT int16_t jl_threadid(void)
{
return jl_current_task->tid;
}

// return calling thread's relative ID (with respect to its threadpool)
int16_t jl_tid_to_relative(int16_t rawtid)
{
if (rawtid < 0)
return rawtid;
int poolid = jl_threadpool_map[rawtid];
int tp_offset = 0;
for (int tp = 0; tp < poolid; tp++) {
tp_offset += jl_threadpool_sizes[tp];
}
return rawtid - tp_offset;
}
JL_DLLEXPORT int16_t jl_relative_threadid(void)
{
return jl_tid_to_relative(jl_current_task->tid);
}

int16_t jl_tid_to_poolid(int16_t tid)
{
int tp_offset = 0;
for (int tp = 0; tp < jl_threadpools; tp++) {
int tp_size = jl_threadpool_sizes[tp];
if (tp_offset + tp_size > tid)
return tp;
tp_offset += tp_size;
}
return -1;
}

jl_ptls_t jl_init_threadtls(int16_t tid)
{
jl_ptls_t ptls = (jl_ptls_t)calloc(1, sizeof(jl_tls_states_t));
Expand Down Expand Up @@ -467,18 +496,61 @@ void jl_init_threading(void)
}
if (jl_n_threads <= 0)
jl_n_threads = 1;
int jl_extra_threads = 8; // FIXME: ENV[NUM_EXTRA_THREADS_NAME]
int jl_max_threadpools = 8; // FIXME: ENV[NUM_THREADPOOLS_NAME]
#ifndef __clang_analyzer__
jl_all_tls_states = (jl_ptls_t*)calloc(jl_n_threads, sizeof(void*));
jl_all_tls_states = (jl_ptls_t*)calloc(jl_n_threads+jl_extra_threads, sizeof(void*));
#endif
jl_threadpools = 0;
jl_threadpool_map = (int*)calloc(jl_max_threadpools, sizeof(int));
jl_threadpool_sizes = (int*)calloc(jl_max_threadpools, sizeof(int));
}

static uv_barrier_t thread_init_done;

int jl_start_threads_(size_t nthreads, size_t cur_n_threads, uv_barrier_t *barrier, int exclusive)
{
int cpumasksize = uv_cpumask_size();
if (cpumasksize < jl_n_threads) // also handles error case
cpumasksize = jl_n_threads;
char *mask = (char*)alloca(cpumasksize);
int i;
uv_thread_t uvtid;

int tp = jl_threadpools;
jl_threadpools++;
if (tp > 0) {
jl_threadpool_sizes[tp] = nthreads;
} else {
jl_threadpool_sizes[tp] = jl_n_threads;
}

uv_barrier_init(barrier, nthreads);

for (i = cur_n_threads; i < cur_n_threads+nthreads; ++i) {
jl_threadarg_t *t = (jl_threadarg_t*)malloc_s(sizeof(jl_threadarg_t)); // ownership will be passed to the thread
t->tid = i;
t->barrier = barrier;
uv_thread_create(&uvtid, jl_threadfun, t);
jl_threadpool_map[i] = tp;
if (exclusive) {
mask[i] = 1;
uv_thread_setaffinity(&uvtid, mask, NULL, cpumasksize);
mask[i] = 0;
}
uv_thread_detach(&uvtid);
}

uv_barrier_wait(barrier);

return tp;
}

void jl_start_threads(void)
{
int cpumasksize = uv_cpumask_size();
char *cp;
int i, exclusive;
int exclusive;
uv_thread_t uvtid;
if (cpumasksize < jl_n_threads) // also handles error case
cpumasksize = jl_n_threads;
Expand Down Expand Up @@ -509,22 +581,27 @@ void jl_start_threads(void)
size_t nthreads = jl_n_threads;

// create threads
uv_barrier_init(&thread_init_done, nthreads);
jl_start_threads_(nthreads-1, 1, &thread_init_done, exclusive);
}

for (i = 1; i < nthreads; ++i) {
jl_threadarg_t *t = (jl_threadarg_t*)malloc_s(sizeof(jl_threadarg_t)); // ownership will be passed to the thread
t->tid = i;
t->barrier = &thread_init_done;
uv_thread_create(&uvtid, jl_threadfun, t);
if (exclusive) {
mask[i] = 1;
uv_thread_setaffinity(&uvtid, mask, NULL, cpumasksize);
mask[i] = 0;
}
uv_thread_detach(&uvtid);
}
JL_DLLEXPORT int jl_start_threads_dedicated(size_t nthreads, int exclusive)
{
uv_barrier_t tbar;
return jl_start_threads_(nthreads, jl_n_threads, &tbar, exclusive);
}

JL_DLLEXPORT int jl_num_threads(int tp)
{
if (jl_threadpools == 0)
// Pre-init
return jl_n_threads;
assert(tp < jl_threadpools);
return jl_threadpool_sizes[tp];
}

uv_barrier_wait(&thread_init_done);
JL_DLLEXPORT int jl_threadpoolid(void)
{
return jl_threadpool_map[jl_current_task->tid];
}

unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO
Expand Down

0 comments on commit b060dba

Please sign in to comment.