Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BLAS.get_num_threads #36360

Merged
merged 30 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bea54cc
add BLAS.get_num_threads
jw3126 Jun 19, 2020
2de23cd
fix
jw3126 Jun 19, 2020
0e6f8fb
fix
jw3126 Jun 19, 2020
9f396c2
fix
jw3126 Jun 19, 2020
e84c258
fix
jw3126 Jun 19, 2020
9a0edcb
warn if get/set of num_bals_threads fails
jw3126 Jun 19, 2020
f6daa79
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
8b2c8c4
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
a15f851
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
826d8ff
fix
jw3126 Jun 20, 2020
ce61636
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
b38afaa
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
0ee0efa
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
8e4fedd
fix
jw3126 Jun 20, 2020
33a95d5
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
35cf5a6
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
bc370b4
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
dd455b5
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 20, 2020
abd2084
fix
jw3126 Jun 22, 2020
b8e9055
fix
jw3126 Jun 22, 2020
9eabcb6
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 23, 2020
920b90b
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 23, 2020
b011dc6
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 23, 2020
95ccdf9
Update stdlib/LinearAlgebra/test/blas.jl
jw3126 Jun 23, 2020
72ef30e
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 23, 2020
4cf2628
Update stdlib/LinearAlgebra/test/blas.jl
jw3126 Jun 23, 2020
bb69931
Update stdlib/LinearAlgebra/test/blas.jl
jw3126 Jun 23, 2020
06550d6
Update stdlib/LinearAlgebra/src/blas.jl
jw3126 Jun 23, 2020
b6aa076
improve docstrings
jw3126 Jun 23, 2020
c524c53
add to NEWS.md
jw3126 Jun 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,65 @@ end
openblas_get_config() = strip(unsafe_string(ccall((@blasfunc(openblas_get_config), libblas), Ptr{UInt8}, () )))

"""
set_num_threads(n)
set_num_threads(n::Integer)
set_num_threads(::Nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an unusual API, I would prefer set_num_threads() instead of set_num_threads(nothing).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to allow the pattern

default = BLAS.get_num_threads() # returns nothing on exotic platforms
BLAS.set_num_threads(1)
# do stuff
BLAS.set_num_threads(default)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will still work if you make the signature set_num_threads(n=nothing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me allowing nothing is a hack to allow the above pattern on strange platforms. It is not a thing I would encourage or that I think needs more convenient syntax.

Copy link
Contributor Author

@jw3126 jw3126 Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine if you have code like set_num_threads() it is more likely you forgot to pass the number of threads, than that you really want to invoke the nothing method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should define set_num_threads() = set_num_threads(nothing). It would send a wrong message that set_num_threads(nothing) is somehow a reasonable default. But it's not. It is the last resort that exists only for supporting the rollback use case.

But this is not clear from the current docstring. I think it's better to clarify this.

Copy link
Contributor

@mcabbott mcabbott Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps clearest to show the pattern which motivates this:

Set the number of threads the BLAS library should use.

Also accepts `nothing`, in which case it tries to set set the default number of threads.
On exotic variants of BLAS, `nothing` may be returned by ` get_num_threads()`.
Thus the following pattern may fail to set the number of threads, but will not give an error:

old = get_num_threads()
set_num_threads(1)
@threads for i in 1:10
    # single-threaded BLAS calls
end
set_num_threads(old)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fredrikekre Does the argument here makes sense? It'd be nice if you can have a look at the docstring.


Set the number of threads the BLAS library should use equal to `n::Integer`.
If `nothing` is passed to this function, julia tries to figure out the optimial number of threads.
The exact heuristic is an implementation detail.

Set the number of threads the BLAS library should use.
On exotic variants of `BLAS` this function can fail.
"""
function set_num_threads(n::Integer)
function set_num_threads(n::Integer)::Nothing
blas = vendor()
if blas === :openblas
return ccall((:openblas_set_num_threads, libblas), Cvoid, (Int32,), n)
elseif blas === :openblas64
return ccall((:openblas_set_num_threads64_, libblas), Cvoid, (Int32,), n)
if blas === :openblas || blas == :openblas64
return ccall((@blasfunc(openblas_set_num_threads), libblas), Cvoid, (Cint,), n)
elseif blas === :mkl
# MKL may let us set the number of threads in several ways
return ccall((:MKL_Set_Num_Threads, libblas), Cvoid, (Cint,), n)
end

# OSX BLAS looks at an environment variable
@static if Sys.isapple()
elseif Sys.isapple()
# OSX BLAS looks at an environment variable
ENV["VECLIB_MAXIMUM_THREADS"] = n
else
@warn "Failed to set number of BLAS threads." maxlog=1
end
return nothing
end

_tryparse_env_int(key) = tryparse(Int, get(ENV, key, ""))

function set_num_threads(::Nothing)
n = something(
_tryparse_env_int("OPENBLAS_NUM_THREADS"),
_tryparse_env_int("OMP_NUM_THREADS"),
max(1, Sys.CPU_THREADS ÷ 2),
)
set_num_threads(n)
end

"""
get_num_threads()

Get the number of threads the BLAS library is using.

On exotic variants of `BLAS` this function can fail, which is indicated by returning `nothing`.
"""
function get_num_threads()::Union{Int, Nothing}
blas = LinearAlgebra.BLAS.vendor()
jw3126 marked this conversation as resolved.
Show resolved Hide resolved
if blas === :openblas || blas === :openblas64
jw3126 marked this conversation as resolved.
Show resolved Hide resolved
return Int(ccall((@blasfunc(openblas_get_num_threads), libblas), Cint, ()))
elseif blas == :mkl
return Int(ccall((:mkl_get_max_threads, libblas), Cint, ()))
elseif Sys.isapple()
tkf marked this conversation as resolved.
Show resolved Hide resolved
key = "VECLIB_MAXIMUM_THREADS"
nt = _tryparse_env_int(key)
if nt === nothing
@warn "Failed to read environment variable $key" maxlog=1
else
return nt
end
end
@warn "Could not get number of BLAS threads. Returning `nothing` instead." maxlog=1
return nothing
end

Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,4 +553,14 @@ Base.stride(A::WrappedArray, i::Int) = stride(A.A, i)
end
end

@testset "get_set_num_threads" begin
default = BLAS.get_num_threads()
@test default isa Integer
@test default > 0
BLAS.set_num_threads(1)
@test BLAS.get_num_threads() == 1
BLAS.set_num_threads(default)
@test BLAS.get_num_threads() == default
end

end # module TestBLAS