Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BLAS.with_num_threads #41785

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ Standard library changes

#### LinearAlgebra

* A new unexported function `BLAS.with_num_threads` allows you to temporarily change the number of
BLAS threads. ([#41785])

#### Markdown

#### Printf
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ LinearAlgebra.BLAS.trsv!
LinearAlgebra.BLAS.trsv
LinearAlgebra.BLAS.set_num_threads
LinearAlgebra.BLAS.get_num_threads
LinearAlgebra.BLAS.with_num_threads
```

## LAPACK functions
Expand Down
53 changes: 53 additions & 0 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ Set the number of threads the BLAS library should use equal to `n::Integer`.

Also accepts `nothing`, in which case julia tries to guess the default number of threads.
Passing `nothing` is discouraged and mainly exists for historical reasons.

See also [`with_num_threads`](@ref BLAS.with_num_threads) to temporarily change
number of BLAS threads.
"""
set_num_threads(nt::Integer)::Nothing = lbt_set_num_threads(Int32(nt))
function set_num_threads(::Nothing)
Expand Down Expand Up @@ -157,6 +160,56 @@ function check()
end
end

"""
with_num_threads(f, num_threads::Integer)

Run function `f()` with BLAS threads `num_threads` and then
restore to previous threads setting.

!!! compat "Julia 1.8"
`with_num_threads` requires at least Julia 1.8.

# Example

Depending on the number of available CPU cores, the result can be different:

```julia
julia> BLAS.get_num_threads()
johnnychen94 marked this conversation as resolved.
Show resolved Hide resolved
8

julia> with_num_threads(4) do
BLAS.get_num_threads()
# or doing some basic BLAS computation
end
4

julia> BLAS.get_num_threads()
8
```

!!! warning
This function is not thread safe. If there are multiple
threads calling BLAS routines, then the threads they are
using will also be changed until this function finishes.

!!! warning
This interface is experimental and subject to change or
removal without notice.

See also [`set_num_threads`](@ref BLAS.set_num_threads) to permanently change
number of BLAS threads.
"""
function with_num_threads(f, num_threads::Integer)
prev_num_threads = BLAS.get_num_threads()
BLAS.set_num_threads(num_threads)
try
return f()
catch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary catch

rethrow()
finally
BLAS.set_num_threads(prev_num_threads)
end
end

# Level 1
## copy
Expand Down
24 changes: 24 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,30 @@ end
@test BLAS.get_num_threads() === default
end

@testset "with_num_threads" begin
prev_num_threads = BLAS.get_num_threads()
context_num_threads = BLAS.with_num_threads(1) do
BLAS.get_num_threads()
end
@test context_num_threads == 1
@test prev_num_threads == BLAS.get_num_threads()

@testset "thread unsafe" begin
prev_num_threads = BLAS.get_num_threads()
context_num_threads = 1
# task A
t = @async BLAS.with_num_threads(context_num_threads) do
sleep(0.5)
end
sleep(0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This execution order is not guaranteed. Use a barrier instead to guarantee the ordering of events

# check that main thread is affected by task A
@test BLAS.get_num_threads() == context_num_threads
# when the task finishes, the num threads get restored
wait(t)
@test prev_num_threads == BLAS.get_num_threads()
end
end

# https://github.com/JuliaLang/julia/pull/39845
@test LinearAlgebra.BLAS.libblas == "libblastrampoline"
@test LinearAlgebra.BLAS.liblapack == "libblastrampoline"
Expand Down