Skip to content

Commit

Permalink
MueLu: Fix #12736 and use Kokkos views in GetMatrixDiagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
cgcgcg committed Feb 12, 2024
1 parent 561af35 commit a850b61
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions packages/muelu/src/Utils/MueLu_UtilitiesBase_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,16 @@ UtilitiesBase<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
const auto rowMap = A.getRowMap();
auto diag = Xpetra::VectorFactory<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Build(rowMap, true);

A.getLocalDiagCopy(*diag);
const CrsMatrixWrap* crsOp = dynamic_cast<const CrsMatrixWrap*>(&A);
if ((crsOp != NULL) && (rowMap->lib() == Xpetra::UseTpetra)) {
using local_vector_type = typename Vector::dual_view_type::t_dev_um;
using execution_space = typename local_vector_type::execution_space;
Kokkos::View<size_t*, execution_space> offsets("offsets", rowMap->getLocalNumElements());
crsOp->getCrsGraph()->getLocalDiagOffsets(offsets);
crsOp->getCrsMatrix()->getLocalDiagCopy(*diag, offsets);
} else {
A.getLocalDiagCopy(*diag);
}

return diag;
}
Expand Down Expand Up @@ -623,22 +632,9 @@ UtilitiesBase<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
GetMatrixOverlappedDiagonal(const Matrix& A) {
// FIXME_KOKKOS
RCP<const Map> rowMap = A.getRowMap(), colMap = A.getColMap();
RCP<Vector> localDiag = VectorFactory::Build(rowMap);

const CrsMatrixWrap* crsOp = dynamic_cast<const CrsMatrixWrap*>(&A);
if ((crsOp != NULL) && (rowMap->lib() == Xpetra::UseTpetra)) {
Teuchos::ArrayRCP<size_t> offsets;
crsOp->getLocalDiagOffsets(offsets);
crsOp->getLocalDiagCopy(*localDiag, offsets());
} else {
auto localDiagVals = localDiag->getDeviceLocalView(Xpetra::Access::ReadWrite);
const auto diagVals = GetMatrixDiagonal(A)->getDeviceLocalView(Xpetra::Access::ReadOnly);
Kokkos::deep_copy(localDiagVals, diagVals);
}

RCP<Vector> diagonal = VectorFactory::Build(colMap);
RCP<const Import> importer;
importer = A.getCrsGraph()->getImporter();
RCP<Vector> localDiag = GetMatrixDiagonal(A);
RCP<Vector> diagonal = VectorFactory::Build(colMap);
RCP<const Import> importer = A.getCrsGraph()->getImporter();
if (importer == Teuchos::null) {
importer = ImportFactory::Build(rowMap, colMap);
}
Expand Down

0 comments on commit a850b61

Please sign in to comment.