Skip to content

Commit

Permalink
gpu - only overwite portion of basis target used
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Sep 26, 2024
1 parent 1945a8d commit 354029d
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 18 deletions.
19 changes: 16 additions & 3 deletions backends/cuda-ref/ceed-cuda-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes, num_qpts;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
}
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
Expand Down Expand Up @@ -206,9 +211,14 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
}

Expand Down Expand Up @@ -283,9 +293,12 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
}

Expand Down
7 changes: 6 additions & 1 deletion backends/cuda-shared/ceed-cuda-shared-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,14 @@ static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_ad

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
}

Expand Down
14 changes: 12 additions & 2 deletions backends/hip-ref/ceed-hip-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ static int CeedBasisApplyCore_Hip(CeedBasis basis, bool apply_add, const CeedInt

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes, num_qpts;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
}
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
Expand Down Expand Up @@ -204,9 +209,14 @@ static int CeedBasisApplyAtPointsCore_Hip(CeedBasis basis, bool apply_add, const

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
}

Expand Down
7 changes: 6 additions & 1 deletion backends/hip-shared/ceed-hip-shared-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,14 @@ static int CeedBasisApplyAtPointsCore_Hip_shared(CeedBasis basis, bool apply_add

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
}

Expand Down
26 changes: 15 additions & 11 deletions interface/ceed-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,10 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co

// Check compatibility coordinates vector
for (CeedInt i = 0; i < num_elem; i++) total_num_points += num_points[i];
CeedCheck((x_length >= total_num_points * dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION,
CeedCheck((x_length >= (CeedSize)total_num_points * (CeedSize)dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION,
"Length of reference coordinate vector incompatible with basis dimension and number of points."
" Found reference coordinate vector of length %" CeedSize_FMT ", not of length %" CeedSize_FMT ".",
x_length, total_num_points * dim);
x_length, (CeedSize)total_num_points * (CeedSize)dim);

// Check CEED_EVAL_WEIGHT only on CEED_NOTRANSPOSE
CeedCheck(eval_mode != CEED_EVAL_WEIGHT || t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_UNSUPPORTED,
Expand All @@ -346,13 +346,16 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co
bool has_good_dims = true;
switch (eval_mode) {
case CEED_EVAL_INTERP:
has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= total_num_points * num_q_comp || v_length >= num_elem * num_nodes * num_comp)) ||
(t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points * num_q_comp || u_length >= num_elem * num_nodes * num_comp)));
has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp ||
v_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)) ||
(t_mode == CEED_NOTRANSPOSE && (v_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp ||
u_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)));
break;
case CEED_EVAL_GRAD:
has_good_dims =
((t_mode == CEED_TRANSPOSE && (u_length >= total_num_points * num_q_comp * dim || v_length >= num_elem * num_nodes * num_comp)) ||
(t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points * num_q_comp * dim || u_length >= num_elem * num_nodes * num_comp)));
has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp * (CeedSize)dim ||
v_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)) ||
(t_mode == CEED_NOTRANSPOSE && (v_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp * (CeedSize)dim ||
u_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)));
break;
case CEED_EVAL_WEIGHT:
has_good_dims = t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points);
Expand Down Expand Up @@ -1822,12 +1825,13 @@ static int CeedBasisApplyCheckDims(CeedBasis basis, CeedInt num_elem, CeedTransp
case CEED_EVAL_GRAD:
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
has_good_dims =
((t_mode == CEED_TRANSPOSE && u_length >= num_elem * num_comp * num_qpts * q_comp && v_length >= num_elem * num_comp * num_nodes) ||
(t_mode == CEED_NOTRANSPOSE && v_length >= num_elem * num_qpts * num_comp * q_comp && u_length >= num_elem * num_comp * num_nodes));
has_good_dims = ((t_mode == CEED_TRANSPOSE && u_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_qpts * (CeedSize)q_comp &&
v_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_nodes) ||
(t_mode == CEED_NOTRANSPOSE && v_length >= (CeedSize)num_elem * (CeedSize)num_qpts * (CeedSize)num_comp * (CeedSize)q_comp &&
u_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_nodes));
break;
case CEED_EVAL_WEIGHT:
has_good_dims = v_length >= num_elem * num_qpts;
has_good_dims = v_length >= (CeedSize)num_elem * (CeedSize)num_qpts;
break;
}
CeedCheck(has_good_dims, ceed, CEED_ERROR_DIMENSION, "Input/output vectors too short for basis and evaluation mode");
Expand Down

0 comments on commit 354029d

Please sign in to comment.