diff --git a/backends/cuda-ref/ceed-cuda-ref-operator.c b/backends/cuda-ref/ceed-cuda-ref-operator.c index 515cc9a847..a9395c7bb7 100644 --- a/backends/cuda-ref/ceed-cuda-ref-operator.c +++ b/backends/cuda-ref/ceed-cuda-ref-operator.c @@ -119,60 +119,48 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool // Loop over fields for (CeedInt i = 0; i < num_fields; i++) { - bool is_strided = false, skip_e_vec = false; - CeedSize q_size; - CeedInt size; - CeedEvalMode eval_mode; - CeedBasis basis; + bool is_active = false, is_strided = false, skip_e_vec = false; + CeedSize q_size; + CeedInt size; + CeedEvalMode eval_mode; + CeedVector l_vec; + CeedElemRestriction elem_rstr; + // Check whether this field can skip the element restriction: + // Input CEED_VECTOR_ACTIVE + // Output CEED_VECTOR_ACTIVE without CEED_EVAL_NONE + // Input CEED_VECTOR_NONE with CEED_EVAL_WEIGHT + // Input passive vectorr with CEED_EVAL_NONE and strided restriction with CEED_STRIDES_BACKEND + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &l_vec)); + is_active = l_vec == CEED_VECTOR_ACTIVE; + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); - if (eval_mode != CEED_EVAL_WEIGHT) { - CeedElemRestriction elem_rstr; - - // Check whether this field can skip the element restriction: - // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. - CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); - - // First, check whether the field is input or output: - if (is_input) { - CeedVector l_vec; - - // Check for passive input - CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &l_vec)); - if (l_vec != CEED_VECTOR_ACTIVE && eval_mode == CEED_EVAL_NONE) { - // Check for strided restriction - CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); - if (is_strided) { - // Check if vector is already in preferred backend ordering - CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_e_vec)); - } - } - } - if (skip_e_vec) { - // Either an active field or strided local vec in backend ordering - e_vecs[i] = NULL; - } else { - CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i])); - } + skip_e_vec = (is_input && is_active) || (is_active && eval_mode != CEED_EVAL_NONE) || (eval_mode == CEED_EVAL_WEIGHT); + if (!skip_e_vec && is_input && !is_active && eval_mode == CEED_EVAL_NONE) { + CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); + if (is_strided) CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_e_vec)); + } + if (skip_e_vec) { + e_vecs[i] = NULL; + } else { + CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i])); } switch (eval_mode) { case CEED_EVAL_NONE: - CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); - q_size = (CeedSize)num_elem * Q * size; - CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); - break; case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); - q_size = (CeedSize)num_elem * Q * size; + q_size = (CeedSize)num_elem * (CeedSize)Q * (CeedSize)size; CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); break; - case CEED_EVAL_WEIGHT: // Only on input fields + case CEED_EVAL_WEIGHT: { + CeedBasis basis; + CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); - q_size = (CeedSize)num_elem * Q; + q_size = (CeedSize)num_elem * (CeedSize)Q; CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); if (is_at_points) { CeedInt num_points[num_elem]; @@ -184,6 +172,7 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); } break; + } } } // Drop duplicate restrictions @@ -201,7 +190,7 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { - CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); skip_rstr[j] = true; } } @@ -220,7 +209,7 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { - CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); skip_rstr[j] = true; apply_add_basis[i] = true; } @@ -348,41 +337,34 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) { // Restrict Operator Inputs //------------------------------------------------------------------------------ static inline int CeedOperatorInputRestrict_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - CeedVector in_vec, const bool skip_active, CeedScalar **e_data, CeedOperator_Cuda *impl, + CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Cuda *impl, CeedRequest *request) { - CeedEvalMode eval_mode; - CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; + bool is_active = false; + CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; // Get input vector CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) { - if (skip_active) return CEED_ERROR_SUCCESS; - else l_vec = in_vec; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; } // Restriction action - CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); - if (eval_mode == CEED_EVAL_WEIGHT) { // Skip - } else { - if (!e_vec) { - // No restriction for this field; read data directly from vec. - CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, (const CeedScalar **)e_data)); - } else { - // Restrict, if necessary - if (!impl->skip_rstr_in[input_field]) { - uint64_t state; + if (e_vec) { + // Restrict, if necessary + if (!impl->skip_rstr_in[input_field]) { + uint64_t state; - CeedCallBackend(CeedVectorGetState(l_vec, &state)); - if (state != impl->input_states[input_field] || l_vec == in_vec) { - CeedElemRestriction elem_rstr; + CeedCallBackend(CeedVectorGetState(l_vec, &state)); + if (is_active || state != impl->input_states[input_field]) { + CeedElemRestriction elem_rstr; - CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_field, &elem_rstr)); - CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, l_vec, e_vec, request)); - } - impl->input_states[input_field] = state; + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_field, &elem_rstr)); + CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, l_vec, e_vec, request)); } - // Get e-vec - CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, (const CeedScalar **)e_data)); + impl->input_states[input_field] = state; } } return CEED_ERROR_SUCCESS; @@ -392,24 +374,35 @@ static inline int CeedOperatorInputRestrict_Cuda(CeedOperatorField op_input_fiel // Input Basis Action //------------------------------------------------------------------------------ static inline int CeedOperatorInputBasis_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - CeedInt num_elem, const bool skip_active, CeedScalar *e_data, CeedOperator_Cuda *impl) { + CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const bool skip_active, + CeedOperator_Cuda *impl) { + bool is_active = false; CeedEvalMode eval_mode; - CeedVector e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; + CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; // Skip active input - if (skip_active) { - CeedVector l_vec; - - CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) return CEED_ERROR_SUCCESS; + CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_data)); + case CEED_EVAL_NONE: { + const CeedScalar *e_vec_array; + + if (e_vec) { + CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array)); + } else { + CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array)); + } + CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: @@ -430,22 +423,30 @@ static inline int CeedOperatorInputBasis_Cuda(CeedOperatorField op_input_field, // Restore Input Vectors //------------------------------------------------------------------------------ static inline int CeedOperatorInputRestore_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - const bool skip_active, CeedScalar **e_data, CeedOperator_Cuda *impl) { + CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Cuda *impl) { + bool is_active = false; CeedEvalMode eval_mode; CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; // Skip active input CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (skip_active && l_vec == CEED_VECTOR_ACTIVE) return CEED_ERROR_SUCCESS; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; + } // Restore e-vec CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); - if (eval_mode == CEED_EVAL_WEIGHT) { // Skip - } else { - if (!e_vec) { // This was a skip_restriction case - CeedCallBackend(CeedVectorRestoreArrayRead(l_vec, (const CeedScalar **)e_data)); + if (eval_mode == CEED_EVAL_NONE) { + const CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[input_field], CEED_MEM_DEVICE, (CeedScalar **)&e_vec_array)); + if (e_vec) { + CeedCallBackend(CeedVectorRestoreArrayRead(e_vec, &e_vec_array)); } else { - CeedCallBackend(CeedVectorRestoreArrayRead(e_vec, (const CeedScalar **)e_data)); + CeedCallBackend(CeedVectorRestoreArrayRead(l_vec, &e_vec_array)); } } return CEED_ERROR_SUCCESS; @@ -456,12 +457,14 @@ static inline int CeedOperatorInputRestore_Cuda(CeedOperatorField op_input_field //------------------------------------------------------------------------------ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { CeedInt Q, num_elem, num_input_fields, num_output_fields; - CeedScalar *e_data_in[CEED_FIELD_MAX] = {NULL}, *e_data_out[CEED_FIELD_MAX] = {NULL}; + Ceed ceed; + CeedVector active_e_vec; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; CeedOperatorField *op_input_fields, *op_output_fields; CeedOperator_Cuda *impl; + CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); @@ -472,13 +475,16 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec // Setup CeedCallBackend(CeedOperatorSetup_Cuda(op)); + // Work vector + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec)); + // Process inputs for (CeedInt i = 0; i < num_input_fields; i++) { CeedInt field = impl->input_field_order[i]; CeedCallBackend( - CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, false, &e_data_in[field], impl, request)); - CeedCallBackend(CeedOperatorInputBasis_Cuda(op_input_fields[field], qf_input_fields[field], field, num_elem, false, e_data_in[field], impl)); + CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request)); + CeedCallBackend(CeedOperatorInputBasis_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, false, impl)); } // Output pointers, as necessary @@ -487,9 +493,10 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - // Set the output Q-Vector to use the E-Vector data directly. - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_data_out[i])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } @@ -498,12 +505,13 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, false, &e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl)); } - // Output basis apply if needed + // Output basis and restriction for (CeedInt i = 0; i < num_output_fields; i++) { - CeedInt field = impl->output_field_order[i]; + bool is_active = false; + CeedInt field = impl->output_field_order[i]; CeedEvalMode eval_mode; CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field]; CeedElemRestriction elem_rstr; @@ -511,7 +519,11 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec // Output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) l_vec = out_vec; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active) { + l_vec = out_vec; + if (!e_vec) e_vec = active_e_vec; + } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode)); @@ -531,14 +543,17 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { - return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); + return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); // LCOV_EXCL_STOP } } // Restore evec if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_data_out[field])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); } // Restrict @@ -546,6 +561,9 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr)); CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request)); } + + // Return work vector + CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec)); return CEED_ERROR_SUCCESS; } @@ -606,12 +624,14 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { impl->q_vecs_out, num_output_fields, max_num_points, num_elem)); // Reorder fields to allow reuse of buffers + impl->max_active_e_vec_len = 0; { bool is_ordered[CEED_FIELD_MAX]; CeedInt curr_index = 0; for (CeedInt i = 0; i < num_input_fields; i++) is_ordered[i] = false; for (CeedInt i = 0; i < num_input_fields; i++) { + CeedSize e_vec_len_i; CeedVector vec_i; CeedElemRestriction rstr_i; @@ -622,6 +642,8 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); if (vec_i == CEED_VECTOR_NONE) continue; // CEED_EVAL_WEIGHT CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); + CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); + impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; for (CeedInt j = i + 1; j < num_input_fields; j++) { CeedVector vec_j; CeedElemRestriction rstr_j; @@ -642,6 +664,7 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { for (CeedInt i = 0; i < num_output_fields; i++) is_ordered[i] = false; for (CeedInt i = 0; i < num_output_fields; i++) { + CeedSize e_vec_len_i; CeedVector vec_i; CeedElemRestriction rstr_i; @@ -651,6 +674,8 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { curr_index++; CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec_i)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr_i)); + CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); + impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; for (CeedInt j = i + 1; j < num_output_fields; j++) { CeedVector vec_j; CeedElemRestriction rstr_j; @@ -665,6 +690,7 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { } } } + CeedCallBackend(CeedOperatorSetSetupDone(op)); return CEED_ERROR_SUCCESS; } @@ -673,25 +699,35 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) { // Input Basis Action AtPoints //------------------------------------------------------------------------------ static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - CeedInt num_elem, const CeedInt *num_points, const bool skip_active, CeedScalar *e_data, - CeedOperator_Cuda *impl) { + CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points, + const bool skip_active, CeedOperator_Cuda *impl) { + bool is_active = false; CeedEvalMode eval_mode; - CeedVector e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; + CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; // Skip active input - if (skip_active) { - CeedVector l_vec; - - CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) return CEED_ERROR_SUCCESS; + CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_data)); + case CEED_EVAL_NONE: { + const CeedScalar *e_vec_array; + + if (e_vec) { + CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array)); + } else { + CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array)); + } + CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: @@ -713,12 +749,14 @@ static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedOperatorField op_input //------------------------------------------------------------------------------ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; - CeedScalar *e_data_in[CEED_FIELD_MAX] = {NULL}, *e_data_out[CEED_FIELD_MAX] = {NULL}; + Ceed ceed; + CeedVector active_e_vec; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; CeedOperatorField *op_input_fields, *op_output_fields; CeedOperator_Cuda *impl; + CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); @@ -730,6 +768,9 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, num_points = impl->num_points; max_num_points = impl->max_num_points; + // Work vector + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec)); + // Get point coordinates if (!impl->point_coords_elem) { CeedVector point_coords = NULL; @@ -745,9 +786,9 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, CeedInt field = impl->input_field_order[i]; CeedCallBackend( - CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, false, &e_data_in[field], impl, request)); - CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, num_elem, num_points, false, - e_data_in[field], impl)); + CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request)); + CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, + num_points, false, impl)); } // Output pointers, as necessary @@ -756,9 +797,10 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - // Set the output Q-Vector to use the E-Vector data directly. - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_data_out[i])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } @@ -767,12 +809,13 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, false, &e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl)); } - // Output basis apply if needed + // Output basis and restriction for (CeedInt i = 0; i < num_output_fields; i++) { - CeedInt field = impl->output_field_order[i]; + bool is_active = false; + CeedInt field = impl->output_field_order[i]; CeedEvalMode eval_mode; CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field]; CeedElemRestriction elem_rstr; @@ -780,7 +823,11 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, // Output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) l_vec = out_vec; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active) { + l_vec = out_vec; + if (!e_vec) e_vec = active_e_vec; + } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode)); @@ -800,14 +847,17 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { - return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); + return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); // LCOV_EXCL_STOP } } // Restore evec if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_data_out[field])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); } // Restrict @@ -815,6 +865,9 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec, CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr)); CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request)); } + + // Restore work vector + CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec)); return CEED_ERROR_SUCCESS; } @@ -825,7 +878,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, CeedRequest *request) { Ceed ceed, ceed_parent; CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; - CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; + CeedScalar *assembled_array; CeedVector *active_inputs; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; @@ -848,8 +901,8 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, // Process inputs for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, true, &e_data[i], impl, request)); - CeedCallBackend(CeedOperatorInputBasis_Cuda(op_input_fields[i], qf_input_fields[i], i, num_elem, true, e_data[i], impl)); + CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request)); + CeedCallBackend(CeedOperatorInputBasis_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, true, impl)); } // Count number of active input fields @@ -949,7 +1002,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op, // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, true, &e_data[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl)); } // Restore output @@ -1646,12 +1699,14 @@ static int CeedOperatorLinearAssembleQFunctionAtPoints_Cuda(CeedOperator op, Cee //------------------------------------------------------------------------------ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, CeedVector assembled, CeedRequest *request) { CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; - CeedScalar *e_data_in[CEED_FIELD_MAX] = {NULL}, *e_data_out[CEED_FIELD_MAX] = {NULL}; + Ceed ceed; + CeedVector active_e_vec_in, active_e_vec_out; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; CeedOperatorField *op_input_fields, *op_output_fields; CeedOperator_Cuda *impl; + CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); @@ -1663,16 +1718,9 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C num_points = impl->num_points; max_num_points = impl->max_num_points; - // Create separate output e-vecs - if (impl->has_shared_e_vecs) { - for (CeedInt i = 0; i < impl->num_outputs; i++) { - CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); - CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); - } - CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_out, - impl->q_vecs_out, num_output_fields, max_num_points, num_elem)); - } - impl->has_shared_e_vecs = false; + // Work vector + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in)); + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out)); // Get point coordinates if (!impl->point_coords_elem) { @@ -1686,8 +1734,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C // Process inputs for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, true, &e_data_in[i], impl, request)); - CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, num_elem, num_points, true, e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request)); + CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, impl)); } // Clear active input Qvecs @@ -1705,9 +1753,10 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - // Set the output Q-Vector to use the E-Vector data directly. - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_data_out[i])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } @@ -1735,32 +1784,36 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C for (CeedInt s = 0; s < e_vec_size; s++) { bool is_active_input = false; CeedEvalMode eval_mode; - CeedVector vec; + CeedVector l_vec, q_vec = impl->q_vecs_in[i]; CeedBasis basis; - CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); + CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec)); // Skip non-active input - is_active_input = vec == CEED_VECTOR_ACTIVE; + is_active_input = l_vec == CEED_VECTOR_ACTIVE; if (!is_active_input) continue; // Update unit vector - if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0)); - else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs_in[i], s - 1, e_vec_size, 0.0)); - CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs_in[i], s, e_vec_size, 1.0)); + if (s == 0) CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0)); + else CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, s - 1, e_vec_size, 0.0)); + CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, s, e_vec_size, 1.0)); // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_in[i])); + case CEED_EVAL_NONE: { + const CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); - CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, - impl->e_vecs_in[i], impl->q_vecs_in[i])); + CeedCallBackend( + CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, active_e_vec_in, q_vec)); break; case CEED_EVAL_WEIGHT: break; // No action @@ -1775,7 +1828,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C CeedInt elem_size = 0; CeedRestrictionType rstr_type; CeedEvalMode eval_mode; - CeedVector l_vec; + CeedVector l_vec, e_vec = impl->e_vecs_out[j], q_vec = impl->q_vecs_out[j]; CeedElemRestriction elem_rstr; CeedBasis basis; @@ -1783,6 +1836,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C // ---- Skip non-active output is_active_output = l_vec == CEED_VECTOR_ACTIVE; if (!is_active_output) continue; + if (!e_vec) e_vec = active_e_vec_out; // ---- Check if elem size matches CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); @@ -1803,16 +1857,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[j], &e_data_out[j])); + case CEED_EVAL_NONE: { + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(q_vec, CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis)); - CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, - impl->q_vecs_out[j], impl->e_vecs_out[j])); + CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec)); break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -1822,21 +1879,23 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C } // Mask output e-vec - CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs_out[j], impl->e_vecs_in[i], impl->e_vecs_out[j])); + CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec)); // Restrict CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); - CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_out[j], assembled, request)); + CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, assembled, request)); // Reset q_vec for if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[j], CEED_MEM_DEVICE, &e_data_out[j])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[j], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[j])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(e_vec, CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } // Reset vec - if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); + if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(q_vec, 0.0)); } } @@ -1850,13 +1909,16 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C // Restore evec CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &e_data_in[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &e_vec_array)); } } // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, true, &e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl)); } return CEED_ERROR_SUCCESS; } diff --git a/backends/cuda-ref/ceed-cuda-ref.h b/backends/cuda-ref/ceed-cuda-ref.h index b6e7016013..9e167463bd 100644 --- a/backends/cuda-ref/ceed-cuda-ref.h +++ b/backends/cuda-ref/ceed-cuda-ref.h @@ -131,7 +131,7 @@ typedef struct { } CeedOperatorAssemble_Cuda; typedef struct { - bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; uint64_t *input_states; // State tracking for passive inputs CeedVector *e_vecs_in, *e_vecs_out; CeedVector *q_vecs_in, *q_vecs_out; diff --git a/backends/hip-ref/ceed-hip-ref-operator.c b/backends/hip-ref/ceed-hip-ref-operator.c index 4534e840ee..06625ef9a2 100644 --- a/backends/hip-ref/ceed-hip-ref-operator.c +++ b/backends/hip-ref/ceed-hip-ref-operator.c @@ -118,60 +118,48 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i // Loop over fields for (CeedInt i = 0; i < num_fields; i++) { - bool is_strided = false, skip_e_vec = false; - CeedSize q_size; - CeedInt size; - CeedEvalMode eval_mode; - CeedBasis basis; + bool is_active = false, is_strided = false, skip_e_vec = false; + CeedSize q_size; + CeedInt size; + CeedEvalMode eval_mode; + CeedVector l_vec; + CeedElemRestriction elem_rstr; + // Check whether this field can skip the element restriction: + // Input CEED_VECTOR_ACTIVE + // Output CEED_VECTOR_ACTIVE without CEED_EVAL_NONE + // Input CEED_VECTOR_NONE with CEED_EVAL_WEIGHT + // Input passive vectorr with CEED_EVAL_NONE and strided restriction with CEED_STRIDES_BACKEND + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &l_vec)); + is_active = l_vec == CEED_VECTOR_ACTIVE; + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); - if (eval_mode != CEED_EVAL_WEIGHT) { - CeedElemRestriction elem_rstr; - - // Check whether this field can skip the element restriction: - // Must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND. - CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr)); - - // First, check whether the field is input or output: - if (is_input) { - CeedVector l_vec; - - // Check for passive input - CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &l_vec)); - if (l_vec != CEED_VECTOR_ACTIVE && eval_mode == CEED_EVAL_NONE) { - // Check for strided restriction - CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); - if (is_strided) { - // Check if vector is already in preferred backend ordering - CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_e_vec)); - } - } - } - if (skip_e_vec) { - // Either an active field or strided local vec in backend ordering - e_vecs[i] = NULL; - } else { - CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i])); - } + skip_e_vec = (is_input && is_active) || (is_active && eval_mode != CEED_EVAL_NONE) || (eval_mode == CEED_EVAL_WEIGHT); + if (!skip_e_vec && is_input && !is_active && eval_mode == CEED_EVAL_NONE) { + CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided)); + if (is_strided) CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_e_vec)); + } + if (skip_e_vec) { + e_vecs[i] = NULL; + } else { + CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i])); } switch (eval_mode) { case CEED_EVAL_NONE: - CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); - q_size = (CeedSize)num_elem * Q * size; - CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); - break; case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size)); - q_size = (CeedSize)num_elem * Q * size; + q_size = (CeedSize)num_elem * (CeedSize)Q * (CeedSize)size; CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); break; - case CEED_EVAL_WEIGHT: // Only on input fields + case CEED_EVAL_WEIGHT: { + CeedBasis basis; + CeedCallBackend(CeedOperatorFieldGetBasis(op_fields[i], &basis)); - q_size = (CeedSize)num_elem * Q; + q_size = (CeedSize)num_elem * (CeedSize)Q; CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i])); if (is_at_points) { CeedInt num_points[num_elem]; @@ -183,6 +171,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i])); } break; + } } } // Drop duplicate restrictions @@ -200,7 +189,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { - CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); skip_rstr[j] = true; } } @@ -219,7 +208,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { - CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + if (e_vecs[i]) CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); skip_rstr[j] = true; apply_add_basis[i] = true; } @@ -347,41 +336,34 @@ static int CeedOperatorSetup_Hip(CeedOperator op) { // Restrict Operator Inputs //------------------------------------------------------------------------------ static inline int CeedOperatorInputRestrict_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - CeedVector in_vec, const bool skip_active, CeedScalar **e_data, CeedOperator_Hip *impl, + CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Hip *impl, CeedRequest *request) { - CeedEvalMode eval_mode; - CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; + bool is_active = false; + CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; // Get input vector CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) { - if (skip_active) return CEED_ERROR_SUCCESS; - else l_vec = in_vec; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; } // Restriction action - CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); - if (eval_mode == CEED_EVAL_WEIGHT) { // Skip - } else { - if (!e_vec) { - // No restriction for this field; read data directly from vec. - CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, (const CeedScalar **)e_data)); - } else { - // Restrict, if necessary - if (!impl->skip_rstr_in[input_field]) { - uint64_t state; + if (e_vec) { + // Restrict, if necessary + if (!impl->skip_rstr_in[input_field]) { + uint64_t state; - CeedCallBackend(CeedVectorGetState(l_vec, &state)); - if (state != impl->input_states[input_field] || l_vec == in_vec) { - CeedElemRestriction elem_rstr; + CeedCallBackend(CeedVectorGetState(l_vec, &state)); + if (is_active || state != impl->input_states[input_field]) { + CeedElemRestriction elem_rstr; - CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_field, &elem_rstr)); - CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, l_vec, e_vec, request)); - } - impl->input_states[input_field] = state; + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_field, &elem_rstr)); + CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, l_vec, e_vec, request)); } - // Get e-vec - CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, (const CeedScalar **)e_data)); + impl->input_states[input_field] = state; } } return CEED_ERROR_SUCCESS; @@ -391,24 +373,35 @@ static inline int CeedOperatorInputRestrict_Hip(CeedOperatorField op_input_field // Input Basis Action //------------------------------------------------------------------------------ static inline int CeedOperatorInputBasis_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - CeedInt num_elem, const bool skip_active, CeedScalar *e_data, CeedOperator_Hip *impl) { + CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const bool skip_active, + CeedOperator_Hip *impl) { + bool is_active = false; CeedEvalMode eval_mode; - CeedVector e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; + CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; // Skip active input - if (skip_active) { - CeedVector l_vec; - - CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) return CEED_ERROR_SUCCESS; + CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_data)); + case CEED_EVAL_NONE: { + const CeedScalar *e_vec_array; + + if (e_vec) { + CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array)); + } else { + CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array)); + } + CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: @@ -429,22 +422,30 @@ static inline int CeedOperatorInputBasis_Hip(CeedOperatorField op_input_field, C // Restore Input Vectors //------------------------------------------------------------------------------ static inline int CeedOperatorInputRestore_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - const bool skip_active, CeedScalar **e_data, CeedOperator_Hip *impl) { + CeedVector in_vec, CeedVector active_e_vec, const bool skip_active, CeedOperator_Hip *impl) { + bool is_active = false; CeedEvalMode eval_mode; CeedVector l_vec, e_vec = impl->e_vecs_in[input_field]; // Skip active input CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (skip_active && l_vec == CEED_VECTOR_ACTIVE) return CEED_ERROR_SUCCESS; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; + } // Restore e-vec CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); - if (eval_mode == CEED_EVAL_WEIGHT) { // Skip - } else { - if (!e_vec) { // This was a skip_restriction case - CeedCallBackend(CeedVectorRestoreArrayRead(l_vec, (const CeedScalar **)e_data)); + if (eval_mode == CEED_EVAL_NONE) { + const CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[input_field], CEED_MEM_DEVICE, (CeedScalar **)&e_vec_array)); + if (e_vec) { + CeedCallBackend(CeedVectorRestoreArrayRead(e_vec, &e_vec_array)); } else { - CeedCallBackend(CeedVectorRestoreArrayRead(e_vec, (const CeedScalar **)e_data)); + CeedCallBackend(CeedVectorRestoreArrayRead(l_vec, &e_vec_array)); } } return CEED_ERROR_SUCCESS; @@ -455,12 +456,14 @@ static inline int CeedOperatorInputRestore_Hip(CeedOperatorField op_input_field, //------------------------------------------------------------------------------ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { CeedInt Q, num_elem, num_input_fields, num_output_fields; - CeedScalar *e_data_in[CEED_FIELD_MAX] = {NULL}, *e_data_out[CEED_FIELD_MAX] = {NULL}; + Ceed ceed; + CeedVector active_e_vec; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; CeedOperatorField *op_input_fields, *op_output_fields; CeedOperator_Hip *impl; + CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); CeedCallBackend(CeedOperatorGetNumQuadraturePoints(op, &Q)); @@ -471,13 +474,15 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect // Setup CeedCallBackend(CeedOperatorSetup_Hip(op)); + // Work vector + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec)); + // Process inputs for (CeedInt i = 0; i < num_input_fields; i++) { CeedInt field = impl->input_field_order[i]; - CeedCallBackend( - CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, false, &e_data_in[field], impl, request)); - CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[field], qf_input_fields[field], field, num_elem, false, e_data_in[field], impl)); + CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request)); + CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, false, impl)); } // Output pointers, as necessary @@ -486,9 +491,10 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - // Set the output Q-Vector to use the E-Vector data directly. - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_data_out[i])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } @@ -497,12 +503,13 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, false, &e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl)); } - // Output basis apply if needed + // Output basis and restriction for (CeedInt i = 0; i < num_output_fields; i++) { - CeedInt field = impl->output_field_order[i]; + bool is_active = false; + CeedInt field = impl->output_field_order[i]; CeedEvalMode eval_mode; CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field]; CeedElemRestriction elem_rstr; @@ -510,7 +517,11 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect // Output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) l_vec = out_vec; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active) { + l_vec = out_vec; + if (!e_vec) e_vec = active_e_vec; + } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode)); @@ -530,14 +541,17 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { - return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); + return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); // LCOV_EXCL_STOP } } // Restore evec if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_data_out[field])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); } // Restrict @@ -545,6 +559,9 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr)); CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request)); } + + // Return work vector + CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec)); return CEED_ERROR_SUCCESS; } @@ -605,12 +622,14 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { num_output_fields, max_num_points, num_elem)); // Reorder fields to allow reuse of buffers + impl->max_active_e_vec_len = 0; { bool is_ordered[CEED_FIELD_MAX]; CeedInt curr_index = 0; for (CeedInt i = 0; i < num_input_fields; i++) is_ordered[i] = false; for (CeedInt i = 0; i < num_input_fields; i++) { + CeedSize e_vec_len_i; CeedVector vec_i; CeedElemRestriction rstr_i; @@ -621,6 +640,8 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i)); if (vec_i == CEED_VECTOR_NONE) continue; // CEED_EVAL_WEIGHT CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i)); + CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); + impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; for (CeedInt j = i + 1; j < num_input_fields; j++) { CeedVector vec_j; CeedElemRestriction rstr_j; @@ -641,6 +662,7 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { for (CeedInt i = 0; i < num_output_fields; i++) is_ordered[i] = false; for (CeedInt i = 0; i < num_output_fields; i++) { + CeedSize e_vec_len_i; CeedVector vec_i; CeedElemRestriction rstr_i; @@ -650,6 +672,8 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { curr_index++; CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec_i)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr_i)); + CeedCallBackend(CeedElemRestrictionGetEVectorSize(rstr_i, &e_vec_len_i)); + impl->max_active_e_vec_len = e_vec_len_i > impl->max_active_e_vec_len ? e_vec_len_i : impl->max_active_e_vec_len; for (CeedInt j = i + 1; j < num_output_fields; j++) { CeedVector vec_j; CeedElemRestriction rstr_j; @@ -664,6 +688,7 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { } } } + CeedCallBackend(CeedOperatorSetSetupDone(op)); return CEED_ERROR_SUCCESS; } @@ -672,25 +697,35 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { // Input Basis Action AtPoints //------------------------------------------------------------------------------ static inline int CeedOperatorInputBasisAtPoints_Hip(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field, - CeedInt num_elem, const CeedInt *num_points, const bool skip_active, CeedScalar *e_data, - CeedOperator_Hip *impl) { + CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points, + const bool skip_active, CeedOperator_Hip *impl) { + bool is_active = false; CeedEvalMode eval_mode; - CeedVector e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; + CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field]; // Skip active input - if (skip_active) { - CeedVector l_vec; - - CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) return CEED_ERROR_SUCCESS; + CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec)); + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active && skip_active) return CEED_ERROR_SUCCESS; + if (is_active) { + l_vec = in_vec; + if (!e_vec) e_vec = active_e_vec; } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_field, &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_data)); + case CEED_EVAL_NONE: { + const CeedScalar *e_vec_array; + + if (e_vec) { + CeedCallBackend(CeedVectorGetArrayRead(e_vec, CEED_MEM_DEVICE, &e_vec_array)); + } else { + CeedCallBackend(CeedVectorGetArrayRead(l_vec, CEED_MEM_DEVICE, &e_vec_array)); + } + CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: @@ -712,12 +747,14 @@ static inline int CeedOperatorInputBasisAtPoints_Hip(CeedOperatorField op_input_ //------------------------------------------------------------------------------ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedVector out_vec, CeedRequest *request) { CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; - CeedScalar *e_data_in[CEED_FIELD_MAX] = {NULL}, *e_data_out[CEED_FIELD_MAX] = {NULL}; + Ceed ceed; + CeedVector active_e_vec; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; CeedOperatorField *op_input_fields, *op_output_fields; CeedOperator_Hip *impl; + CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); @@ -729,6 +766,9 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, num_points = impl->num_points; max_num_points = impl->max_num_points; + // Work vector + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec)); + // Get point coordinates if (!impl->point_coords_elem) { CeedVector point_coords = NULL; @@ -743,10 +783,9 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, for (CeedInt i = 0; i < num_input_fields; i++) { CeedInt field = impl->input_field_order[i]; - CeedCallBackend( - CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, false, &e_data_in[field], impl, request)); - CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[field], qf_input_fields[field], field, num_elem, num_points, false, - e_data_in[field], impl)); + CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request)); + CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem, + num_points, false, impl)); } // Output pointers, as necessary @@ -755,9 +794,10 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - // Set the output Q-Vector to use the E-Vector data directly. - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_data_out[i])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } @@ -766,12 +806,13 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, false, &e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, in_vec, active_e_vec, false, impl)); } - // Output basis apply if needed + // Output basis and restriction for (CeedInt i = 0; i < num_output_fields; i++) { - CeedInt field = impl->output_field_order[i]; + bool is_active = false; + CeedInt field = impl->output_field_order[i]; CeedEvalMode eval_mode; CeedVector l_vec, e_vec = impl->e_vecs_out[field], q_vec = impl->q_vecs_out[field]; CeedElemRestriction elem_rstr; @@ -779,7 +820,11 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, // Output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field], &l_vec)); - if (l_vec == CEED_VECTOR_ACTIVE) l_vec = out_vec; + is_active = l_vec == CEED_VECTOR_ACTIVE; + if (is_active) { + l_vec = out_vec; + if (!e_vec) e_vec = active_e_vec; + } // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field], &eval_mode)); @@ -799,14 +844,17 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { - return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); + return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode"); // LCOV_EXCL_STOP } } // Restore evec if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_data_out[field])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); } // Restrict @@ -814,6 +862,9 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field], &elem_rstr)); CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, l_vec, request)); } + + // Restore work vector + CeedCallBackend(CeedRestoreWorkVector(ceed, &active_e_vec)); return CEED_ERROR_SUCCESS; } @@ -824,7 +875,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, b CeedRequest *request) { Ceed ceed, ceed_parent; CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size; - CeedScalar *assembled_array, *e_data[2 * CEED_FIELD_MAX] = {NULL}; + CeedScalar *assembled_array; CeedVector *active_inputs; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; @@ -847,8 +898,8 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, b // Process inputs for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, true, &e_data[i], impl, request)); - CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[i], qf_input_fields[i], i, num_elem, true, e_data[i], impl)); + CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request)); + CeedCallBackend(CeedOperatorInputBasis_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, true, impl)); } // Count number of active input fields @@ -948,7 +999,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, b // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, true, &e_data[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl)); } // Restore output @@ -1339,7 +1390,7 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op, //------------------------------------------------------------------------------ static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceedsize_idx) { Ceed ceed; - Ceed_Hip *hip_data; + Ceed_Hip *Hip_data; char *assembly_kernel_source; const char *assembly_kernel_path; CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0; @@ -1429,8 +1480,8 @@ static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceed asmb->block_size_x = elem_size_in; asmb->block_size_y = elem_size_out; - CeedCallBackend(CeedGetData(ceed, &hip_data)); - bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > hip_data->device_prop.maxThreadsPerBlock; + CeedCallBackend(CeedGetData(ceed, &Hip_data)); + bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > Hip_data->device_prop.maxThreadsPerBlock; if (fallback) { // Use fallback kernel with 1D threadblock @@ -1440,7 +1491,7 @@ static int CeedSingleOperatorAssembleSetup_Hip(CeedOperator op, CeedInt use_ceed // Compile kernels CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in)); CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out)); - CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-operator-assemble.h", &assembly_kernel_path)); + CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/Hip/Hip-ref-operator-assemble.h", &assembly_kernel_path)); CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Kernel Source -----\n"); CeedCallBackend(CeedLoadSourceToBuffer(ceed, assembly_kernel_path, &assembly_kernel_source)); CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Assembly Source Complete! -----\n"); @@ -1645,12 +1696,14 @@ static int CeedOperatorLinearAssembleQFunctionAtPoints_Hip(CeedOperator op, Ceed //------------------------------------------------------------------------------ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, CeedVector assembled, CeedRequest *request) { CeedInt max_num_points, *num_points, num_elem, num_input_fields, num_output_fields; - CeedScalar *e_data_in[CEED_FIELD_MAX] = {NULL}, *e_data_out[CEED_FIELD_MAX] = {NULL}; + Ceed ceed; + CeedVector active_e_vec_in, active_e_vec_out; CeedQFunctionField *qf_input_fields, *qf_output_fields; CeedQFunction qf; CeedOperatorField *op_input_fields, *op_output_fields; CeedOperator_Hip *impl; + CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); @@ -1662,16 +1715,9 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce num_points = impl->num_points; max_num_points = impl->max_num_points; - // Create separate output e-vecs - if (impl->has_shared_e_vecs) { - for (CeedInt i = 0; i < impl->num_outputs; i++) { - CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i])); - CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_out[i])); - } - CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_out, - impl->q_vecs_out, num_output_fields, max_num_points, num_elem)); - } - impl->has_shared_e_vecs = false; + // Work vector + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in)); + CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out)); // Get point coordinates if (!impl->point_coords_elem) { @@ -1685,8 +1731,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce // Process inputs for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, true, &e_data_in[i], impl, request)); - CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[i], qf_input_fields[i], i, num_elem, num_points, true, e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestrict_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request)); + CeedCallBackend(CeedOperatorInputBasisAtPoints_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, impl)); } // Clear active input Qvecs @@ -1704,9 +1750,10 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - // Set the output Q-Vector to use the E-Vector data directly. - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_data_out[i])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } @@ -1734,32 +1781,36 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce for (CeedInt s = 0; s < e_vec_size; s++) { bool is_active_input = false; CeedEvalMode eval_mode; - CeedVector vec; + CeedVector l_vec, q_vec = impl->q_vecs_in[i]; CeedBasis basis; - CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); + CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec)); // Skip non-active input - is_active_input = vec == CEED_VECTOR_ACTIVE; + is_active_input = l_vec == CEED_VECTOR_ACTIVE; if (!is_active_input) continue; // Update unit vector - if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0)); - else CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs_in[i], s - 1, e_vec_size, 0.0)); - CeedCallBackend(CeedVectorSetValueStrided(impl->e_vecs_in[i], s, e_vec_size, 1.0)); + if (s == 0) CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0)); + else CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, s - 1, e_vec_size, 0.0)); + CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, s, e_vec_size, 1.0)); // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_in[i])); + case CEED_EVAL_NONE: { + const CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis)); - CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, - impl->e_vecs_in[i], impl->q_vecs_in[i])); + CeedCallBackend( + CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, active_e_vec_in, q_vec)); break; case CEED_EVAL_WEIGHT: break; // No action @@ -1774,7 +1825,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce CeedInt elem_size = 0; CeedRestrictionType rstr_type; CeedEvalMode eval_mode; - CeedVector l_vec; + CeedVector l_vec, e_vec = impl->e_vecs_out[j], q_vec = impl->q_vecs_out[j]; CeedElemRestriction elem_rstr; CeedBasis basis; @@ -1782,6 +1833,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce // ---- Skip non-active output is_active_output = l_vec == CEED_VECTOR_ACTIVE; if (!is_active_output) continue; + if (!e_vec) e_vec = active_e_vec_out; // ---- Check if elem size matches CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); @@ -1802,16 +1854,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce // Basis action CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode)); switch (eval_mode) { - case CEED_EVAL_NONE: - CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_out[j], &e_data_out[j])); + case CEED_EVAL_NONE: { + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(q_vec, CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(e_vec, &e_vec_array)); break; + } case CEED_EVAL_INTERP: case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis)); - CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, - impl->q_vecs_out[j], impl->e_vecs_out[j])); + CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec)); break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -1821,21 +1876,23 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce } // Mask output e-vec - CeedCallBackend(CeedVectorPointwiseMult(impl->e_vecs_out[j], impl->e_vecs_in[i], impl->e_vecs_out[j])); + CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec)); // Restrict CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr)); - CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs_out[j], assembled, request)); + CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, e_vec, assembled, request)); // Reset q_vec for if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_out[j], CEED_MEM_DEVICE, &e_data_out[j])); - CeedCallBackend(CeedVectorSetArray(impl->q_vecs_out[j], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data_out[j])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorGetArrayWrite(e_vec, CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, e_vec_array)); } } // Reset vec - if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0)); + if (s == e_vec_size - 1 && i != num_input_fields - 1) CeedCallBackend(CeedVectorSetValue(q_vec, 0.0)); } } @@ -1849,13 +1906,16 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce // Restore evec CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); if (eval_mode == CEED_EVAL_NONE) { - CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &e_data_in[i])); + CeedScalar *e_vec_array; + + CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, &e_vec_array)); + CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &e_vec_array)); } } // Restore input arrays for (CeedInt i = 0; i < num_input_fields; i++) { - CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, true, &e_data_in[i], impl)); + CeedCallBackend(CeedOperatorInputRestore_Hip(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl)); } return CEED_ERROR_SUCCESS; } diff --git a/backends/hip-ref/ceed-hip-ref.h b/backends/hip-ref/ceed-hip-ref.h index fb2c5b565e..52e88129a1 100644 --- a/backends/hip-ref/ceed-hip-ref.h +++ b/backends/hip-ref/ceed-hip-ref.h @@ -135,7 +135,7 @@ typedef struct { } CeedOperatorAssemble_Hip; typedef struct { - bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; uint64_t *input_states; // State tracking for passive inputs CeedVector *e_vecs_in, *e_vecs_out; CeedVector *q_vecs_in, *q_vecs_out; diff --git a/interface/ceed-basis.c b/interface/ceed-basis.c index 67b9d43345..aa19489e0a 100644 --- a/interface/ceed-basis.c +++ b/interface/ceed-basis.c @@ -331,11 +331,6 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co if (x_ref != CEED_VECTOR_NONE) CeedCall(CeedVectorGetLength(x_ref, &x_length)); if (u != CEED_VECTOR_NONE) CeedCall(CeedVectorGetLength(u, &u_length)); - // Check compatibility of topological and geometrical dimensions - CeedCheck((t_mode == CEED_TRANSPOSE && v_length % num_nodes == 0) || (t_mode == CEED_NOTRANSPOSE && u_length % num_nodes == 0) || - (eval_mode == CEED_EVAL_WEIGHT), - ceed, CEED_ERROR_DIMENSION, "Length of input/output vectors incompatible with basis dimensions and number of points"); - // Check compatibility coordinates vector for (CeedInt i = 0; i < num_elem; i++) total_num_points += num_points[i]; CeedCheck((x_length >= total_num_points * dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION, @@ -1819,11 +1814,6 @@ static int CeedBasisApplyCheckDims(CeedBasis basis, CeedInt num_elem, CeedTransp CeedCall(CeedVectorGetLength(v, &v_length)); if (u) CeedCall(CeedVectorGetLength(u, &u_length)); - // Check compatibility of topological and geometrical dimensions - CeedCheck((t_mode == CEED_TRANSPOSE && v_length % num_nodes == 0 && u_length % num_qpts == 0) || - (t_mode == CEED_NOTRANSPOSE && u_length % num_nodes == 0 && v_length % num_qpts == 0), - ceed, CEED_ERROR_DIMENSION, "Length of input/output vectors incompatible with basis dimensions"); - // Check vector lengths to prevent out of bounds issues bool has_good_dims = true; switch (eval_mode) { diff --git a/interface/ceed-vector.c b/interface/ceed-vector.c index 7c10ca98bc..99672bce09 100644 --- a/interface/ceed-vector.c +++ b/interface/ceed-vector.c @@ -862,7 +862,7 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) { CeedCall(CeedVectorGetLength(w, &length_w)); CeedCall(CeedVectorGetLength(x, &length_x)); CeedCall(CeedVectorGetLength(y, &length_y)); - CeedCheck(length_w == length_x && length_w == length_y, ceed, CEED_ERROR_UNSUPPORTED, + CeedCheck(length_x >= length_x && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED, "Cannot multiply vectors of different lengths." " x length: %" CeedSize_FMT " y length: %" CeedSize_FMT, length_x, length_y);