diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 327beb020901b..c96fc9b965b1e 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -79,7 +79,7 @@ import LinearAlgebra: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismat include("lbt.jl") """ -get_config() + get_config() Return an object representing the current `libblastrampoline` configuration. diff --git a/stdlib/LinearAlgebra/src/lbt.jl b/stdlib/LinearAlgebra/src/lbt.jl index 67ce521a9aa7e..dd8c99b3f05f6 100644 --- a/stdlib/LinearAlgebra/src/lbt.jl +++ b/stdlib/LinearAlgebra/src/lbt.jl @@ -17,6 +17,7 @@ const LBT_INTERFACE_MAP = Dict( LBT_INTERFACE_ILP64 => :ilp64, LBT_INTERFACE_UNKNOWN => :unknown, ) +const LBT_INV_INTERFACE_MAP = Dict(v => k for (k, v) in LBT_INTERFACE_MAP) const LBT_F2C_PLAIN = 0 const LBT_F2C_REQUIRED = 1 @@ -26,6 +27,7 @@ const LBT_F2C_MAP = Dict( LBT_F2C_REQUIRED => :required, LBT_F2C_UNKNOWN => :unknown, ) +const LBT_INV_F2C_MAP = Dict(v => k for (k, v) in LBT_F2C_MAP) struct LBTLibraryInfo libname::String @@ -164,14 +166,74 @@ function lbt_get_default_func() return ccall((:lbt_get_default_func, libblastrampoline), Ptr{Cvoid}, ()) end -#= -Don't define footgun API (yet) +""" + lbt_find_backing_library(symbol_name, interface; config::LBTConfig = lbt_get_config()) -function lbt_get_forward(symbol_name, interface, f2c = LBT_F2C_PLAIN) - return ccall((:lbt_get_forward, libblastrampoline), Ptr{Cvoid}, (Cstring, Int32, Int32), symbol_name, interface, f2c) +Return the `LBTLibraryInfo` that represents the backing library for the given symbol +exported from libblastrampoline. This allows us to discover which library will service +a particular BLAS call from Julia code. This method returns `nothing` if either of the +following conditions are met: + + * No loaded library exports the desired symbol (the default function will be called) + * The symbol was set via `lbt_set_forward()`, which does not track library provenance. + +If the given `symbol_name` is not contained within the list of exported symbols, an +`ArgumentError` will be thrown. +""" +function lbt_find_backing_library(symbol_name, interface::Symbol; + config::LBTConfig = lbt_get_config()) + if interface ∉ (:ilp64, :lp64) + throw(Argument("Invalid interface specification: '$(interface)'")) + end + symbol_idx = findfirst(s -> s == symbol_name, config.exported_symbols) + if symbol_idx === nothing + throw(ArgumentError("Invalid exported symbol name '$(symbol_name)'")) + end + # Convert to zero-indexed + symbol_idx -= 1 + + forward_byte_offset = div(symbol_idx, 8) + forward_byte_mask = 1 << mod(symbol_idx, 8) + for lib in filter(l -> l.interface == interface, config.loaded_libs) + if lib.active_forwards[forward_byte_offset+1] & forward_byte_mask != 0x00 + return lib + end + end + + # No backing library was found + return nothing end + +## NOTE: Manually setting forwards is referred to as the 'footgun API'. It allows truly +## bizarre and complex setups to be created. If you run into strange errors while using +## it, the first thing you should ask yourself is whether it's truly function lbt_set_forward(symbol_name, addr, interface, f2c = LBT_F2C_PLAIN; verbose::Bool = false) - return ccall((:lbt_set_forward, libblastrampoline), Int32, (Cstring, Ptr{Cvoid}, Int32, Int32, Int32), symbol_name, addr, interface, f2c, verbose ? 1 : 0) + return ccall( + (:lbt_set_forward, libblastrampoline), + Int32, + (Cstring, Ptr{Cvoid}, Int32, Int32, Int32), + string(symbol_name), + addr, + Int32(interface), + Int32(f2c), + verbose ? Int32(1) : Int32(0), + ) +end +function lbt_set_forward(symbol_name, addr, interface::Symbol, f2c::Symbol = :plain; kwargs...) + return lbt_set_forward(symbol_name, addr, LBT_INV_INTERFACE_MAP[interface], LBT_INV_F2C_MAP[f2c]; kwargs...) +end + +function lbt_get_forward(symbol_name, interface, f2c = LBT_F2C_PLAIN) + return ccall( + (:lbt_get_forward, libblastrampoline), + Ptr{Cvoid}, + (Cstring, Int32, Int32), + string(symbol_name), + Int32(interface), + Int32(f2c), + ) +end +function lbt_get_forward(symbol_name, interface::Symbol, f2c::Symbol = :plain) + return lbt_get_forward(symbol_name, LBT_INV_INTERFACE_MAP[interface], LBT_INV_F2C_MAP[f2c]) end -=# \ No newline at end of file