diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 8c1f8818e6..42e4601ae9 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -35,15 +35,18 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond // TODO: Added memory usage statistics //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - // the product of H and psi in the reduced basis set + resmem_complex_op()(this->ctx, this->psi_in_iter, this->nbase_x * this->dim, "DAV::psi_in_iter"); + setmem_complex_op()(this->ctx, this->psi_in_iter, 0, this->nbase_x * this->dim); + + // the product of H and psi in the reduced psi set resmem_complex_op()(this->ctx, this->hphi, this->nbase_x * this->dim, "DAV::hphi"); setmem_complex_op()(this->ctx, this->hphi, 0, this->nbase_x * this->dim); - // Hamiltonian on the reduced basis set + // Hamiltonian on the reduced psi set resmem_complex_op()(this->ctx, this->hcc, this->nbase_x * this->nbase_x, "DAV::hcc"); setmem_complex_op()(this->ctx, this->hcc, 0, this->nbase_x * this->nbase_x); - // Overlap on the reduced basis set + // Overlap on the reduced psi set resmem_complex_op()(this->ctx, this->scc, this->nbase_x * this->nbase_x, "DAV::scc"); setmem_complex_op()(this->ctx, this->scc, 0, this->nbase_x * this->nbase_x); @@ -56,11 +59,7 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond if (this->device == base_device::GpuDevice) { resmem_real_op()(this->ctx, this->d_precondition, nbasis_in); - syncmem_var_h2d_op()(this->ctx, - this->cpu_ctx, - this->d_precondition, - this->precondition.data(), - nbasis_in); + syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition.data(), nbasis_in); } #endif } @@ -68,6 +67,8 @@ Diago_DavSubspace::Diago_DavSubspace(const std::vector& precond template Diago_DavSubspace::~Diago_DavSubspace() { + delmem_complex_op()(this->ctx, this->psi_in_iter); + delmem_complex_op()(this->ctx, this->hphi); delmem_complex_op()(this->ctx, this->hcc); delmem_complex_op()(this->ctx, this->scc); @@ -82,17 +83,18 @@ Diago_DavSubspace::~Diago_DavSubspace() } template -int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, - psi::Psi& psi, - Real* eigenvalue_in_hsolver, - const std::vector& is_occupied) +int Diago_DavSubspace::diag_once( + + const Func& hpsi_func, + T* psi_in, + + psi::Psi& psi, + + Real* eigenvalue_in_hsolver, + const std::vector& is_occupied) { ModuleBase::timer::tick("Diago_DavSubspace", "diag_once"); - // TODO: Allocate memory in the constructor - psi::Psi basis(1, this->nbase_x, this->dim, &(psi.get_ngk(0))); - ModuleBase::Memory::record("DAV::basis", this->nbase_x * this->dim * sizeof(T)); - // the eigenvalues in dav iter std::vector eigenvalue_iter(this->nbase_x, 0.0); @@ -102,7 +104,7 @@ int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, // unconv[m] store the number of the m th unconvergent band std::vector unconv(this->n_band); - // the dimension of the reduced basis set + // the dimension of the reduced psi set int nbase = 0; // the number of the unconvergent bands @@ -116,16 +118,19 @@ int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, syncmem_complex_op()(this->ctx, this->ctx, - &basis(m, 0), + this->psi_in_iter + m * this->dim, psi.get_k_first() ? &psi(m, 0) : &psi(m, 0, 0), this->dim); } - // calculate H|psi> - hpsi_info dav_hpsi_in(&basis, psi::Range(1, 0, 0, this->n_band - 1), this->hphi); - phm_in->ops->hPsi(dav_hpsi_in); + // auto psi_iter_wrapper = psi::Psi(this->psi_in_iter, 1, this->nbase_x, this->dim); + // // calculate H|psi> + // hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, 0, psi_iter_wrapper.get_nbands() - 1), this->hphi); + // phm_in->ops->hPsi(dav_hpsi_in); + + hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1); - this->cal_elem(this->dim, nbase, this->notconv, basis, this->hphi, this->hcc, this->scc); + this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc); this->diag_zhegvx(nbase, this->n_band, @@ -150,17 +155,20 @@ int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, { dav_iter++; - this->cal_grad(phm_in, - this->dim, - nbase, - this->notconv, - basis, - this->hphi, - this->vcc, - unconv.data(), - &eigenvalue_iter); + this->cal_grad( - this->cal_elem(this->dim, nbase, this->notconv, basis, this->hphi, this->hcc, this->scc); + hpsi_func, + + this->dim, + nbase, + this->notconv, + this->psi_in_iter, + this->hphi, + this->vcc, + unconv.data(), + &eigenvalue_iter); + + this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc); this->diag_zhegvx(nbase, this->n_band, @@ -214,7 +222,7 @@ int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, this->n_band, // n: col of B,C nbase, // k: col of A, row of B this->one, - basis.get_pointer(), // A dim * nbase + this->psi_in_iter, // A dim * nbase this->dim, this->vcc, // B nbase * n_band this->nbase_x, @@ -240,7 +248,7 @@ int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, nbase, eigenvalue_in_hsolver, psi, - basis, + this->psi_in_iter, this->hphi, this->hcc, this->scc, @@ -257,11 +265,11 @@ int Diago_DavSubspace::diag_once(hamilt::Hamilt* phm_in, } template -void Diago_DavSubspace::cal_grad(hamilt::Hamilt* phm_in, +void Diago_DavSubspace::cal_grad(const Func& hpsi_func, const int& dim, const int& nbase, const int& notconv, - psi::Psi& basis, + T* psi_iter, T* hphi, T* vcc, const int* unconv, @@ -281,17 +289,17 @@ void Diago_DavSubspace::cal_grad(hamilt::Hamilt* phm_in, gemm_op()(this->ctx, 'N', 'N', - this->dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - &basis(0, 0), // A - this->dim, // LDA - vcc, // B - this->nbase_x, // LDB - this->zero, // belta - &basis(nbase, 0), // C dim * notconv - this->dim // LDC + this->dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + psi_iter, // A + this->dim, // LDA + vcc, // B + this->nbase_x, // LDB + this->zero, // belta + psi_iter + nbase * this->dim, // C dim * notconv + this->dim // LDC ); for (int m = 0; m < notconv; m++) @@ -301,25 +309,25 @@ void Diago_DavSubspace::cal_grad(hamilt::Hamilt* phm_in, vector_mul_vector_op()(this->ctx, this->dim, - &basis(nbase + m, 0), - &basis(nbase + m, 0), + psi_iter + (nbase + m) * this->dim, + psi_iter + (nbase + m) * this->dim, e_temp_cpu.data()); } gemm_op()(this->ctx, 'N', 'N', - this->dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - hphi, // A dim * nbase - this->dim, // LDA - vcc, // B nbase * notconv - this->nbase_x, // LDB - this->one, // belta - &basis(nbase, 0), // C dim * notconv - this->dim // LDC + this->dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + hphi, // A dim * nbase + this->dim, // LDA + vcc, // B nbase * notconv + this->nbase_x, // LDB + this->one, // belta + psi_iter + (nbase) * this->dim, + this->dim // LDC ); // "precondition!!!" @@ -331,27 +339,37 @@ void Diago_DavSubspace::cal_grad(hamilt::Hamilt* phm_in, double x = this->precondition[i] - (*eigenvalue_iter)[m]; pre[i] = 0.5 * (1.0 + x + sqrt(1 + (x - 1.0) * (x - 1.0))); } - vector_div_vector_op()(this->ctx, this->dim, &basis(nbase + m, 0), &basis(nbase + m, 0), pre.data()); + vector_div_vector_op()(this->ctx, + this->dim, + psi_iter + (nbase + m) * this->dim, + psi_iter + (nbase + m) * this->dim, + pre.data()); } // "normalize!!!" in order to improve numerical stability of subspace diagonalization std::vector psi_norm(notconv, 0.0); for (size_t i = 0; i < notconv; i++) { - psi_norm[i] = dot_real_op()(this->ctx, this->dim, &basis(nbase + i, 0), &basis(nbase + i, 0), false); + psi_norm[i] = dot_real_op()(this->ctx, + this->dim, + psi_iter + (nbase + i) * this->dim, + psi_iter + (nbase + i) * this->dim, + false); assert(psi_norm[i] > 0.0); psi_norm[i] = sqrt(psi_norm[i]); vector_div_constant_op()(this->ctx, this->dim, - &basis(nbase + i, 0), - &basis(nbase + i, 0), + psi_iter + (nbase + i) * this->dim, + psi_iter + (nbase + i) * this->dim, psi_norm[i]); } - // "calculate H|psi>" for not convergence bands - hpsi_info dav_hpsi_in(&basis, psi::Range(1, 0, nbase, nbase + notconv - 1), &hphi[nbase * this->dim]); - phm_in->ops->hPsi(dav_hpsi_in); + // auto psi_iter_wrapper = psi::Psi(psi_iter, 1, this->nbase_x, this->dim); + // // "calculate H|psi>" for not convergence bands + // hpsi_info dav_hpsi_in(&psi_iter_wrapper, psi::Range(1, 0, nbase, nbase + notconv - 1), &hphi[nbase * this->dim]); + // phm_in->ops->hPsi(dav_hpsi_in); + hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1); ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad"); return; @@ -361,7 +379,7 @@ template void Diago_DavSubspace::cal_elem(const int& dim, int& nbase, const int& notconv, - const psi::Psi& basis, + const T* psi_iter, const T* hphi, T* hcc, T* scc) @@ -375,7 +393,7 @@ void Diago_DavSubspace::cal_elem(const int& dim, notconv, this->dim, this->one, - &basis(0, 0), + psi_iter, this->dim, &hphi[nbase * this->dim], this->dim, @@ -390,9 +408,9 @@ void Diago_DavSubspace::cal_elem(const int& dim, notconv, this->dim, this->one, - &basis(0, 0), + psi_iter, this->dim, - &basis(nbase, 0), + psi_iter + nbase * this->dim, this->dim, this->zero, &scc[nbase * this->nbase_x], @@ -630,7 +648,7 @@ void Diago_DavSubspace::refresh(const int& dim, int& nbase, const Real* eigenvalue_in_hsolver, const psi::Psi& psi, - psi::Psi& basis, + T* psi_iter, T* hp, T* sp, T* hc, @@ -638,10 +656,14 @@ void Diago_DavSubspace::refresh(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "refresh"); - // update basis + // update psi for (size_t i = 0; i < nband; i++) { - syncmem_complex_op()(this->ctx, this->ctx, &basis(i, 0), &psi(i, 0), this->dim); + syncmem_complex_op()(this->ctx, + this->ctx, + psi_iter + i * this->dim, + &psi(i, 0), + this->dim); } gemm_op()(this->ctx, 'N', @@ -655,11 +677,15 @@ void Diago_DavSubspace::refresh(const int& dim, this->vcc, this->nbase_x, this->zero, - &basis(nband, 0), + psi_iter + nband * this->dim, this->dim); // update hphi - syncmem_complex_op()(this->ctx, this->ctx, hphi, &basis(nband, 0), this->dim * nband); + syncmem_complex_op()(this->ctx, + this->ctx, + hphi, + psi_iter + nband * this->dim, + this->dim * nband); nbase = nband; @@ -725,18 +751,19 @@ void Diago_DavSubspace::refresh(const int& dim, } template -int Diago_DavSubspace::diag(hamilt::Hamilt* phm_in, +int Diago_DavSubspace::diag(const Func& hpsi_func, + T* psi_in, + + hamilt::Hamilt* phm_in, psi::Psi& psi, + Real* eigenvalue_in_hsolver, const std::vector& is_occupied, const bool& scf_type) { - /// record the times of trying iterative diagonalization - this->notconv = 0; - + /** bool outputscc = false; bool outputeigenvalue = false; - if (outputscc) { std::cout << "before dav 111" << std::endl; @@ -765,7 +792,6 @@ int Diago_DavSubspace::diag(hamilt::Hamilt* phm_in, } std::cout << std::endl; } - if (outputeigenvalue) { // output: eigenvalue_in_hsolver @@ -776,6 +802,10 @@ int Diago_DavSubspace::diag(hamilt::Hamilt* phm_in, } std::cout << std::endl; } + */ + + /// record the times of trying iterative diagonalization + this->notconv = 0; int sum_iter = 0; int ntry = 0; @@ -786,18 +816,27 @@ int Diago_DavSubspace::diag(hamilt::Hamilt* phm_in, DiagoIterAssist::diagH_subspace(phm_in, psi, psi, eigenvalue_in_hsolver, psi.get_nbands()); } - sum_iter += this->diag_once(phm_in, psi, eigenvalue_in_hsolver, is_occupied); + sum_iter += this->diag_once( + + hpsi_func, + psi_in, + + psi, + + eigenvalue_in_hsolver, + is_occupied); ++ntry; } while (this->test_exit_cond(ntry, this->notconv, scf_type)); - if (notconv > std::max(5, psi.get_nbands() / 4)) + if (notconv > std::max(5, this->n_band / 4)) { std::cout << "\n notconv = " << this->notconv; std::cout << "\n Diago_DavSubspace::diag', too many bands are not converged! \n"; } + /** if (outputeigenvalue) { // output: eigenvalue_in_hsolver @@ -808,7 +847,6 @@ int Diago_DavSubspace::diag(hamilt::Hamilt* phm_in, } std::cout << std::endl; } - if (outputscc) { std::cout << "after dav 222 " << std::endl; @@ -837,6 +875,7 @@ int Diago_DavSubspace::diag(hamilt::Hamilt* phm_in, } std::cout << std::endl; } + */ return sum_iter; } diff --git a/source/module_hsolver/diago_dav_subspace.h b/source/module_hsolver/diago_dav_subspace.h index d2c5a3ab41..de702a1e27 100644 --- a/source/module_hsolver/diago_dav_subspace.h +++ b/source/module_hsolver/diago_dav_subspace.h @@ -3,6 +3,8 @@ #include "diagh.h" +#include + namespace hsolver { @@ -27,8 +29,14 @@ class Diago_DavSubspace : public DiagH virtual ~Diago_DavSubspace() override; - int diag(hamilt::Hamilt* phm_in, + using Func = std::function; + + int diag(const Func& hpsi_func, + T* psi_in, + + hamilt::Hamilt* phm_in, psi::Psi& phi, + Real* eigenvalue_in, const std::vector& is_occupied, const bool& scf_type); @@ -62,6 +70,8 @@ class Diago_DavSubspace : public DiagH /// record for how many bands not have convergence eigenvalues int notconv = 0; + T* psi_in_iter = nullptr; + /// the product of H and psi in the reduced basis set T* hphi = nullptr; @@ -79,30 +89,24 @@ class Diago_DavSubspace : public DiagH base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; - void cal_grad(hamilt::Hamilt* phm_in, + void cal_grad(const Func& hpsi_func, const int& dim, const int& nbase, const int& notconv, - psi::Psi& basis, + T* psi_iter, T* hphi, T* vcc, const int* unconv, std::vector* eigenvalue_iter); - void cal_elem(const int& dim, - int& nbase, - const int& notconv, - const psi::Psi& basis, - const T* hphi, - T* hcc, - T* scc); + void cal_elem(const int& dim, int& nbase, const int& notconv, const T* psi_iter, const T* hphi, T* hcc, T* scc); void refresh(const int& dim, const int& nband, int& nbase, const Real* eigenvalue, const psi::Psi& psi, - psi::Psi& basis, + T* psi_iter, T* hphi, T* hcc, T* scc, @@ -118,7 +122,8 @@ class Diago_DavSubspace : public DiagH bool init, bool is_subspace); - int diag_once(hamilt::Hamilt* phm_in, + int diag_once(const Func& hpsi_func, + T* psi_in, psi::Psi& psi, Real* eigenvalue_in, const std::vector& is_occupied); diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 168b398919..ad04584bea 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -716,145 +716,169 @@ void HSolverPW::updatePsiK(hamilt::Hamilt* pHamilt, psi::P template void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& psi, Real* eigenvalue) { - if (this->method != "cg") + if (this->method == "cg") { - if (this->method == "dav_subspace") - { -#ifdef __MPI - const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; -#else - const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; -#endif - - // this->pdiagh = new Diago_DavSubspace(this->precondition, - - // psi.get_nbands(), - // psi.get_k_first() ? psi.get_current_nbas() : psi.get_nk() * psi.get_nbasis(), - - // GlobalV::PW_DIAG_NDIM, - // DiagoIterAssist::PW_DIAG_THR, - // DiagoIterAssist::PW_DIAG_NMAX, - // DiagoIterAssist::need_subspace, - // comm_info); - Diago_DavSubspace dav_subspace(this->precondition, - psi.get_nbands(), - psi.get_k_first() ? psi.get_current_nbas() - : psi.get_nk() * psi.get_nbasis(), - GlobalV::PW_DIAG_NDIM, - DiagoIterAssist::PW_DIAG_THR, - DiagoIterAssist::PW_DIAG_NMAX, - DiagoIterAssist::need_subspace, - comm_info); - - // this->pdiagh->method = this->method; - - bool scf; - if (GlobalV::CALCULATION == "nscf") + // warp the hpsi_func and spsi_func into a lambda function + using ct_Device = typename ct::PsiToContainer::type; + auto cg = reinterpret_cast*>(this->pdiagh); + // warp the hpsi_func and spsi_func into a lambda function + auto ngk_pointer = psi.get_ngk_pointer(); + auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { + ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); + // psi_in should be a 2D tensor: + // psi_in.shape() = [nbands, nbasis] + const auto ndim = psi_in.shape().ndim(); + REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); + // Convert a Tensor object to a psi::Psi object + auto psi_wrapper = psi::Psi(psi_in.data(), + 1, + ndim == 1 ? 1 : psi_in.shape().dim_size(0), + ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), + ngk_pointer); + psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); + using hpsi_info = typename hamilt::Operator::hpsi_info; + hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); + hm->ops->hPsi(info); + ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); + }; + auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { + ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); + // psi_in should be a 2D tensor: + // psi_in.shape() = [nbands, nbasis] + const auto ndim = psi_in.shape().ndim(); + REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); + if (GlobalV::use_uspp) { - scf = false; + // Convert a Tensor object to a psi::Psi object + hm->sPsi(psi_in.data(), + spsi_out.data(), + ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), + ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), + ndim == 1 ? 1 : psi_in.shape().dim_size(0)); } else { - scf = true; + base_device::memory::synchronize_memory_op()( + this->ctx, + this->ctx, + spsi_out.data(), + psi_in.data(), + static_cast((ndim == 1 ? 1 : psi_in.shape().dim_size(0)) + * (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1)))); } - DiagoIterAssist::avg_iter - // += static_cast((reinterpret_cast*>(this->pdiagh))->diag( - += static_cast(dav_subspace.diag(hm, psi, eigenvalue, is_occupied, scf)); - - // delete reinterpret_cast*>(this->pdiagh); - this->pdiagh = nullptr; - } - else if (this->method == "bpcg") - { - this->pdiagh->diag(hm, psi, eigenvalue); - } - else // method == "dav" - { - // Allow 5 tries at most. If ntry > ntry_max = 5, exit diag loop. - const int ntry_max = 5; - // In non-self consistent calculation, do until totally converged. Else allow 5 eigenvecs to be NOT - // converged. - const int notconv_max = ("nscf" == GlobalV::CALCULATION) ? 0 : 5; - // do diag and add davidson iteration counts up to avg_iter - const Real david_diag_thr = DiagoIterAssist::PW_DIAG_THR; - const int david_maxiter = DiagoIterAssist::PW_DIAG_NMAX; - auto david = (reinterpret_cast*>(this->pdiagh)); - DiagoIterAssist::avg_iter += static_cast( - david->diag(hm, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max)); - } - return; + ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); + }; + auto psi_tensor = ct::TensorMap(psi.get_pointer(), + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({psi.get_nbands(), psi.get_nbasis()})) + .slice({0, 0}, {psi.get_nbands(), psi.get_current_nbas()}); + auto eigen_tensor = ct::TensorMap(eigenvalue, + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({psi.get_nbands()})); + auto prec_tensor = ct::TensorMap(precondition.data(), + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({static_cast(precondition.size())})) + .to_device() + .slice({0}, {psi.get_current_nbas()}); + + cg->diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor); + // TODO: Double check tensormap's potential problem + ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor); } - // warp the hpsi_func and spsi_func into a lambda function - using ct_Device = typename ct::PsiToContainer::type; - auto cg = reinterpret_cast*>(this->pdiagh); - // warp the hpsi_func and spsi_func into a lambda function - auto ngk_pointer = psi.get_ngk_pointer(); - auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { - ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); - // psi_in should be a 2D tensor: - // psi_in.shape() = [nbands, nbasis] - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - // Convert a Tensor object to a psi::Psi object - auto psi_wrapper = psi::Psi(psi_in.data(), - 1, - ndim == 1 ? 1 : psi_in.shape().dim_size(0), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_pointer); - psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); - using hpsi_info = typename hamilt::Operator::hpsi_info; - hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); - hm->ops->hPsi(info); - ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); - }; - auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { - ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); - // psi_in should be a 2D tensor: - // psi_in.shape() = [nbands, nbasis] - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - - if (GlobalV::use_uspp) + else if (this->method == "dav_subspace") + { +#ifdef __MPI + const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; +#else + const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; +#endif + this->pdiagh = new Diago_DavSubspace(this->precondition, + + psi.get_nbands(), + psi.get_k_first() ? psi.get_current_nbas() + : psi.get_nk() * psi.get_nbasis(), + + GlobalV::PW_DIAG_NDIM, + DiagoIterAssist::PW_DIAG_THR, + DiagoIterAssist::PW_DIAG_NMAX, + DiagoIterAssist::need_subspace, + comm_info); + + this->pdiagh->method = this->method; + + bool scf; + if (GlobalV::CALCULATION == "nscf") { - // Convert a Tensor object to a psi::Psi object - hm->sPsi(psi_in.data(), - spsi_out.data(), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ndim == 1 ? 1 : psi_in.shape().dim_size(0)); + scf = false; } else { - base_device::memory::synchronize_memory_op()( - this->ctx, - this->ctx, - spsi_out.data(), - psi_in.data(), - static_cast((ndim == 1 ? 1 : psi_in.shape().dim_size(0)) - * (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1)))); + scf = true; } - ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); - }; - auto psi_tensor = ct::TensorMap(psi.get_pointer(), - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({psi.get_nbands(), psi.get_nbasis()})) - .slice({0, 0}, {psi.get_nbands(), psi.get_current_nbas()}); - auto eigen_tensor = ct::TensorMap(eigenvalue, - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({psi.get_nbands()})); - auto prec_tensor = ct::TensorMap(precondition.data(), - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({static_cast(precondition.size())})) - .to_device() - .slice({0}, {psi.get_current_nbas()}); - - cg->diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor); - // TODO: Double check tensormap's potential problem - ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor); + auto ngk_pointer = psi.get_ngk_pointer(); + + std::function hpsi_func = [hm, ngk_pointer]( + T* hpsi_out, + T* psi_in, + const int nband_in, + const int nbasis_in, + const int band_index1, + const int band_index2) + { + ModuleBase::timer::tick("DavSubspace", "hpsi_func"); + + // Convert "pointer data stucture" to a psi::Psi object + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nband_in, nbasis_in, ngk_pointer); + + psi::Range bands_range(1, 0, band_index1, band_index2); + + using hpsi_info = typename hamilt::Operator::hpsi_info; + hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out); + hm->ops->hPsi(info); + + ModuleBase::timer::tick("DavSubspace", "hpsi_func"); + }; + + + DiagoIterAssist::avg_iter + += static_cast((reinterpret_cast*>(this->pdiagh)) + ->diag( + + hpsi_func, + psi.get_pointer(), + + hm, + psi, + eigenvalue, + is_occupied, + scf)); + + delete reinterpret_cast*>(this->pdiagh); + this->pdiagh = nullptr; + } + else if (this->method == "bpcg") + { + this->pdiagh->diag(hm, psi, eigenvalue); + } + else if (this->method == "dav") + { + // Allow 5 tries at most. If ntry > ntry_max = 5, exit diag loop. + const int ntry_max = 5; + // In non-self consistent calculation, do until totally converged. Else allow 5 eigenvecs to be NOT converged. + const int notconv_max = ("nscf" == GlobalV::CALCULATION)? 0: 5; + // do diag and add davidson iteration counts up to avg_iter + const Real david_diag_thr = DiagoIterAssist::PW_DIAG_THR; + const int david_maxiter = DiagoIterAssist::PW_DIAG_NMAX; + auto david = (reinterpret_cast*>(this->pdiagh)); + DiagoIterAssist::avg_iter += static_cast( + david->diag(hm, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max) + ); + } + return; } template