Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: enables NSCF calculation for DeePKS #1746

Merged
merged 3 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/module_deepks/LCAO_deepks.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class LCAO_Deepks
int nmaxd = 0; //#. descriptors per l
int inlmax = 0; //tot. number {i,n,l} - atom, n, l

bool init_pdm = false; //for DeePKS NSCF calculation

// deep neural network module that provides corrected Hamiltonian term and
// related derivatives.
torch::jit::script::Module module;
Expand Down
45 changes: 43 additions & 2 deletions source/module_deepks/LCAO_deepks_pdm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,32 @@ void LCAO_Deepks::cal_projected_DM(const ModuleBase::matrix &dm,
ModuleBase::TITLE("LCAO_Deepks", "cal_projected_DM");
ModuleBase::timer::tick("LCAO_Deepks","cal_projected_DM");

const int pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
if (GlobalV::init_chg == "file" && !this->init_pdm) //for DeePKS NSCF calculation
{
ifstream ifs("pdm.dat");
if (!ifs)
{
ModuleBase::WARNING_QUIT("LCAO_Deepks::cal_projected_DM", "Can not find the file pdm.dat . Please do DeePKS SCF calculation first.");
}
for(int inl=0;inl<this->inlmax;inl++)
{
for(int ind=0;ind<pdm_size;ind++)
{
double c;
ifs >> c;
pdm[inl][ind] = c;
}
}
this->init_pdm = true;
return;
}

if(dm.nr == 0 && dm.nc ==0)
{
return;
}

const int pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
for(int inl=0;inl<inlmax;inl++)
{
ModuleBase::GlobalFunc::ZEROS(pdm[inl],pdm_size);
Expand Down Expand Up @@ -146,6 +166,28 @@ void LCAO_Deepks::cal_projected_DM_k(const std::vector<ModuleBase::ComplexMatrix
const int nks,
const std::vector<ModuleBase::Vector3<double>> &kvec_d)
{
const int pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);

if (GlobalV::init_chg == "file" && !this->init_pdm) //for DeePKS NSCF calculation
{
ifstream ifs("pdm.dat");
if (!ifs)
{
ModuleBase::WARNING_QUIT("LCAO_Deepks::cal_projected_DM_k","Can not find the file pdm.dat . Please do DeePKS SCF calculation first.");
}
for(int inl=0;inl<this->inlmax;inl++)
{
for(int ind=0;ind<pdm_size;ind++)
{
double c;
ifs >> c;
pdm[inl][ind] = c;
}
}
this->init_pdm = true;
return;
}

//check for skipping
if(dm[0].nr == 0 && dm[0].nc ==0)
{
Expand All @@ -154,7 +196,6 @@ void LCAO_Deepks::cal_projected_DM_k(const std::vector<ModuleBase::ComplexMatrix
}
ModuleBase::timer::tick("LCAO_Deepks","cal_projected_DM_k");

const int pdm_size = (this->lmaxd * 2 + 1) * (this->lmaxd * 2 + 1);
for(int inl=0;inl<inlmax;inl++)
{
ModuleBase::GlobalFunc::ZEROS(pdm[inl],pdm_size);
Expand Down
1 change: 1 addition & 0 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ void ESolver_KS_LCAO::afterscf(const int istep)
GlobalC::kv.nks,
GlobalC::kv.kvec_d);
}
GlobalC::ld.check_projected_dm(); //print out the projected dm for NSCF calculaiton
GlobalC::ld.cal_descriptor(); // final descriptor
GlobalC::ld.check_descriptor(GlobalC::ucell);

Expand Down
48 changes: 48 additions & 0 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,54 @@ namespace ModuleESolver
bp.Macroscopic_polarization(this->psi);
}

//below is for DeePKS NSCF calculation
#ifdef __DEEPKS
const Parallel_Orbitals* pv = this->LOWF.ParaV;
if (GlobalV::deepks_out_labels || GlobalV::deepks_scf)
{
if (GlobalV::GAMMA_ONLY_LOCAL)
{
GlobalC::ld.cal_projected_DM(this->LOC.dm_gamma[0],
GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
pv->trace_loc_row,
pv->trace_loc_col);
}
else
{
GlobalC::ld.cal_projected_DM_k(this->LOC.dm_k,
GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
pv->trace_loc_row,
pv->trace_loc_col,
GlobalC::kv.nks,
GlobalC::kv.kvec_d);
}
GlobalC::ld.cal_descriptor(); // final descriptor
GlobalC::ld.cal_gedm(GlobalC::ucell.nat);
if (GlobalV::GAMMA_ONLY_LOCAL)
{
GlobalC::ld.add_v_delta(GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
pv->trace_loc_row,
pv->trace_loc_col,
pv->nrow,
pv->ncol);
}
else
{
GlobalC::ld.add_v_delta_k(GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
pv->trace_loc_row,
pv->trace_loc_col,
pv->nnr);
}
}
#endif
return;
}

Expand Down