Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some minor security risks about ModuleESolver::init_esolver() and some potential functions #4590

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ void Driver::driver_run() {
ModuleBase::TITLE("Driver", "driver_line");
ModuleBase::timer::tick("Driver", "driver_line");

//! 1: initialize the ESolver
ModuleESolver::ESolver* p_esolver = nullptr;
ModuleESolver::init_esolver(p_esolver);
//! 1: initialize the ESolver
ModuleESolver::ESolver *p_esolver = ModuleESolver::init_esolver();

//! 2: setup cell and atom information

Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/H_Hartree_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ PotHartree::PotHartree(const ModulePW::PW_Basis* rho_basis_in)
this->fixed_mode = false;
}

void PotHartree::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
void PotHartree::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff)
{
if(GlobalV::use_paw)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/H_Hartree_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class PotHartree : public PotBase
public:
PotHartree(const ModulePW::PW_Basis* rho_basis_in);

void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff);
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff);
};

} // namespace elecstate
Expand Down
18 changes: 6 additions & 12 deletions source/module_elecstate/potentials/pot_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,12 @@ namespace elecstate
class PotBase
{
public:
PotBase(){};
virtual ~PotBase(){};

virtual void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
{
return;
}

virtual void cal_fixed_v(double* vl_pseudo)
{
return;
}
PotBase(){}
virtual ~PotBase(){}

virtual void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff){}

virtual void cal_fixed_v(double* vl_pseudo){}

bool fixed_mode = 0;
bool dynamic_mode = 0;
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/pot_surchem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PotSurChem : public PotBase
}
}

void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) override
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override
{
if (!this->allocated)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/pot_xc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace elecstate
{

void PotXC::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
void PotXC::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff)
{
ModuleBase::TITLE("PotXC", "cal_v_eff");
ModuleBase::timer::tick("PotXC", "cal_v_eff");
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/pot_xc.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PotXC : public PotBase
this->fixed_mode = false;
}

void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) override;
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override;

ModuleBase::matrix* vofk = nullptr;
double* etxc_ = nullptr;
Expand Down
10 changes: 5 additions & 5 deletions source/module_elecstate/potentials/potential_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void Potential::allocate()
}
}

void Potential::update_from_charge(const Charge* chg, const UnitCell* ucell)
void Potential::update_from_charge(const Charge*const chg, const UnitCell*const ucell)
{
ModuleBase::TITLE("Potential", "update_from_charge");
ModuleBase::timer::tick("Potential", "update_from_charge");
Expand Down Expand Up @@ -243,11 +243,11 @@ void Potential::cal_fixed_v(double* vl_pseudo)
ModuleBase::timer::tick("Potential", "cal_fixed_v");
}

void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
void Potential::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff)
{
ModuleBase::TITLE("Potential", "cal_v_eff");
int nspin_current = this->v_effective.nr;
int nrxx = this->v_effective.nc;
const int nspin_current = this->v_effective.nr;
const int nrxx = this->v_effective.nc;
ModuleBase::timer::tick("Potential", "cal_v_eff");
// first of all, set v_effective to zero.
this->v_effective.zero_out();
Expand Down Expand Up @@ -275,7 +275,7 @@ void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::
ModuleBase::timer::tick("Potential", "cal_v_eff");
}

void Potential::init_pot(int istep, const Charge* chg)
void Potential::init_pot(int istep, const Charge*const chg)
{
ModuleBase::TITLE("Potential", "init_pot");
ModuleBase::timer::tick("Potential", "init_pot");
Expand Down
6 changes: 3 additions & 3 deletions source/module_elecstate/potentials/potential_new.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class Potential : public PotBase
~Potential();

// initialize potential when SCF begin
void init_pot(int istep, const Charge* chg);
void init_pot(int istep, const Charge*const chg);
// initialize potential components before SCF
void pot_register(std::vector<std::string>& components_list);
// update potential from current charge
void update_from_charge(const Charge* chg, const UnitCell* ucell);
void update_from_charge(const Charge*const chg, const UnitCell*const ucell);
// interface for SCF-converged, etxc vtxc for Energy, vnew for force_scc
void get_vnew(const Charge* chg, ModuleBase::matrix& vnew);

Expand Down Expand Up @@ -170,7 +170,7 @@ class Potential : public PotBase
}

private:
void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) override;
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override;
void cal_fixed_v(double* vl_pseudo) override;
// interpolate potential on the smooth mesh if necessary
void interpolate_vrs();
Expand Down
59 changes: 30 additions & 29 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
#include "esolver_lj.h"
#include "esolver_of.h"
#include "module_md/md_para.h"
#include <stdexcept>

namespace ModuleESolver
{

void ESolver::printname(void)
{
std::cout << classname << std::endl;
std::cout << classname << std::endl;
}

std::string determine_type(void)
Expand Down Expand Up @@ -85,14 +86,14 @@ std::string determine_type(void)

auto device_info = GlobalV::device_flag;

for (char &c : device_info)
for (char &c : device_info)
{
if (std::islower(c))
if (std::islower(c))
{
c = std::toupper(c);
}
}
if (GlobalV::MY_RANK == 0)
if (GlobalV::MY_RANK == 0)
{
std::cout << " RUNNING WITH DEVICE : " << device_info << " / "
<< base_device::information::get_device_info(GlobalV::device_flag) << std::endl;
Expand All @@ -106,74 +107,74 @@ std::string determine_type(void)


//Some API to operate E_Solver
void init_esolver(ESolver*& p_esolver)
ESolver* init_esolver()
{
//determine type of esolver based on INPUT information
std::string esolver_type = determine_type();
const std::string esolver_type = determine_type();

//initialize the corresponding Esolver child class
if (esolver_type == "ksdft_pw")
{
#if ((defined __CUDA) || (defined __ROCM))
if (GlobalV::device_flag == "gpu")
if (GlobalV::device_flag == "gpu")
{
if (GlobalV::precision_flag == "single")
if (GlobalV::precision_flag == "single")
{
p_esolver = new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_GPU>();
}
else
return new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_GPU>();
}
else
{
p_esolver = new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_GPU>();
}
return;
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_GPU>();
}
}
#endif
if (GlobalV::precision_flag == "single")
if (GlobalV::precision_flag == "single")
{
p_esolver = new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_CPU>();
}
else
return new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_CPU>();
}
else
{
p_esolver = new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
}
}
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
}
}
#ifdef __LCAO
else if (esolver_type == "ksdft_lcao")
{
if (GlobalV::GAMMA_ONLY_LOCAL)
{
p_esolver = new ESolver_KS_LCAO<double, double>();
return new ESolver_KS_LCAO<double, double>();
}
else if (GlobalV::NSPIN < 4)
{
p_esolver = new ESolver_KS_LCAO<std::complex<double>, double>();
return new ESolver_KS_LCAO<std::complex<double>, double>();
}
else
{
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
}
}
else if (esolver_type == "ksdft_lcao_tddft")
{
p_esolver = new ESolver_KS_LCAO_TDDFT();
return new ESolver_KS_LCAO_TDDFT();
}
#endif
else if (esolver_type == "sdft_pw")
{
p_esolver = new ESolver_SDFT_PW();
return new ESolver_SDFT_PW();
}
else if(esolver_type == "ofdft")
{
p_esolver = new ESolver_OF();
return new ESolver_OF();
}
else if (esolver_type == "lj_pot")
{
p_esolver = new ESolver_LJ();
return new ESolver_LJ();
}
else if (esolver_type == "dp_pot")
{
p_esolver = new ESolver_DP(INPUT.mdp.pot_file);
return new ESolver_DP(INPUT.mdp.pot_file);
}
throw std::invalid_argument("esolver_type = "+std::string(esolver_type)+". Wrong in "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
}


Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ std::string determine_type(void);
* the corresponding ESolver child class. It supports various ESolver types including ksdft_pw,
* ksdft_lcao, ksdft_lcao_tddft, sdft_pw, ofdft, lj_pot, and dp_pot.
*
* @param [in, out] p_esolver A pointer to an ESolver object that will be initialized.
* @return [out] A pointer to an ESolver object that will be initialized.
*/
void init_esolver(ESolver*& p_esolver);
ESolver* init_esolver();

void clean_esolver(ESolver*& pesolver);

Expand Down
Loading