Skip to content

Commit

Permalink
cc: refactor DeepPotModelDevi, making it framework-independent (#3134)
Browse files Browse the repository at this point in the history
Refactor `DeepPotModelDevi` as a step of #3122. Now, it is just a
wrapper of multiple `DeepPot` classes. Models can have different
behaviors inside different `DeepPot`.

One may argue that the new class needs to prepare the input multiple
times. However, it's not expensive only to copy the memory. Also, during
the simulations, usually we run it every 100 steps.
  • Loading branch information
njzjz authored Jan 12, 2024
1 parent 04f07ef commit 828df66
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 331 deletions.
41 changes: 7 additions & 34 deletions source/api_cc/include/DeepPot.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,39 +480,39 @@ class DeepPotModelDevi {
**/
double cutoff() const {
assert(inited);
return rcut;
return dps[0].cutoff();
};
/**
* @brief Get the number of types.
* @return The number of types.
**/
int numb_types() const {
assert(inited);
return ntypes;
return dps[0].numb_types();
};
/**
* @brief Get the number of types with spin.
* @return The number of types with spin.
**/
int numb_types_spin() const {
assert(inited);
return ntypes_spin;
return dps[0].numb_types_spin();
};
/**
* @brief Get the dimension of the frame parameter.
* @return The dimension of the frame parameter.
**/
int dim_fparam() const {
assert(inited);
return dfparam;
return dps[0].dim_fparam();
};
/**
* @brief Get the dimension of the atomic parameter.
* @return The dimension of the atomic parameter.
**/
int dim_aparam() const {
assert(inited);
return daparam;
return dps[0].dim_aparam();
};
/**
* @brief Compute the average energy.
Expand Down Expand Up @@ -590,39 +590,12 @@ class DeepPotModelDevi {
**/
bool is_aparam_nall() const {
assert(inited);
return aparam_nall;
return dps[0].is_aparam_nall();
};

private:
unsigned numb_models;
std::vector<tensorflow::Session*> sessions;
int num_intra_nthreads, num_inter_nthreads;
std::vector<tensorflow::GraphDef*> graph_defs;
std::vector<deepmd::DeepPot> dps;
bool inited;
template <class VT>
VT get_scalar(const std::string name) const;
// VALUETYPE get_rcut () const;
// int get_ntypes () const;
double rcut;
double cell_size;
int dtype;
std::string model_type;
std::string model_version;
int ntypes;
int ntypes_spin;
int dfparam;
int daparam;
bool aparam_nall;
template <typename VALUETYPE>
void validate_fparam_aparam(const int& nloc,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) const;

// copy neighbor list info from host
bool init_nbor;
std::vector<std::vector<int> > sec;
deepmd::AtomMap atommap;
NeighborListData nlist_data;
InputNlist nlist;
};
} // namespace deepmd
Loading

0 comments on commit 828df66

Please sign in to comment.