Skip to content

Commit

Permalink
Fix NonlinearOperator with Array
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Aug 31, 2023
1 parent ca92705 commit 0d39e70
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
20 changes: 9 additions & 11 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,6 @@ for f in MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS
end
end

function LinearAlgebra.det(A::LinearAlgebra.Symmetric{<:AbstractJuMPScalar})
return GenericNonlinearExpr{variable_ref_type(eltype(A))}(:det, A)
end

# Multivariate operators

# The multivariate operators in MOI are +, -, *, ^, /, ifelse, atan
Expand Down Expand Up @@ -510,8 +506,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
for i in length(f.args):-1:1
if f.args[i] isa GenericNonlinearExpr{V}
push!(stack, (ret, i, f.args[i]))
elseif arg.args[i] isa AbstractArray
child.args[i] = moi_function.(arg.args[i])
elseif f.args[i] isa AbstractArray
ret.args[i] = moi_function.(f.args[i])
else
ret.args[i] = moi_function(f.args[i])
end
Expand Down Expand Up @@ -827,32 +823,34 @@ function Base.show(io::IO, f::NonlinearOperator)
return print(io, "NonlinearOperator(:$(f.head), $(f.func))")
end

const AbstractJuMPScalarOrArray = Union{AbstractJuMPScalar, AbstractArray{<:AbstractJuMPScalar}}

# Fast overload for unary calls

(f::NonlinearOperator)(x) = f.func(x)

(f::NonlinearOperator)(x::AbstractJuMPScalar) = NonlinearExpr(f.head, Any[x])
(f::NonlinearOperator)(x::AbstractJuMPScalarOrArray) = NonlinearExpr(f.head, Any[x])

# Fast overload for binary calls

(f::NonlinearOperator)(x, y) = f.func(x, y)

function (f::NonlinearOperator)(x::AbstractJuMPScalar, y)
function (f::NonlinearOperator)(x::AbstractJuMPScalarOrArray, y)
return GenericNonlinearExpr(f.head, Any[x, y])
end

function (f::NonlinearOperator)(x, y::AbstractJuMPScalar)
function (f::NonlinearOperator)(x, y::AbstractJuMPScalarOrArray)
return GenericNonlinearExpr(f.head, Any[x, y])
end

function (f::NonlinearOperator)(x::AbstractJuMPScalar, y::AbstractJuMPScalar)
function (f::NonlinearOperator)(x::AbstractJuMPScalarOrArray, y::AbstractJuMPScalarOrArray)
return GenericNonlinearExpr(f.head, Any[x, y])
end

# Fallback for more arguments
function (f::NonlinearOperator)(x, y, z...)
args = (x, y, z...)
if any(Base.Fix2(isa, AbstractJuMPScalar), args)
if any(Base.Fix2(isa, AbstractJuMPScalarOrArray), args)
return GenericNonlinearExpr(f.head, Any[a for a in args])
end
return f.func(args...)
Expand Down
11 changes: 11 additions & 0 deletions test/test_nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
module TestNLPExpr

using JuMP
using LinearAlgebra
using Test

function test_extension_univariate_operators(
Expand Down Expand Up @@ -828,4 +829,14 @@ function test_redefinition_of_function()
return
end

function test_array()
model = Model()
@variable(model, x)
op_norm = NonlinearOperator(:det, det)
@objective(model, Min, op_norm([x]))
f = MOI.get(model, MOI.ObjectiveFunction{MOI.ScalarNonlinearFunction}())
@test f.head == :norm
@test f.args == [[index(x)]]
end

end # module

0 comments on commit 0d39e70

Please sign in to comment.