Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort eigenvalues after LOBPCG #964

Merged
merged 4 commits into from
Mar 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,21 @@ end
end


function final_retval(X, AX, resid_history, niter, n_matvec)
λ = real(diag(X' * AX))
residuals = AX .- X*Diagonal(λ)
(; λ, X,
residual_norms=[norm(residuals[:, i]) for i = 1:size(residuals, 2)],
function final_retval(X, AX, BX, resid_history, niter, n_matvec)
λ = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:size(X, 2)]
λ = real(oftype(X[:, 1], λ)) # Offload to GPU if needed
residuals = AX .- BX .* λ'
if !issorted(λ)
p = sortperm(λ)
λ = λ[p]
residuals = residuals[:, p]
X = X[:, p]
AX = AX[:, p]
BX = BX[:, p]
resid_history = resid_history[p, :]
end
(; λ, X, AX, BX,
residual_norms=norm.(eachcol(residuals)),
residual_history=resid_history[:, 1:niter+1], n_matvec)
end

Expand Down Expand Up @@ -358,10 +368,11 @@ end
nlocked = 0
niter = 0 # the first iteration is fake
λs = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:M]
λs = oftype(X[:, 1], λs) # Offload to GPU if needed
λs = real(oftype(X[:, 1], λs)) # Offload to GPU if needed
new_X = X
new_AX = AX
new_BX = BX
# The full_ arrays contain all the vectors, the others only get the active ones
full_X = X
full_AX = AX
full_BX = BX
Expand Down Expand Up @@ -435,7 +446,7 @@ end
if nlocked >= n_conv_check # Converged!
X .= new_X # Update the part of X which is still active
AX .= new_AX
return final_retval(full_X, full_AX, resid_history, niter, n_matvec)
return final_retval(full_X, full_AX, full_BX, resid_history, niter, n_matvec)
end
newly_locked = nlocked - prev_nlocked
active = newly_locked+1:size(X,2) # newly active vectors
Expand Down Expand Up @@ -524,5 +535,5 @@ end
niter = niter + 1
end

final_retval(full_X, full_AX, resid_history, maxiter, n_matvec)
final_retval(full_X, full_AX, full_BX, resid_history, maxiter, n_matvec)
end
Loading