Skip to content

Commit

Permalink
generalize gint_kernel_rho (#3087)
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 authored Oct 23, 2023
1 parent 79547be commit 23124d3
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 100 deletions.
20 changes: 10 additions & 10 deletions source/module_hamilt_lcao/module_gint/gint_fvl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ void Gint::gint_kernel_force(
//Gint_Tools::mult_psi_DM(*this->gridt, this->bxyz, na_grid, LD_pool, block_iw, block_size, block_index, cal_flag,
// psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, DM_in, 2);
Gint_Tools::mult_psi_DM_new(*this->gridt, this->bxyz, grid_index, na_grid, LD_pool, block_iw, block_size, block_index, cal_flag,
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, this->DMRGint[is], 2);
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, this->DMRGint[is], false);
}
else
{
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, grid_index, na_grid, block_index, block_size, cal_flag,
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], 2);
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], false);
}

if(isforce)
Expand Down Expand Up @@ -210,24 +210,24 @@ void Gint::gint_kernel_force_meta(
dpsir_z_vlbr3.ptr_2D, dpsirz_v_DM.ptr_2D, DM_in, 2);
*/
Gint_Tools::mult_psi_DM_new(*this->gridt, this->bxyz, grid_index, na_grid, LD_pool, block_iw, block_size, block_index, cal_flag,
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, this->DMRGint[is], 2);
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, this->DMRGint[is], false);
Gint_Tools::mult_psi_DM_new(*this->gridt, this->bxyz, grid_index, na_grid, LD_pool, block_iw, block_size, block_index, cal_flag,
dpsir_x_vlbr3.ptr_2D, dpsirx_v_DM.ptr_2D, this->DMRGint[is], 2);
dpsir_x_vlbr3.ptr_2D, dpsirx_v_DM.ptr_2D, this->DMRGint[is], false);
Gint_Tools::mult_psi_DM_new(*this->gridt, this->bxyz, grid_index, na_grid, LD_pool, block_iw, block_size, block_index, cal_flag,
dpsir_y_vlbr3.ptr_2D, dpsiry_v_DM.ptr_2D, this->DMRGint[is], 2);
dpsir_y_vlbr3.ptr_2D, dpsiry_v_DM.ptr_2D, this->DMRGint[is], false);
Gint_Tools::mult_psi_DM_new(*this->gridt, this->bxyz, grid_index, na_grid, LD_pool, block_iw, block_size, block_index, cal_flag,
dpsir_z_vlbr3.ptr_2D, dpsirz_v_DM.ptr_2D, this->DMRGint[is], 2);
dpsir_z_vlbr3.ptr_2D, dpsirz_v_DM.ptr_2D, this->DMRGint[is], false);
}
else
{
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, grid_index, na_grid, block_index, block_size, cal_flag,
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], 2);
psir_vlbr3.ptr_2D, psir_vlbr3_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], false);
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, grid_index, na_grid, block_index, block_size, cal_flag,
dpsir_x_vlbr3.ptr_2D, dpsirx_v_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], 2);
dpsir_x_vlbr3.ptr_2D, dpsirx_v_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], false);
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, grid_index, na_grid, block_index, block_size, cal_flag,
dpsir_y_vlbr3.ptr_2D, dpsiry_v_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], 2);
dpsir_y_vlbr3.ptr_2D, dpsiry_v_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], false);
Gint_Tools::mult_psi_DMR(*this->gridt, this->bxyz, grid_index, na_grid, block_index, block_size, cal_flag,
dpsir_z_vlbr3.ptr_2D, dpsirz_v_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], 2);
dpsir_z_vlbr3.ptr_2D, dpsirz_v_DM.ptr_2D, DM_in[GlobalV::CURRENT_SPIN], this->DMRGint[is], false);
}

if(isforce)
Expand Down
6 changes: 3 additions & 3 deletions source/module_hamilt_lcao/module_gint/gint_rho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void Gint::gint_kernel_rho(
block_index, cal_flag,
psir_ylm.ptr_2D,
psir_DM.ptr_2D,
inout->DM[is], 1);
inout->DM[is], inout->if_symm);
}
else
{
Expand All @@ -54,7 +54,7 @@ void Gint::gint_kernel_rho(
block_index, cal_flag,
psir_ylm.ptr_2D,
psir_DM.ptr_2D,
this->DMRGint[is], 1);
this->DMRGint[is], inout->if_symm);
}

}
Expand All @@ -69,7 +69,7 @@ void Gint::gint_kernel_rho(
psir_DM.ptr_2D,
inout->DM_R[is],
this->DMRGint[is],
1);
inout->if_symm);
}

//do sum_mu g_mu(r)psi_mu(r) to get electron density on grid
Expand Down
95 changes: 14 additions & 81 deletions source/module_hamilt_lcao/module_gint/gint_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,30 +827,18 @@ namespace Gint_Tools
const double*const*const psi, // psir_vlbr3[bxyz][LD_pool]
double ** psi_DM,
const double*const*const DM,
const int job) // 1: density, 2: force
const bool if_symm) // true: density, use dsymv; false: potential/transition density use dgemv
{
constexpr char side='L', uplo='U';
constexpr char transa='N', transb='N';
constexpr double alpha_symm=1, beta=1;
constexpr int inc=1;
double alpha_gemm;

switch(job)
{
case 1:
alpha_gemm=2.0;
break;
case 2:
alpha_gemm=1.0;
break;
default:
ModuleBase::WARNING_QUIT("psir_dm","job can only be 1 or 2");
}
double alpha_gemm = if_symm ? 2.0 : 1.0;

for (int ia1=0; ia1<na_grid; ia1++)
{
const int iw1_lo=block_iw[ia1];
if(job==1)//density
if (if_symm)//density
{
//ia1==ia2, diagonal part
// find the first ib and last ib for non-zeros cal_flag
Expand Down Expand Up @@ -903,18 +891,7 @@ namespace Gint_Tools
}
}

int start;
switch(job)
{
case 1:
start=ia1+1;
break;
case 2:
start=0;
break;
default:
ModuleBase::WARNING_QUIT("psi_dm","job can only be 1 or 2");
}
int start = if_symm ? ia1 + 1 : 0;

for (int ia2=start; ia2<na_grid; ia2++)
{
Expand Down Expand Up @@ -981,7 +958,7 @@ namespace Gint_Tools
const double*const*const psi, // psir_vlbr3[bxyz][LD_pool]
double ** psi_DM,
const hamilt::HContainer<double>* DM,
const int job) // 1: density, 2: force
const bool if_symm) // 1: density, 2: force
{
bool *all_out_of_range = new bool[na_grid];
for(int ia=0; ia<na_grid; ++ia) //number of atoms
Expand All @@ -1001,21 +978,9 @@ namespace Gint_Tools
constexpr char transa='N', transb='N';
constexpr double alpha_symm=1, beta=1;
constexpr int inc=1;
double alpha_gemm;
double alpha_gemm = if_symm ? 2.0 : 1.0;

switch(job)
{
case 1:
alpha_gemm=2.0;
break;
case 2:
alpha_gemm=1.0;
break;
default:
ModuleBase::WARNING_QUIT("psir_dm","job can only be 1 or 2");
}

for (int ia1=0; ia1<na_grid; ia1++)
for (int ia1 = 0; ia1 < na_grid; ia1++)
{
if(all_out_of_range[ia1]) continue;

Expand All @@ -1024,7 +989,7 @@ namespace Gint_Tools
const double* tmp_matrix = DM->find_pair(iat1, iat1)->get_pointer(0);
const int iw1_lo=block_iw[ia1];

if(job==1)//density
if (if_symm)//density
{
//ia1==ia2, diagonal part
// find the first ib and last ib for non-zeros cal_flag
Expand Down Expand Up @@ -1077,18 +1042,7 @@ namespace Gint_Tools
}
}

int start;
switch(job)
{
case 1:
start=ia1+1;
break;
case 2:
start=0;
break;
default:
ModuleBase::WARNING_QUIT("psi_dm","job can only be 1 or 2");
}
int start = if_symm ? ia1 + 1 : 0;

for (int ia2=start; ia2<na_grid; ia2++)
{
Expand Down Expand Up @@ -1180,7 +1134,7 @@ namespace Gint_Tools
double ** psi_DMR,
double* DMR,
const hamilt::HContainer<double>* DM,
const int job)
const bool if_symm)
{
double *psi2, *psi2_dmr;
int iwi, iww;
Expand All @@ -1204,18 +1158,8 @@ namespace Gint_Tools
const char trans='N';
const double alpha=1.0, beta=1.0;
const int inc=1;
double alpha1;
switch(job)
{
case 1:
alpha1=2.0;
break;
case 2:
alpha1=1.0;
break;
default:
ModuleBase::WARNING_QUIT("psir_dmr","job can only be 1 or 2");
}
double alpha1;
alpha1 = if_symm ? 2.0 : 1.0;

for (int ia1=0; ia1<na_grid; ia1++)
{
Expand All @@ -1235,7 +1179,7 @@ namespace Gint_Tools
const int R1y = gt.ucell_index2y[id1];
const int R1z = gt.ucell_index2z[id1];
const double* tmp_matrix = DM->find_matrix(iat, iat, 0, 0, 0)->get_pointer();
if(job==1) //density
if (if_symm) //density
{
const int idx1=block_index[ia1];
int* find_start = gt.find_R2[iat];
Expand Down Expand Up @@ -1307,18 +1251,7 @@ namespace Gint_Tools
}

// get (j,beta,R2)
int start;
switch(job)
{
case 1:
start=ia1+1;
break;
case 2:
start=0;
break;
default:
ModuleBase::WARNING_QUIT("psi_dmr","job can only be 1 or 2");
}
int start = if_symm ? ia1 + 1 : 0;

for (int ia2=start; ia2<na_grid; ia2++)
{
Expand Down
14 changes: 8 additions & 6 deletions source/module_hamilt_lcao/module_gint/gint_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Gint_inout
bool isforce;
bool isstress;
int ispin;

bool if_symm = false; // if true, use dsymv in gint_kernel_rho; if false, use dgemv.

//output
double** rho;
Expand All @@ -38,11 +38,12 @@ class Gint_inout
Gint_Tools::job_type job;

// electron density and kin_r, multi-k
Gint_inout(double **DM_R_in, double** rho_in, Gint_Tools::job_type job_in)
Gint_inout(double** DM_R_in, double** rho_in, Gint_Tools::job_type job_in, bool if_symm_in = true)
{
DM_R = DM_R_in;
rho = rho_in;
job = job_in;
if_symm = if_symm_in;
}

// force, multi-k
Expand Down Expand Up @@ -94,11 +95,12 @@ class Gint_inout
}

// electron density and kin_r, gamma point
Gint_inout(double ***DM_in, double** rho_in, Gint_Tools::job_type job_in)
Gint_inout(double*** DM_in, double** rho_in, Gint_Tools::job_type job_in, bool if_symm_in = true)
{
DM = DM_in;
rho = rho_in;
job = job_in;
if_symm = if_symm_in;
}

// force, gamma point
Expand Down Expand Up @@ -270,7 +272,7 @@ namespace Gint_Tools
const double*const*const psi, // psir_vlbr3[bxyz][LD_pool]
double** psi_DM,
const double*const*const DM,
const int job);
const bool if_symm);

// sum_nu,R rho_mu,nu(R) psi_nu, for multi-k
void mult_psi_DMR(
Expand All @@ -285,7 +287,7 @@ namespace Gint_Tools
double** psi_DMR,
double* DMR,
const hamilt::HContainer<double>* DM,
const int job);
const bool if_symm);

// sum_nu rho_mu,nu psi_nu, for gamma point
void mult_psi_DM_new(
Expand All @@ -301,7 +303,7 @@ namespace Gint_Tools
const double*const*const psi, // psir_vlbr3[bxyz][LD_pool]
double** psi_DM,
const hamilt::HContainer<double>* DM,
const int job);
const bool if_symm);

}

Expand Down

0 comments on commit 23124d3

Please sign in to comment.