diff --git a/anphon/scph.cpp b/anphon/scph.cpp index 0e214faa..c891b4eb 100644 --- a/anphon/scph.cpp +++ b/anphon/scph.cpp @@ -1081,7 +1081,7 @@ void Scph::compute_V4_elements_mpi_over_kpoint(std::complex ***v4_out, maxsize=(maxsize<<31)-1; const size_t count = nk2_prod * ns4; - const size_t count_sub = nk2_prod * ns2; + const size_t count_sub = ns4; if (count <= maxsize) { #ifdef MPI_CXX_DOUBLE_COMPLEX @@ -1091,21 +1091,36 @@ void Scph::compute_V4_elements_mpi_over_kpoint(std::complex ***v4_out, MPI_Allreduce(&v4_mpi[0][0][0], &v4_out[0][0][0], count, MPI_COMPLEX16, MPI_SUM, MPI_COMM_WORLD); #endif - } else { - for (is = 0; is < ns2; ++is) { + } else if (count_sub <= maxsize) { + for (size_t ik_prod = 0; ik_prod < nk2_prod; ++ik_prod) { #ifdef MPI_CXX_DOUBLE_COMPLEX - MPI_Allreduce(&v4_mpi[0][0][is], &v4_out[0][0][is], count_sub, - MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&v4_mpi[ik_prod][0][0], &v4_out[ik_prod][0][0], + count_sub, + MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD); #else - MPI_Allreduce(&v4_mpi[0][0][is], &v4_out[0][0][is], count_sub, + MPI_Allreduce(&v4_mpi[ik_prod][0][0], &v4_out[ik_prod][0][0], + count_sub, MPI_COMPLEX16, MPI_SUM, MPI_COMM_WORLD); #endif } + } else { + for (size_t ik_prod = 0; ik_prod < nk2_prod; ++ik_prod) { + for (is = 0; is < ns2; ++is) { +#ifdef MPI_CXX_DOUBLE_COMPLEX + MPI_Allreduce(&v4_mpi[ik_prod][is][0], &v4_out[ik_prod][is][0], + ns2, + MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD); +#else + MPI_Allreduce(&v4_mpi[ik_prod][is][0], &v4_out[ik_prod][is][0], + ns2, + MPI_COMPLEX16, MPI_SUM, MPI_COMM_WORLD); +#endif + } + } } memory->deallocate(v4_mpi); - zerofill_elements_acoustic_at_gamma(omega2_harmonic, v4_out, 4); @@ -1318,7 +1333,7 @@ void Scph::compute_V4_elements_mpi_over_band(std::complex ***v4_out, maxsize=(maxsize<<31)-1; const size_t count = nk2_prod * ns4; - const size_t count_sub = nk2_prod * ns2; + const size_t count_sub = ns4; if (count <= maxsize) { #ifdef MPI_CXX_DOUBLE_COMPLEX @@ -1328,16 +1343,32 @@ void Scph::compute_V4_elements_mpi_over_band(std::complex ***v4_out, MPI_Allreduce(&v4_mpi[0][0][0], &v4_out[0][0][0], count, MPI_COMPLEX16, MPI_SUM, MPI_COMM_WORLD); #endif - } else { - for (is = 0; is < ns2; ++is) { + } else if (count_sub <= maxsize) { + for (size_t ik_prod = 0; ik_prod < nk2_prod; ++ik_prod) { #ifdef MPI_CXX_DOUBLE_COMPLEX - MPI_Allreduce(&v4_mpi[0][0][is], &v4_out[0][0][is], count_sub, - MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&v4_mpi[ik_prod][0][0], &v4_out[ik_prod][0][0], + count_sub, + MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD); #else - MPI_Allreduce(&v4_mpi[0][0][is], &v4_out[0][0][is], count_sub, + MPI_Allreduce(&v4_mpi[ik_prod][0][0], &v4_out[ik_prod][0][0], + count_sub, MPI_COMPLEX16, MPI_SUM, MPI_COMM_WORLD); #endif } + } else { + for (size_t ik_prod = 0; ik_prod < nk2_prod; ++ik_prod) { + for (is = 0; is < ns2; ++is) { +#ifdef MPI_CXX_DOUBLE_COMPLEX + MPI_Allreduce(&v4_mpi[ik_prod][is][0], &v4_out[ik_prod][is][0], + ns2, + MPI_CXX_DOUBLE_COMPLEX, MPI_SUM, MPI_COMM_WORLD); +#else + MPI_Allreduce(&v4_mpi[ik_prod][is][0], &v4_out[ik_prod][is][0], + ns2, + MPI_COMPLEX16, MPI_SUM, MPI_COMM_WORLD); +#endif + } + } } memory->deallocate(v4_mpi);