Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the missing EXX operator and minor refactor in LR-HSE #4696

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions source/module_lr/esolver_lrtd_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ void LR::ESolver_LR<T, TR>::set_dimension()

template <typename T, typename TR>
LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol,
const Input_para& inp,
UnitCell& ucell)
const Input_para& inp, UnitCell& ucell)
: input(inp), ucell(ucell)
#ifdef __EXX
, exx_info(GlobalC::exx_info)
#endif
{
redirect_log(inp.out_alllog);
ModuleBase::TITLE("ESolver_LR", "ESolver_LR");
Expand Down Expand Up @@ -162,9 +164,7 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
{
Cpxgemr2d(this->nbasis, this->nbands, &(*ks_sol.psi)(ik, 0, 0), 1, start_band + 1, ks_sol.ParaV.desc_wfc,
&(*this->psi_ks)(ik, 0, 0), 1, 1, this->paraC_.desc, this->paraC_.blacs_ctxt);
for (int ib = 0;ib < this->nbands;++ib) {
this->eig_ks(ik, ib) = ks_sol.pelec->ekb(ik, start_band + ib);
}
for (int ib = 0;ib < this->nbands;++ib) { this->eig_ks(ik, ib) = ks_sol.pelec->ekb(ik, start_band + ib); }
}
}
#else
Expand Down Expand Up @@ -198,8 +198,8 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
this->move_exx_lri(ks_sol.exx_lri_complex);
} else // construct C, V from scratch
{
this->exx_lri = std::make_shared<Exx_LRI<T>>(GlobalC::exx_info.info_ri);
this->exx_lri->init(MPI_COMM_WORLD, this->kv); // using GlobalC::ORB
this->exx_lri = std::make_shared<Exx_LRI<T>>(exx_info.info_ri);
this->exx_lri->init(MPI_COMM_WORLD, this->kv);
this->exx_lri->cal_exx_ions();
}
}
Expand All @@ -209,6 +209,9 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol

template <typename T, typename TR>
LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : input(inp), ucell(ucell)
#ifdef __EXX
, exx_info(GlobalC::exx_info)
#endif
{
redirect_log(inp.out_alllog);
ModuleBase::TITLE("ESolver_LR", "ESolver_LR");
Expand Down Expand Up @@ -359,8 +362,8 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
#ifdef __EXX
if ((xc_kernel == "hf" || xc_kernel == "hse") && this->input.lr_solver != "spectrum")
{
this->exx_lri = std::make_shared<Exx_LRI<T>>(GlobalC::exx_info.info_ri);
this->exx_lri->init(MPI_COMM_WORLD, this->kv); // using GlobalC::ORB
this->exx_lri = std::make_shared<Exx_LRI<T>>(exx_info.info_ri);
this->exx_lri->init(MPI_COMM_WORLD, this->kv);
this->exx_lri->cal_exx_ions();
}
// else
Expand All @@ -380,7 +383,7 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
// allocate and initialize A matrix and density matrix
hamilt::Hamilt<T>* phamilt = new HamiltCasidaLR<T>(xc_kernel, this->nspin, this->nbasis, this->nocc, this->nvirt, this->ucell, GlobalC::GridD, this->psi_ks, this->eig_ks,
#ifdef __EXX
this->exx_lri.get(),
this->exx_lri.get(), this->exx_info.info_global.hybrid_alpha,
#endif
this->gint_, this->pot, this->kv, &this->paraX_, &this->paraC_, &this->paraMat_);
// solve the Casida equation
Expand Down
1 change: 1 addition & 0 deletions source/module_lr/esolver_lrtd_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ namespace LR
std::shared_ptr<Exx_LRI<T>> exx_lri = nullptr;
void move_exx_lri(std::shared_ptr<Exx_LRI<double>>&);
void move_exx_lri(std::shared_ptr<Exx_LRI<std::complex<double>>>&);
const Exx_Info& exx_info;
#endif
};
}
5 changes: 3 additions & 2 deletions source/module_lr/hamilt_casida.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace LR
const ModuleBase::matrix& eig_ks,
#ifdef __EXX
Exx_LRI<T>* exx_lri_in,
const double& exx_alpha,
#endif
TGint* gint_in,
PotHxcLR* pot_in,
Expand All @@ -44,10 +45,10 @@ namespace LR
this->DM_trans, gint_in, pot_in, ucell_in, gd_in, kv_in, pX_in, pc_in, pmat_in);
this->ops->add(lr_hxc);
#ifdef __EXX
if (xc_kernel == "hf")
if (xc_kernel == "hf" || xc_kernel == "hse")
{ //add Exx operator
hamilt::Operator<T>* lr_exx = new OperatorLREXX<T>(nspin, naos, nocc, nvirt, ucell_in, psi_ks_in,
this->DM_trans, exx_lri_in, kv_in, pX_in, pc_in, pmat_in);
this->DM_trans, exx_lri_in, kv_in, pX_in, pc_in, pmat_in, exx_alpha);
this->ops->add(lr_exx);
}
#endif
Expand Down
21 changes: 9 additions & 12 deletions source/module_lr/operator_casida/operator_lr_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ namespace LR
for (int iat2 = 0;iat2 < ucell.nat;++iat2) {
for (auto cell : this->BvK_cells) {
this->Ds_onebase[is][iat1][std::make_pair(iat2, cell)] =
RI::Tensor<T>({ static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw), static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw) });
}
}
}
}
RI::Tensor<T>({ static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw), static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw) });
}
}
}
}
}

template<>
Expand Down Expand Up @@ -86,12 +86,11 @@ namespace LR

// 1. set_Ds (once)
// convert to vector<T*> for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor)
std::cout << "ib=" << ib << std::endl;
std::vector<std::vector<T>> DMk_trans_vector = this->DM_trans[ib]->get_DMK_vector();
assert(DMk_trans_vector.size() == nks);
std::vector<const std::vector<T>*> DMk_trans_pointer(nks);
for (int is = 0;is < nks;++is) { DMk_trans_pointer[is] = &DMk_trans_vector[is];
}
for (int is = 0;is < nks;++is) { DMk_trans_pointer[is] = &DMk_trans_vector[is]; }

// if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex<double>)
std::vector<std::map<TA, std::map<TAC, RI::Tensor<T>>>> Ds_trans =
RI_2D_Comm::split_m2D_ktoR<T>(this->kv, DMk_trans_pointer, *this->pmat);
Expand All @@ -118,10 +117,8 @@ namespace LR
for (int is = 0;is < this->nspin;++is)
{
this->cal_DM_onebase(this->pX->local2global_col(io), this->pX->local2global_row(iv), ik, is); //set Ds_onebase
psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) -=
0.5 * //minus for exchange, 0.5 for spin
GlobalC::exx_info.info_global.hybrid_alpha *
this->exx_lri->exx_lri.post_2D.cal_energy(this->Ds_onebase[is], this->exx_lri->Hexxs[is]);
psi_out_bfirst(ik, io * this->pX->get_row_size() + iv) -= 0.5 * //minus for exchange, 0.5 for spin
alpha * this->exx_lri->exx_lri.post_2D.cal_energy(this->Ds_onebase[is], this->exx_lri->Hexxs[is]);
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions source/module_lr/operator_casida/operator_lr_exx.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ namespace LR
const K_Vectors& kv_in,
Parallel_2D* pX_in,
Parallel_2D* pc_in,
Parallel_Orbitals* pmat_in)
Parallel_Orbitals* pmat_in,
const double& alpha = 1.0)
: nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt),
psi_ks(psi_ks_in), DM_trans(DM_trans_in), exx_lri(exx_lri_in), kv(kv_in),
pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in)
pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in), alpha(alpha)
{
ModuleBase::TITLE("OperatorLREXX", "OperatorLREXX");
this->cal_type = hamilt::calculation_type::lcao_exx;
Expand All @@ -57,10 +58,11 @@ namespace LR
virtual void act(const psi::Psi<T>& psi_in, psi::Psi<T>& psi_out, const int nbands) const override;
private:
//global sizes
int nspin = 1;
int naos;
int nocc;
int nvirt;
const int& nspin;
const int& naos;
const int& nocc;
const int& nvirt;
const double& alpha;
const K_Vectors& kv;
/// ground state wavefunction
const psi::Psi<T>* psi_ks = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion source/module_lr/potentials/pot_hxc_lrtd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace LR
std::set<std::string> local_xc = { "lda", "pbe", "hse" };
if (local_xc.find(this->xc_kernel) != local_xc.end())
{
XC_Functional::set_xc_type(this->xc_kernel);
XC_Functional::set_xc_type(this->xc_kernel); // for hse, (1-alpha) and omega are set here
this->xc_kernel_components_.cal_kernel(chg_gs, ucell_in, this->nspin);
}
}
Expand Down
Loading