Skip to content

Commit

Permalink
Cleanup Normal Form Equation Construction
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 1, 2023
1 parent eadf16f commit 954a799
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 23 deletions.
5 changes: 3 additions & 2 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}

# Use normal form to solve the Linear Problem
if cache.JᵀJ !== nothing
__update_JᵀJ!(Val{iip}(), cache, :JᵀJ, cache.J)
__update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu)
__update_JᵀJ!(cache, Val(:JᵀJ))
__update_Jᵀf!(cache, Val(:JᵀJ))
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
else
A, b = cache.J, _vec(cache.fu)
Expand Down Expand Up @@ -148,6 +148,7 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
return nothing
end

# FIXME: Reinit `JᵀJ` operator if `p` is changed
function __reinit_internal!(cache::GaussNewtonCache;
termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...)
abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol,
Expand Down
31 changes: 10 additions & 21 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,29 +209,18 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
end

# Generic Handling of Krylov Methods for Normal Form Linear Solves
# FIXME: Use MaybeInplace here for efficient matmuls
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache)
if !(cache.JᵀJ isa KrylovJᵀJ)
@bb cache.JᵀJ = transpose(cache.J) × cache.J
end
end
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H

function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return mul!(_vec(getproperty(cache, sym1)), J', fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache)
if cache.JᵀJ isa KrylovJᵀJ
@bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu
else
@bb cache.Jᵀf = transpose(cache.J) × vec(cache.fu)
end
end

# Left-Right Multiplication
Expand Down

0 comments on commit 954a799

Please sign in to comment.