Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BLAS.with_num_threads #41785

Closed

Conversation

johnnychen94
Copy link
Sponsor Member

@johnnychen94 johnnychen94 commented Aug 4, 2021

I added this BLAS.with_num_threads as a convenient wrapper, but I'm not sure if we want it live in Base because it's not thread-safe.

In some cases, multi-threading could happen at a more efficient higher level instead of at the BLAS level. Thus we may want to temporarily disable BLAS threads during the julia threads execution.

As an example, SVD is one typical case: if there are a bunch of SVD operations and if we can do embarrassingly parallel at a higher level, then it's better to disable the BLAS threads.

using LinearAlgebra

function make_test_data(n)
    U = rand(n, n)
    S = rand(n)
    V = rand(n, n)
    U * Diagonal(S) * V'
end

function svd_solver(X, k=5)
    U, S, Vt = svd(X)
    S[k:end] .= 0
    U * Diagonal(S) * Vt
end

dataset = [make_test_data(64) for _ in 1:64]

function svd_with_blas_threads(dataset)
    losses = zeros(size(dataset))
    Threads.@threads for i in eachindex(dataset)
        X = dataset[i]
        losses[i] = norm(X - svd_solver(X))
    end
    return losses
end

function svd_without_blas_threads(dataset)
    losses = zeros(size(dataset))
    BLAS.with_num_threads(1) do 
        Threads.@threads for i in eachindex(dataset)
            X = dataset[i]
            losses[i] = norm(X - svd_solver(X))
        end
    end
    return losses
end

Threads.nthreads() # 8
BLAS.get_num_threads() # 8

@btime svd_with_blas_threads(dataset); # 15.149 ms (1194 allocations: 18.55 MiB)
@btime svd_without_blas_threads(dataset);# 10.411 ms (1194 allocations: 18.55 MiB)

Would need to add a news entry if approved.

@ViralBShah ViralBShah added the linear algebra Linear algebra label Aug 15, 2021
@ViralBShah
Copy link
Member

@KristofferC @dkarrasch @chriselrod @YingboMa Thoughts on this?

I generally don't like code and APIs that code in numbers of threads, but this is a reality we have to work with today. Since this is an unexported API, I feel comfortable merging.

@ViralBShah
Copy link
Member

@johnnychen94 Any ideas why the doctests are failing?

@johnnychen94
Copy link
Sponsor Member Author

@johnnychen94 Any ideas why the doctests are failing?

That seems to be noise. After rebasing on the current master the test passes now.

@chriselrod
Copy link
Contributor

@KristofferC @dkarrasch @chriselrod @YingboMa Thoughts on this?

I generally don't like code and APIs that code in numbers of threads, but this is a reality we have to work with today. Since this is an unexported API, I feel comfortable merging.

Would you prefer they're integrated with Julia's threading and the scheduler handles it?

I think this is an improvement, but it'd be concerning if we want to break it later (e.g. after/if this merges).

@ViralBShah
Copy link
Member

ViralBShah commented Aug 19, 2021

I think this is an improvement, but it'd be concerning if we want to break it later (e.g. after/if this merges).

That PR has not seen any real work since it was created. I believe this PR would be compatible with that, in any case.

@DilumAluthge
Copy link
Member

Also, if we make the additions in this PR experimental and not part of the public API, then we can break it (and even remove it) in a future minor release.

@johnnychen94
Copy link
Sponsor Member Author

Should I only tag with_num_threads as experimental or also to set_num_threads?

@ViralBShah
Copy link
Member

Only with_num_threads needs to be marked experimental.

@ViralBShah
Copy link
Member

Let's rebase this and get it merged.

@testset "thread unsafe" begin
prev_num_threads = BLAS.get_num_threads()
# thread unsafe
@async BLAS.with_num_threads(1) do
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this test testing?

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to simulate the case that when there exists a task with with_num_threads, then all threads are affected until that task is finished.

BLAS.set_num_threads(num_threads)
local retval
try
retval = f()
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
retval = f()
return f()

and fixup the rest?

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I must have messed up with rebasing on a computer that has different local git history. This is pointed out in #41785 (comment) already.
Thanks for catching this again!

Copy link
Member

@vchuravy vchuravy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really not a fan of this. Mutating global state like this makes this effectively unusable from a library writers perspective.

Another question that came to mind is if this is portable across blas vendors? Or is this OpenBLAS specific.

t = @async BLAS.with_num_threads(context_num_threads) do
sleep(0.5)
end
sleep(0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This execution order is not guaranteed. Use a barrier instead to guarantee the ordering of events

BLAS.set_num_threads(num_threads)
try
return f()
catch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary catch

@chriselrod
Copy link
Contributor

chriselrod commented Nov 3, 2021

Another question that came to mind is if this is portable across blas vendors? Or is this OpenBLAS specific.

BLAS.get_num_threads and BLAS.set_num_threads are supported by MKL as well.

@johnnychen94
Copy link
Sponsor Member Author

I am really not a fan of this.

Okay. Should I fix the comments or just close this PR? I'm okay with having this on my local projects. I propose this PR as I feel like there might be other people wanting it.

@johnnychen94
Copy link
Sponsor Member Author

This isn't a very good solution so let me just close this one.

@johnnychen94 johnnychen94 deleted the jc/blas_with_num_threads branch November 11, 2021 09:13
@ViralBShah
Copy link
Member

Yes, I have been thinking that it isn't a good idea to change global state in base Julia like this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
linear algebra Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants