Skip to content

Commit

Permalink
Fix some minor security risks about ModuleESolver::init_esolver() and…
Browse files Browse the repository at this point in the history
… some potential functions
  • Loading branch information
PeizeLin committed Jul 6, 2024
1 parent 08518df commit 2870c62
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 59 deletions.
5 changes: 2 additions & 3 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ void Driver::driver_run() {
ModuleBase::TITLE("Driver", "driver_line");
ModuleBase::timer::tick("Driver", "driver_line");

//! 1: initialize the ESolver
ModuleESolver::ESolver* p_esolver = nullptr;
ModuleESolver::init_esolver(p_esolver);
//! 1: initialize the ESolver
ModuleESolver::ESolver *p_esolver = ModuleESolver::init_esolver();

//! 2: setup cell and atom information

Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/H_Hartree_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ PotHartree::PotHartree(const ModulePW::PW_Basis* rho_basis_in)
this->fixed_mode = false;
}

void PotHartree::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
void PotHartree::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff)
{
if(GlobalV::use_paw)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/H_Hartree_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class PotHartree : public PotBase
public:
PotHartree(const ModulePW::PW_Basis* rho_basis_in);

void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff);
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff);
};

} // namespace elecstate
Expand Down
18 changes: 6 additions & 12 deletions source/module_elecstate/potentials/pot_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,12 @@ namespace elecstate
class PotBase
{
public:
PotBase(){};
virtual ~PotBase(){};

virtual void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
{
return;
}

virtual void cal_fixed_v(double* vl_pseudo)
{
return;
}
PotBase(){}
virtual ~PotBase(){}

virtual void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff){}

virtual void cal_fixed_v(double* vl_pseudo){}

bool fixed_mode = 0;
bool dynamic_mode = 0;
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/pot_surchem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PotSurChem : public PotBase
}
}

void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) override
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override
{
if (!this->allocated)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/pot_xc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace elecstate
{

void PotXC::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
void PotXC::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff)
{
ModuleBase::TITLE("PotXC", "cal_v_eff");
ModuleBase::timer::tick("PotXC", "cal_v_eff");
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/potentials/pot_xc.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PotXC : public PotBase
this->fixed_mode = false;
}

void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) override;
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override;

ModuleBase::matrix* vofk = nullptr;
double* etxc_ = nullptr;
Expand Down
10 changes: 5 additions & 5 deletions source/module_elecstate/potentials/potential_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void Potential::allocate()
}
}

void Potential::update_from_charge(const Charge* chg, const UnitCell* ucell)
void Potential::update_from_charge(const Charge*const chg, const UnitCell*const ucell)
{
ModuleBase::TITLE("Potential", "update_from_charge");
ModuleBase::timer::tick("Potential", "update_from_charge");
Expand Down Expand Up @@ -243,11 +243,11 @@ void Potential::cal_fixed_v(double* vl_pseudo)
ModuleBase::timer::tick("Potential", "cal_fixed_v");
}

void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
void Potential::cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff)
{
ModuleBase::TITLE("Potential", "cal_v_eff");
int nspin_current = this->v_effective.nr;
int nrxx = this->v_effective.nc;
const int nspin_current = this->v_effective.nr;
const int nrxx = this->v_effective.nc;
ModuleBase::timer::tick("Potential", "cal_v_eff");
// first of all, set v_effective to zero.
this->v_effective.zero_out();
Expand Down Expand Up @@ -275,7 +275,7 @@ void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::
ModuleBase::timer::tick("Potential", "cal_v_eff");
}

void Potential::init_pot(int istep, const Charge* chg)
void Potential::init_pot(int istep, const Charge*const chg)
{
ModuleBase::TITLE("Potential", "init_pot");
ModuleBase::timer::tick("Potential", "init_pot");
Expand Down
6 changes: 3 additions & 3 deletions source/module_elecstate/potentials/potential_new.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class Potential : public PotBase
~Potential();

// initialize potential when SCF begin
void init_pot(int istep, const Charge* chg);
void init_pot(int istep, const Charge*const chg);
// initialize potential components before SCF
void pot_register(std::vector<std::string>& components_list);
// update potential from current charge
void update_from_charge(const Charge* chg, const UnitCell* ucell);
void update_from_charge(const Charge*const chg, const UnitCell*const ucell);
// interface for SCF-converged, etxc vtxc for Energy, vnew for force_scc
void get_vnew(const Charge* chg, ModuleBase::matrix& vnew);

Expand Down Expand Up @@ -170,7 +170,7 @@ class Potential : public PotBase
}

private:
void cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) override;
void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override;
void cal_fixed_v(double* vl_pseudo) override;
// interpolate potential on the smooth mesh if necessary
void interpolate_vrs();
Expand Down
59 changes: 30 additions & 29 deletions source/module_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
#include "esolver_lj.h"
#include "esolver_of.h"
#include "module_md/md_para.h"
#include <stdexcept>

namespace ModuleESolver
{

void ESolver::printname(void)
{
std::cout << classname << std::endl;
std::cout << classname << std::endl;
}

std::string determine_type(void)
Expand Down Expand Up @@ -85,14 +86,14 @@ std::string determine_type(void)

auto device_info = GlobalV::device_flag;

for (char &c : device_info)
for (char &c : device_info)
{
if (std::islower(c))
if (std::islower(c))
{
c = std::toupper(c);
}
}
if (GlobalV::MY_RANK == 0)
if (GlobalV::MY_RANK == 0)
{
std::cout << " RUNNING WITH DEVICE : " << device_info << " / "
<< base_device::information::get_device_info(GlobalV::device_flag) << std::endl;
Expand All @@ -106,74 +107,74 @@ std::string determine_type(void)


//Some API to operate E_Solver
void init_esolver(ESolver*& p_esolver)
ESolver* init_esolver()
{
//determine type of esolver based on INPUT information
std::string esolver_type = determine_type();
const std::string esolver_type = determine_type();

//initialize the corresponding Esolver child class
if (esolver_type == "ksdft_pw")
{
#if ((defined __CUDA) || (defined __ROCM))
if (GlobalV::device_flag == "gpu")
if (GlobalV::device_flag == "gpu")
{
if (GlobalV::precision_flag == "single")
if (GlobalV::precision_flag == "single")
{
p_esolver = new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_GPU>();
}
else
return new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_GPU>();
}
else
{
p_esolver = new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_GPU>();
}
return;
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_GPU>();
}
}
#endif
if (GlobalV::precision_flag == "single")
if (GlobalV::precision_flag == "single")
{
p_esolver = new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_CPU>();
}
else
return new ESolver_KS_PW<std::complex<float>, base_device::DEVICE_CPU>();
}
else
{
p_esolver = new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
}
}
return new ESolver_KS_PW<std::complex<double>, base_device::DEVICE_CPU>();
}
}
#ifdef __LCAO
else if (esolver_type == "ksdft_lcao")
{
if (GlobalV::GAMMA_ONLY_LOCAL)
{
p_esolver = new ESolver_KS_LCAO<double, double>();
return new ESolver_KS_LCAO<double, double>();
}
else if (GlobalV::NSPIN < 4)
{
p_esolver = new ESolver_KS_LCAO<std::complex<double>, double>();
return new ESolver_KS_LCAO<std::complex<double>, double>();
}
else
{
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
}
}
else if (esolver_type == "ksdft_lcao_tddft")
{
p_esolver = new ESolver_KS_LCAO_TDDFT();
return new ESolver_KS_LCAO_TDDFT();
}
#endif
else if (esolver_type == "sdft_pw")
{
p_esolver = new ESolver_SDFT_PW();
return new ESolver_SDFT_PW();
}
else if(esolver_type == "ofdft")
{
p_esolver = new ESolver_OF();
return new ESolver_OF();
}
else if (esolver_type == "lj_pot")
{
p_esolver = new ESolver_LJ();
return new ESolver_LJ();
}
else if (esolver_type == "dp_pot")
{
p_esolver = new ESolver_DP(INPUT.mdp.pot_file);
return new ESolver_DP(INPUT.mdp.pot_file);
}
throw std::invalid_argument("esolver_type = "+std::string(esolver_type)+". Wrong in "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
}


Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ std::string determine_type(void);
* the corresponding ESolver child class. It supports various ESolver types including ksdft_pw,
* ksdft_lcao, ksdft_lcao_tddft, sdft_pw, ofdft, lj_pot, and dp_pot.
*
* @param [in, out] p_esolver A pointer to an ESolver object that will be initialized.
* @return [out] A pointer to an ESolver object that will be initialized.
*/
void init_esolver(ESolver*& p_esolver);
ESolver* init_esolver();

void clean_esolver(ESolver*& pesolver);

Expand Down

0 comments on commit 2870c62

Please sign in to comment.