Skip to content

Commit

Permalink
add comments for the fft class
Browse files Browse the repository at this point in the history
  • Loading branch information
A-006 committed Nov 8, 2024
1 parent 0610758 commit da26acc
Show file tree
Hide file tree
Showing 7 changed files with 486 additions and 194 deletions.
2 changes: 1 addition & 1 deletion examples/scf/pw_Si2/INPUT
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ scf_thr 1e-7
scf_nmax 100
device cpu
ks_solver dav_subspace
precision single
precision double
145 changes: 123 additions & 22 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,151 @@ class FFT_BASE
FFT_BASE();
virtual ~FFT_BASE();

// init parameters of fft
/**
* @brief Initialize the fft parameters As virtual function.
*
* The function is used to initialize the fft parameters.
*/
virtual __attribute__((weak))
void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false);
void initfft(int nx_in,
int ny_in,
int nz_in,
int lixy_in,
int rixy_in,
int ns_in,
int nplane_in,
int nproc_in,
bool gamma_only_in,
bool xprime_in = true);

//init fftw_plans
/**
* @brief Setup the fft Plan and data As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to setup the fft Plan and data.
*/
virtual void setupFFT()=0;

//destroy fftw_plans
/**
* @brief Clean the fft Plan As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to clean the fft Plan.
*/
virtual void cleanFFT()=0;
//clear fftw_data

/**
* @brief Clear the fft data As pure virtual function.
*
* The function is set as pure virtual function.In order to
* override the function in the derived class.In the derived
* class, the function is used to clear the fft data.
*/
virtual void clear()=0;

// access the real space data
virtual __attribute__((weak)) FPTYPE* get_rspace_data() const;
/**
* @brief Get the real space data in cpu-like fft
*
* The function is used to get the real space data.While the
* FFT_BASE is an abstract class,the function will be override,
* The attribute weak is used to avoid define the function.
*/
virtual __attribute__((weak))
FPTYPE* get_rspace_data() const;

virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxr_data() const;
virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_data() const;

virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxg_data() const;
virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxg_data() const;

virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxr_3d_data() const;
/**
* @brief Get the auxiliary real space data in 3D
*
* The function is used to get the auxiliary real space data in 3D.
* While the FFT_BASE is an abstract class,the function will be override,
* The attribute weak is used to avoid define the function.
*/
virtual __attribute__((weak))
std::complex<FPTYPE>* get_auxr_3d_data() const;

//forward fft in x-y direction
virtual __attribute__((weak)) void fftxyfor(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftxybac(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
/**
* @brief Forward FFT in x-y direction
* @param in input data
* @param out output data
*
* This function performs the forward FFT in the x-y direction.
* It involves two axes, x and y. The FFT is applied multiple times
* along the left and right boundaries in the primary direction(which is
* determined by the xprime flag).Notably, the Y axis operates in
* "many-many-FFT" mode.
*/
virtual __attribute__((weak))
void fftxyfor(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak))
void fftxybac(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftzfor(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
/**
* @brief Forward FFT in z direction
* @param in input data
* @param out output data
*
* This function performs the forward FFT in the z direction.
* It involves only one axis, z. The FFT is applied only once.
* Notably, the Z axis operates in many FFT with nz*ns.
*/
virtual __attribute__((weak))
void fftzfor(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftzbac(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
virtual __attribute__((weak))
void fftzbac(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in, std::complex<FPTYPE>* out) const;
/**
* @brief Forward FFT in x-y direction with real to complex
* @param in input data, real type
* @param out output data, complex type
*
* This function performs the forward FFT in the x-y direction
* with real to complex.There is no difference between fftxyfor.
*/
virtual __attribute__((weak))
void fftxyr2c(FPTYPE* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftxyc2r(std::complex<FPTYPE>* in, FPTYPE* out) const;
virtual __attribute__((weak))
void fftxyc2r(std::complex<FPTYPE>* in,
FPTYPE* out) const;

virtual __attribute__((weak)) void fft3D_forward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
/**
* @brief Forward FFT in 3D
* @param in input data
* @param out output data
*
* This function performs the forward FFT for gpu-like fft.
* It involves three axes, x, y, and z. The FFT is applied multiple times
* for fft3D_forward.
*/
virtual __attribute__((weak))
void fft3D_forward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fft3D_backward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
virtual __attribute__((weak))
void fft3D_backward(std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out) const;

protected:
int ny=0;
int nx=0;
int nx=0;
int ny=0;
int nz=0;

};
}
#endif // FFT_BASE_H
36 changes: 20 additions & 16 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
}
namespace ModulePW
{
// #include "fft_gpu.h"
FFT_Bundle::FFT_Bundle()
{
}
Expand All @@ -26,11 +25,6 @@ FFT_Bundle::FFT_Bundle(std::string device_in,std::string precision_in)
assert(precision_in=="single" || precision_in=="double" || precision_in=="mixing");
this->device = device_in;
this->precision = precision_in;
// if (device=="cpu")
// {
// fft_float = make_unique<FFT_CPU<float>>();
// fft_double = make_unique<FFT_CPU<double>>();
// }
}

FFT_Bundle::~FFT_Bundle()
Expand All @@ -54,8 +48,17 @@ void FFT_Bundle::setfft(std::string device_in,std::string precision_in)
this->precision = precision_in;

}
void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
int nproc_in, bool gamma_only_in, bool xprime_in , bool mpifft_in)
void FFT_Bundle::initfft(int nx_in,
int ny_in,
int nz_in,
int lixy_in,
int rixy_in,
int ns_in,
int nplane_in,
int nproc_in,
bool gamma_only_in,
bool xprime_in ,
bool mpifft_in)
{
if (this->precision=="single")
{
Expand All @@ -74,6 +77,14 @@ void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_
{
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
if (float_flag)
{
fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in);
}
if (double_flag)
{
fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in);
}
}
if (device=="gpu")
{
Expand All @@ -85,14 +96,7 @@ void FFT_Bundle::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_
// fft_double = make_unique<FFT_CUDA<double>>();
// #endif
}
if (float_flag)
{
fft_float->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);
}
if (double_flag)
{
fft_double->initfft(nx_in,ny_in,nz_in,lixy_in,rixy_in,ns_in,nplane_in,nproc_in,gamma_only_in,xprime_in,mpifft_in);
}

}
void FFT_Bundle::initfftmode(int fft_mode_in)
{
Expand Down
Loading

0 comments on commit da26acc

Please sign in to comment.