diff --git a/src/eigen/lobpcg_hyper_impl.jl b/src/eigen/lobpcg_hyper_impl.jl index 44eb12cbc..7f25e9882 100644 --- a/src/eigen/lobpcg_hyper_impl.jl +++ b/src/eigen/lobpcg_hyper_impl.jl @@ -298,11 +298,21 @@ end end -function final_retval(X, AX, resid_history, niter, n_matvec) - λ = real(diag(X' * AX)) - residuals = AX .- X*Diagonal(λ) - (; λ, X, - residual_norms=[norm(residuals[:, i]) for i = 1:size(residuals, 2)], +function final_retval(X, AX, BX, resid_history, niter, n_matvec) + λ = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:size(X, 2)] + λ = real(oftype(X[:, 1], λ)) # Offload to GPU if needed + residuals = AX .- BX .* λ' + if !issorted(λ) + p = sortperm(λ) + λ = λ[p] + residuals = residuals[:, p] + X = X[:, p] + AX = AX[:, p] + BX = BX[:, p] + resid_history = resid_history[p, :] + end + (; λ, X, AX, BX, + residual_norms=norm.(eachcol(residuals)), residual_history=resid_history[:, 1:niter+1], n_matvec) end @@ -358,10 +368,11 @@ end nlocked = 0 niter = 0 # the first iteration is fake λs = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:M] - λs = oftype(X[:, 1], λs) # Offload to GPU if needed + λs = real(oftype(X[:, 1], λs)) # Offload to GPU if needed new_X = X new_AX = AX new_BX = BX + # The full_ arrays contain all the vectors, the others only get the active ones full_X = X full_AX = AX full_BX = BX @@ -435,7 +446,7 @@ end if nlocked >= n_conv_check # Converged! X .= new_X # Update the part of X which is still active AX .= new_AX - return final_retval(full_X, full_AX, resid_history, niter, n_matvec) + return final_retval(full_X, full_AX, full_BX, resid_history, niter, n_matvec) end newly_locked = nlocked - prev_nlocked active = newly_locked+1:size(X,2) # newly active vectors @@ -524,5 +535,5 @@ end niter = niter + 1 end - final_retval(full_X, full_AX, resid_history, maxiter, n_matvec) + final_retval(full_X, full_AX, full_BX, resid_history, maxiter, n_matvec) end