Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unstick parent if all child tasks are done #41393

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ primitive type Char <: AbstractChar 32 end
primitive type Int8 <: Signed 8 end
#primitive type UInt8 <: Unsigned 8 end
primitive type Int16 <: Signed 16 end
primitive type UInt16 <: Unsigned 16 end
#primitive type UInt16 <: Unsigned 16 end
#primitive type Int32 <: Signed 32 end
#primitive type UInt32 <: Unsigned 32 end
#primitive type Int64 <: Signed 64 end
Expand Down
49 changes: 47 additions & 2 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,21 @@ end
elseif field === :exception
# TODO: this field name should be deprecated in 2.0
return t._isexception ? t.result : nothing
elseif field === :sticky
return getfield(t, :sticky_count) != 0
else
return getfield(t, field)
end
end

function setproperty!(t::Task, field::Symbol, x)
if field === :sticky
t.sticky_count = convert(Bool, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what prevents an @async-scheduled task from being marked as unsticky since sticky_count will always be >= 1, right? Maybe it would be good to add comment here. Or a test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather a compat layer for supporting .sticky property.

@async task is never marked as unsticky because (1) its sticky_count is initialized to 1 (in the C function jl_new_task) and (2) a decrement (if any) is always paired with a preceding increment.

else
setfield!(t, field, convert(fieldtype(Task, field), x))
end
end

"""
istaskdone(t::Task) -> Bool

Expand Down Expand Up @@ -611,6 +621,37 @@ function __preinit_threads__()
nothing
end

# Factored out so that the behavior after saturation can be tested:
is_sticky_count_saturated(t::Task) = t.sticky_count === typemax(t.sticky_count)

# This is a struct rather than a closure so that `serialize` can be dispatched
# to ignore `parent` field.
struct StickyCountDecrementer
code::Any
parent::Union{Nothing,Task}
end

unset_parent(f::StickyCountDecrementer) = StickyCountDecrementer(f.code, nothing)

function (f::StickyCountDecrementer)()
try
f.code()
finally
parent_task = f.parent
if parent_task !== nothing && !is_sticky_count_saturated(parent_task)
# Once `parent_task.sticky_count` hits the typemax (which
# practically never happens), we stop un-sticking the parent task.
# This only affects the performance in rare cases (which already
# torturing the scheulder anyway) and does not sacrifice the
# correctness. Checking saturation should be done for all tasks
# includding those started with `parent_task.sticky_count < typemax
# -1` since there may be sticky tasks started realying on that the
# coutner is saturated.
tkf marked this conversation as resolved.
Show resolved Hide resolved
parent_task.sticky_count -= 1
end
end
end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
tid = Threads.threadid(t)
Expand All @@ -625,8 +666,12 @@ function enq_work(t::Task)
# t.sticky && tid == 0 is a task that needs to be co-scheduled with
# the parent task. If the parent (current_task) is not sticky we must
# set it to be sticky.
# XXX: Ideally we would be able to unset this
current_task().sticky = true
parent_task = current_task()
if t.sticky && !is_sticky_count_saturated(parent_task)
parent_task.sticky_count += 1
t.code = StickyCountDecrementer(t.code, parent_task)
end

tid = Threads.threadid()
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
end
Expand Down
1 change: 1 addition & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("UInt8", (jl_value_t*)jl_uint8_type);
add_builtin("Int32", (jl_value_t*)jl_int32_type);
add_builtin("Int64", (jl_value_t*)jl_int64_type);
add_builtin("UInt16", (jl_value_t*)jl_uint16_type);
add_builtin("UInt32", (jl_value_t*)jl_uint32_type);
add_builtin("UInt64", (jl_value_t*)jl_uint64_type);
#ifdef _P64
Expand Down
2 changes: 1 addition & 1 deletion src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,6 @@ static void post_boot_hooks(void)
jl_char_type = (jl_datatype_t*)core("Char");
jl_int8_type = (jl_datatype_t*)core("Int8");
jl_int16_type = (jl_datatype_t*)core("Int16");
jl_uint16_type = (jl_datatype_t*)core("UInt16");
jl_float16_type = (jl_datatype_t*)core("Float16");
jl_float32_type = (jl_datatype_t*)core("Float32");
jl_float64_type = (jl_datatype_t*)core("Float64");
Expand All @@ -819,6 +818,7 @@ static void post_boot_hooks(void)
jl_uint8_type->super = jl_unsigned_type;
jl_int32_type->super = jl_signed_type;
jl_int64_type->super = jl_signed_type;
jl_uint16_type->super = jl_unsigned_type;
jl_uint32_type->super = jl_unsigned_type;
jl_uint64_type->super = jl_unsigned_type;

Expand Down
6 changes: 4 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,8 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type, jl_emptysvec, 32);
jl_int64_type = jl_new_primitivetype((jl_value_t*)jl_symbol("Int64"), core,
jl_any_type, jl_emptysvec, 64);
jl_uint16_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt16"), core,
jl_any_type, jl_emptysvec, 16);
jl_uint32_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt32"), core,
jl_any_type, jl_emptysvec, 32);
jl_uint64_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt64"), core,
Expand Down Expand Up @@ -2543,8 +2545,8 @@ void jl_init_types(void) JL_GC_DISABLED
"rngState1",
"rngState2",
"rngState3",
"sticky_count",
"_state",
"sticky",
"_isexception"),
jl_svec(14,
jl_any_type,
Expand All @@ -2558,8 +2560,8 @@ void jl_init_types(void) JL_GC_DISABLED
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint16_type,
jl_uint8_type,
jl_bool_type,
jl_bool_type),
jl_emptysvec,
0, 1, 6);
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1833,8 +1833,8 @@ typedef struct _jl_task_t {
uint64_t rngState1;
uint64_t rngState2;
uint64_t rngState3;
uint16_t sticky; // 0 means this Task can be migrated to a new thread
uint8_t _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread
uint8_t _isexception; // set if `result` is an exception to throw or that we exited with

// hidden state:
Expand Down
3 changes: 3 additions & 0 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ function serialize(s::AbstractSerializer, linfo::Core.MethodInstance)
nothing
end

serialize(s::AbstractSerializer, f::Base.StickyCountDecrementer) =
invoke(serialize, Tuple{typeof(s),Any}, s, Base.unset_parent(f))

function serialize(s::AbstractSerializer, t::Task)
serialize_cycle(s, t) && return
if istaskstarted(t) && !istaskdone(t)
Expand Down
36 changes: 35 additions & 1 deletion test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ fib34666(x) =
wait(child)
end
wait(parent)
@test parent.sticky == true
@test parent.sticky == false
end

function jitter_channel(f, k, delay, ntasks, schedule)
Expand Down Expand Up @@ -912,3 +912,37 @@ end
@test reproducible_rand(r, 10) == val
end
end

# [ADD TESTS ABOVE THIS COMMENT]
#
# The following tests must be done at the end, since they need to monkey-patch runtime.
const MAX_STICKY_COUNT = 3
@assert MAX_STICKY_COUNT <= typemax(fieldtype(Task, :sticky_count))
Base.is_sticky_count_saturated(t::Task) = t.sticky_count == MAX_STICKY_COUNT

@testset "Saturated sticky_count" begin
@testset for nchild in MAX_STICKY_COUNT-1:MAX_STICKY_COUNT+1
local is_sticky_pre, is_sticky_post, sticky_count_pre, sticky_count_post
@sync Threads.@spawn begin
is_sticky_pre = current_task().sticky
sticky_count_pre = current_task().sticky_count
@sync for _ in 1:nchild
@async nothing
end
is_sticky_post = current_task().sticky
sticky_count_post = current_task().sticky_count
end
@test !is_sticky_pre
@test sticky_count_pre == 0
if nchild < MAX_STICKY_COUNT
@test !is_sticky_post
@test sticky_count_post == 0
else
@test is_sticky_post
@test sticky_count_post == MAX_STICKY_COUNT
end
end
end

# Please do not add tests at the end of this file. Pleaes add tests above the above
# comment [ADD TESTS ABOVE THIS COMMENT].