-
Notifications
You must be signed in to change notification settings - Fork 504
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e02dd11
commit 8343077
Showing
1 changed file
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// SPDX-License-Identifier: LGPL-3.0-or-later | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/include/paddle_inference_api.h" | ||
|
||
namespace deepmd { | ||
/** | ||
* @brief Check TensorFlow status. Exit if not OK. | ||
* @param[in] status TensorFlow status. | ||
**/ | ||
// void check_status(const tensorflow::Status& status); | ||
|
||
/** | ||
* @brief Get the value of a tensor. | ||
* @param[in] session TensorFlow session. | ||
* @param[in] name The name of the tensor. | ||
* @param[in] scope The scope of the tensor. | ||
* @return The value of the tensor. | ||
**/ | ||
template <typename VT> | ||
VT predictor_get_scalar(const std::shared_ptr<paddle_infer::Predictor>& predictor, | ||
const std::string& name_); | ||
|
||
/** | ||
* @brief Get the vector of a tensor. | ||
* @param[out] o_vec The output vector. | ||
* @param[in] session TensorFlow session. | ||
* @param[in] name The name of the tensor. | ||
* @param[in] scope The scope of the tensor. | ||
**/ | ||
// template <typename VT> | ||
// void session_get_vector(std::vector<VT>& o_vec, | ||
// tensorflow::Session* session, | ||
// const std::string name_, | ||
// const std::string scope = ""); | ||
|
||
/** | ||
* @brief Get the type of a tensor. | ||
* @param[in] session TensorFlow session. | ||
* @param[in] name The name of the tensor. | ||
* @param[in] scope The scope of the tensor. | ||
* @return The type of the tensor as int. | ||
**/ | ||
paddle_infer::DataType predictor_get_dtype(const std::shared_ptr<paddle_infer::Predictor>& predictor, | ||
const std::string& name_); | ||
|
||
/** | ||
* @brief Get input tensors. | ||
* @param[out] input_tensors Input tensors. | ||
* @param[in] dcoord_ Coordinates of atoms. | ||
* @param[in] ntypes Number of atom types. | ||
* @param[in] datype_ Atom types. | ||
* @param[in] dbox Box matrix. | ||
* @param[in] cell_size Cell size. | ||
* @param[in] fparam_ Frame parameters. | ||
* @param[in] aparam_ Atom parameters. | ||
* @param[in] atommap Atom map. | ||
* @param[in] scope The scope of the tensors. | ||
* @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is | ||
* nall. | ||
*/ | ||
template <typename MODELTYPE, typename VALUETYPE> | ||
int predictor_input_tensors( | ||
const std::shared_ptr<paddle_infer::Predictor>& predictor, | ||
const std::vector<VALUETYPE>& dcoord_, | ||
const int& ntypes, | ||
const std::vector<int>& datype_, | ||
const std::vector<VALUETYPE>& dbox, | ||
const double& cell_size, | ||
const std::vector<VALUETYPE>& fparam_, | ||
const std::vector<VALUETYPE>& aparam_, | ||
const deepmd::AtomMap& atommap, | ||
const bool aparam_nall = false); | ||
|
||
/** | ||
* @brief Get input tensors. | ||
* @param[out] input_tensors Input tensors. | ||
* @param[in] dcoord_ Coordinates of atoms. | ||
* @param[in] ntypes Number of atom types. | ||
* @param[in] datype_ Atom types. | ||
* @param[in] dlist Neighbor list. | ||
* @param[in] fparam_ Frame parameters. | ||
* @param[in] aparam_ Atom parameters. | ||
* @param[in] atommap Atom map. | ||
* @param[in] nghost Number of ghost atoms. | ||
* @param[in] ago Update the internal neighbour list if ago is 0. | ||
* @param[in] scope The scope of the tensors. | ||
* @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is | ||
* nall. | ||
*/ | ||
template <typename MODELTYPE, typename VALUETYPE> | ||
int predictor_input_tensors( | ||
const std::shared_ptr<paddle_infer::Predictor>& predictor, | ||
const std::vector<VALUETYPE>& dcoord_, | ||
const int& ntypes, | ||
const std::vector<int>& datype_, | ||
const std::vector<VALUETYPE>& dbox, | ||
InputNlist& dlist, | ||
const std::vector<VALUETYPE>& fparam_, | ||
const std::vector<VALUETYPE>& aparam_, | ||
const deepmd::AtomMap& atommap, | ||
const int nghost, | ||
const int ago, | ||
const bool aparam_nall = false); | ||
|
||
/** | ||
* @brief Get input tensors for mixed type. | ||
* @param[out] input_tensors Input tensors. | ||
* @param[in] nframes Number of frames. | ||
* @param[in] dcoord_ Coordinates of atoms. | ||
* @param[in] ntypes Number of atom types. | ||
* @param[in] datype_ Atom types. | ||
* @param[in] dlist Neighbor list. | ||
* @param[in] fparam_ Frame parameters. | ||
* @param[in] aparam_ Atom parameters. | ||
* @param[in] atommap Atom map. | ||
* @param[in] nghost Number of ghost atoms. | ||
* @param[in] ago Update the internal neighbour list if ago is 0. | ||
* @param[in] scope The scope of the tensors. | ||
* @param[in] aparam_nall Whether the atomic dimesion of atomic parameters is | ||
* nall. | ||
*/ | ||
template <typename MODELTYPE, typename VALUETYPE> | ||
int predictor_input_tensors_mixed_type( | ||
const std::shared_ptr<paddle_infer::Predictor>& predictor, | ||
const int& nframes, | ||
const std::vector<VALUETYPE>& dcoord_, | ||
const int& ntypes, | ||
const std::vector<int>& datype_, | ||
const std::vector<VALUETYPE>& dbox, | ||
const double& cell_size, | ||
const std::vector<VALUETYPE>& fparam_, | ||
const std::vector<VALUETYPE>& aparam_, | ||
const deepmd::AtomMap& atommap, | ||
const bool aparam_nall = false); | ||
|
||
} // namespace deepmd |