Skip to content

Commit

Permalink
refine error messsage
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Nov 12, 2024
1 parent ec0d61f commit 62d0aff
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
45 changes: 28 additions & 17 deletions paddle/fluid/pir/dialect/distributed/ir/dist_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,22 @@ void VerifyOpArgNum(const pir::OpBase* op,
PADDLE_ENFORCE_EQ(input_size,
num_inputs,
common::errors::PreconditionNotMet(
"The size of inputs must be equal to "
"%u.",
num_inputs));
"Mismatched inputs size, expected:%u, "
"but received:%u.",
num_inputs,
input_size));
}

VLOG(4) << "Verifying outputs num:";
{
auto output_size = op->num_results();
PADDLE_ENFORCE_EQ(output_size,
num_outputs,
common::errors::PreconditionNotMet("The size of outputs "
"must be equal to %u.",
num_outputs));
common::errors::PreconditionNotMet(
"Mismatched outputs size, expected:%u, "
"but received:%u.",
num_outputs,
output_size));
}

VLOG(4) << "Verifying attributes:";
Expand Down Expand Up @@ -111,13 +114,15 @@ void ShardTensorOp::VerifySig() {
(*this)->operand_source(0).type().isa<paddle::dialect::DenseTensorType>(),
true,
common::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
"Mismatched input type. ShardTensorOp requires 'DenseTensorType' for "
"input."));

PADDLE_ENFORCE_EQ(
(*this)->result(0).type().isa<paddle::dialect::DistDenseTensorType>(),
true,
common::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
"Mismatched output type. ShardTensorOp requires "
"'DistDenseTensorType' for output."));

VLOG(4) << "End Verifying for: ShardTensorOp.";
}
Expand Down Expand Up @@ -452,15 +457,18 @@ void MoESubMeshTensorsOp::VerifySig() {
.isa<paddle::dialect::DistDenseTensorType>(),
true,
common::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
"Mismatched input type. MoESubMeshTensorsOp requires "
"'DistDenseTensorType' for input."));

auto output_size = num_results();
for (size_t i = 0; i < output_size; ++i) {
PADDLE_ENFORCE_EQ(
(*this)->result(i).type().isa<paddle::dialect::DistDenseTensorType>(),
true,
common::errors::PreconditionNotMet(
"Type validation failed for the %d output.", i));
"Mismatched type of %u'th output. MoESubMeshTensorsOp requires "
"'DistDenseTensorType.",
i));
}

VLOG(4) << "End Verifying for: moe_sub_mesh_tensors op.";
Expand Down Expand Up @@ -535,13 +543,16 @@ void MoEGlobalMeshTensorOp::VerifySig() {

auto input_size = num_operands();
for (size_t i = 0; i < input_size; ++i) {
PADDLE_ENFORCE_EQ((*this)
->operand_source(i)
.type()
.isa<paddle::dialect::DistDenseTensorType>(),
true,
common::errors::PreconditionNotMet(
"Type validation failed for the %d input.", i));
PADDLE_ENFORCE_EQ(
(*this)
->operand_source(i)
.type()
.isa<paddle::dialect::DistDenseTensorType>(),
true,
common::errors::PreconditionNotMet(
"Mismatched type of %u'th input. MoEGlobalMeshTensorOp requires "
"'DistDenseTensorType'.",
i));
}

PADDLE_ENFORCE_EQ(
Expand Down
14 changes: 9 additions & 5 deletions python/paddle/distributed/auto_parallel/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,15 @@ def _dist_reshape(
"""
tgt_global_shape = infer_positive_shape(dist_tensor.shape, global_shape)
tgt_local_shape = _cal_local_shape(tgt_global_shape, mesh, placements)
src_local_shape = dist_tensor._local_value().shape
if paddle.in_dynamic_mode():
src_local_shape = dist_tensor._local_value().shape
elif paddle.framework.in_pir_mode():
src_local_shape = dist_tensor._local_shape
else:
raise NotImplementedError(
"dist_reshape is only supported in dynamic and pir mode."
)

assert np.prod(tgt_local_shape) == np.prod(
src_local_shape
), f"The local shapes {src_local_shape} and {tgt_local_shape} are mismatched."
Expand All @@ -212,10 +220,6 @@ def _dist_reshape(
mesh,
placements,
)
else:
raise NotImplementedError(
"dist_reshape is only supported in dynamic mode."
)


def _reshard_mesh_shape(
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60)
py_test_modules(test_custom_spec_pir MODULES test_custom_spec_pir ENVS
FLAGS_enable_pir_api=1)
py_test_modules(test_moe_api MODULES test_pir_moe_utils_api ENVS
py_test_modules(test_pir_moe_utils_api MODULES test_pir_moe_utils_api ENVS
FLAGS_enable_pir_api=1)
endif()
py_test_modules(test_pir_1f1b_plan MODULES test_pir_1f1b_plan ENVS
Expand Down

0 comments on commit 62d0aff

Please sign in to comment.