Skip to content

Commit

Permalink
[hotfix] Symmetric mul! for GPUs (#201)
Browse files Browse the repository at this point in the history
* mul symmetric

* gpu fix

* symmetric fixed

* symv instead of Symmetric

* inroduced symul! function

* symul update

* Added comments
  • Loading branch information
sshin23 authored Aug 18, 2022
1 parent 22b1f99 commit 87f4e0b
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
3 changes: 2 additions & 1 deletion lib/MadNLPGPU/src/MadNLPGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import MadNLP:
AbstractOptions, AbstractLinearSolver, AbstractNLPModel, set_options!,
SymbolicException,FactorizationException,SolveException,InertiaException,
introduce, factorize!, solve!, improve!, is_inertia, inertia, tril_to_full!,
LapackOptions, input_type, is_supported, default_options
LapackOptions, input_type, is_supported, default_options, symul!

symul!(y, A, x::CuVector{T}, α = 1., β = 0.) where T = CUBLAS.symv!('L', T(α), A, x, T(β), y)


include("kernels.jl")
Expand Down
4 changes: 2 additions & 2 deletions lib/MadNLPGPU/src/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function LinearAlgebra.mul!(y::AbstractVector, kkt::MadNLP.DenseKKTSystem{T, VT,

# x and y can be host arrays. Copy them on the device to avoid side effect.
copyto!(d_x, x)
LinearAlgebra.mul!(d_y, kkt.aug_com, d_x)
symul!(d_y, kkt.aug_com, d_x)
copyto!(y, d_y)
end
function LinearAlgebra.mul!(y::MadNLP.ReducedKKTVector, kkt::MadNLP.DenseKKTSystem{T, VT, MT}, x::MadNLP.ReducedKKTVector) where {T, VT<:CuVector{T}, MT<:CuMatrix{T}}
Expand Down Expand Up @@ -266,7 +266,7 @@ function LinearAlgebra.mul!(y::AbstractVector, kkt::MadNLP.DenseCondensedKKTSyst

# Call parent() as CUDA does not dispatch on proper copyto! when passed a view
copyto!(d_x, 1, parent(x), 1, length(x))
LinearAlgebra.mul!(d_y, kkt.aug_com, d_x)
symul!(d_y, kkt.aug_com, d_x)
copyto!(y, d_y)
else
# Load buffers
Expand Down
6 changes: 3 additions & 3 deletions src/KKT/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ is_reduced(::DenseKKTSystem) = true
num_variables(kkt::DenseKKTSystem) = length(kkt.pr_diag)

function mul!(y::AbstractVector, kkt::DenseKKTSystem, x::AbstractVector)
mul!(y, Symmetric(kkt.aug_com, :L), x)
symul!(y, kkt.aug_com, x)
end
function mul!(y::ReducedKKTVector, kkt::DenseKKTSystem, x::ReducedKKTVector)
mul!(full(y), kkt.aug_com, full(x))
Expand Down Expand Up @@ -354,7 +354,7 @@ function _mul_expanded!(y::AbstractVector, kkt::DenseCondensedKKTSystem, x::Abst

# / x (variable)
yx .= Σx .* xx
mul!(yx, kkt.hess, xx, 1.0, 1.0)
symul!(yx, kkt.hess, xx)
mul!(yx, kkt.jac', xy, 1.0, 1.0)

# / s (slack)
Expand All @@ -371,7 +371,7 @@ end
function mul!(y::AbstractVector, kkt::DenseCondensedKKTSystem, x::AbstractVector)
# TODO: implement properly with AbstractKKTRHS
if length(y) == length(x) == size(kkt.aug_com, 1)
mul!(y, Symmetric(kkt.aug_com, :L), x)
symul!(y, kkt.aug_com, x)
else
_mul_expanded!(y, kkt, x)
end
Expand Down
2 changes: 1 addition & 1 deletion src/MadNLP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import MathOptInterface
import Libdl: dlopen, dlext, RTLD_DEEPBIND, RTLD_GLOBAL
import Printf: @sprintf
import LinearAlgebra: BLAS, Adjoint, Symmetric, mul!, ldiv!, norm, dot
import LinearAlgebra.BLAS: axpy!, libblas, liblapack, BlasInt, @blasfunc
import LinearAlgebra.BLAS: axpy!, symv!, libblas, liblapack, BlasInt, @blasfunc
import SparseArrays: AbstractSparseMatrix, SparseMatrixCSC, sparse, getcolptr, rowvals, nnz
import Base: string, show, print, size, getindex, copyto!, @kwdef
import SuiteSparse: UMFPACK
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ for (name,level,color) in [(:trace,TRACE,7),(:debug,DEBUG,6),(:info,INFO,256),(:
end

# BLAS
# CUBLAS currently does not import symv!,
# so using symv! is not dispatched to CUBLAS.symv!
# symul! wraps symv! and dispatch based on the data type
symul!(y, A, x::AbstractVector{T}, α = 1, β = 0) where T = BLAS.symv!('L', T(α), A, x, T(β), y)

const blas_num_threads = Ref{Int}(1)
function set_blas_num_threads(n::Integer;permanent::Bool=false)
permanent && (blas_num_threads[]=n)
Expand Down

0 comments on commit 87f4e0b

Please sign in to comment.