From a850b6137bb2e0d2ee42c5164f90f5529fa8a865 Mon Sep 17 00:00:00 2001 From: Christian Glusa Date: Mon, 12 Feb 2024 14:44:53 -0700 Subject: [PATCH] MueLu: Fix #12736 and use Kokkos views in GetMatrixDiagonal --- .../src/Utils/MueLu_UtilitiesBase_def.hpp | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/packages/muelu/src/Utils/MueLu_UtilitiesBase_def.hpp b/packages/muelu/src/Utils/MueLu_UtilitiesBase_def.hpp index e101a7991e2a..ea1d6de062c2 100644 --- a/packages/muelu/src/Utils/MueLu_UtilitiesBase_def.hpp +++ b/packages/muelu/src/Utils/MueLu_UtilitiesBase_def.hpp @@ -197,7 +197,16 @@ UtilitiesBase:: const auto rowMap = A.getRowMap(); auto diag = Xpetra::VectorFactory::Build(rowMap, true); - A.getLocalDiagCopy(*diag); + const CrsMatrixWrap* crsOp = dynamic_cast(&A); + if ((crsOp != NULL) && (rowMap->lib() == Xpetra::UseTpetra)) { + using local_vector_type = typename Vector::dual_view_type::t_dev_um; + using execution_space = typename local_vector_type::execution_space; + Kokkos::View offsets("offsets", rowMap->getLocalNumElements()); + crsOp->getCrsGraph()->getLocalDiagOffsets(offsets); + crsOp->getCrsMatrix()->getLocalDiagCopy(*diag, offsets); + } else { + A.getLocalDiagCopy(*diag); + } return diag; } @@ -623,22 +632,9 @@ UtilitiesBase:: GetMatrixOverlappedDiagonal(const Matrix& A) { // FIXME_KOKKOS RCP rowMap = A.getRowMap(), colMap = A.getColMap(); - RCP localDiag = VectorFactory::Build(rowMap); - - const CrsMatrixWrap* crsOp = dynamic_cast(&A); - if ((crsOp != NULL) && (rowMap->lib() == Xpetra::UseTpetra)) { - Teuchos::ArrayRCP offsets; - crsOp->getLocalDiagOffsets(offsets); - crsOp->getLocalDiagCopy(*localDiag, offsets()); - } else { - auto localDiagVals = localDiag->getDeviceLocalView(Xpetra::Access::ReadWrite); - const auto diagVals = GetMatrixDiagonal(A)->getDeviceLocalView(Xpetra::Access::ReadOnly); - Kokkos::deep_copy(localDiagVals, diagVals); - } - - RCP diagonal = VectorFactory::Build(colMap); - RCP importer; - importer = A.getCrsGraph()->getImporter(); + RCP localDiag = GetMatrixDiagonal(A); + RCP diagonal = VectorFactory::Build(colMap); + RCP importer = A.getCrsGraph()->getImporter(); if (importer == Teuchos::null) { importer = ImportFactory::Build(rowMap, colMap); }