diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 3b5ccaa062f6e..edcd53bdc7a52 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -15,12 +15,6 @@ from .interface import shard_tensor # noqa: F401 from .interface import shard_op # noqa: F401 from .process_mesh import ProcessMesh -# from .interface import set_shard_mask # noqa: F401 -# from .interface import set_offload_device # noqa: F401 -# from .interface import set_pipeline_stage # noqa: F401 -# from .interface import ProcessMesh # noqa: F401 -from .completion import complete_annotation # noqa: F401 -from .completion import complete_backward_annotation # noqa: F401 from .reshard import reshard # noqa: F401 from .cost_model import estimate_cost diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 660b1a54221a7..54491f9e6c16e 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from copy import deepcopy +import time from paddle.fluid import core from paddle.fluid import framework -from .utils import compute_compatible_process_mesh -from .utils import compute_compatible_dim_mapping -from .utils import compute_compatible_dims_mapping from .utils import print_program_with_dist_attr from .operators import find_best_compatible_distributed_operator_impl from .dist_context import get_default_distributed_context @@ -29,865 +28,602 @@ from .dist_attribute import OperatorDistributedAttribute from paddle.distributed.fleet.meta_optimizers.common import OpRole -ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] +def compute_compatible_process_mesh(process_mesh_list): + """Compute the compatible process mesh given a list of process meshes.""" + if not process_mesh_list: + return None -def is_elementwise_like_op(op_type): - if op_type in ELEMENTWISE_LIKE_OP_LIST: - return True - else: - return False - + def _compute_compatible_process_mesh_two(pm1, pm2): + if pm1 is None: + return True, pm2 + if pm2 is None: + return True, pm1 + if pm1 == pm2: + return True, pm1 + if pm1.processes == pm2.processes: + if len(pm1.topology) >= len(pm2.topology): + return True, pm1 + else: + return True, pm2 + process_set1 = set(pm1.processes) + process_set2 = set(pm2.processes) + if process_set1.issubset(process_set2): + return True, pm2 + if process_set2.issubset(process_set1): + return True, pm1 + return False, None + + compatible_result = None + for process_mesh in process_mesh_list: + compatible, compatible_result = _compute_compatible_process_mesh_two( + compatible_result, process_mesh) + if not compatible: + return None + return copy.deepcopy(compatible_result) + + +def compute_compatible_dim_mapping(dim_mapping_list): + """Compute the compatible dim mapping given a list of dim mapping.""" + if not dim_mapping_list: + return None -def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True): - """ - Update tensor's process mesh by using its predecessor's process mesh if in the forward direction, - and by using its successor's process mesh if in the backward direction. Note: only the equal - process meshes are compatible for now. + def _compute_compatible_dim_mapping_two(dm1, dm2): + if dm1 == -1: + return True, dm2 + if dm2 == -1: + return True, dm1 + if dm1 == dm2: + return True, dm1 + return False, None + + compatible_result = -1 + for mapping in dim_mapping_list: + compatible, compatible_result = _compute_compatible_dim_mapping_two( + compatible_result, mapping) + if not compatible: + return None + return compatible_result + + +def compute_compatible_dims_mapping(dims_mapping_list): + """Compute the compatible dims mapping given a list of dims mapping. + Each of dims mapping is also a list. """ - changed = False - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) - if tensor_dist_attr.is_annotated("process_mesh"): - return changed - tensor_process_mesh = tensor_dist_attr.process_mesh - if fwd: - inputs_process_meshes = [] - for pred_op_node in tensor_node.inputs: - if pred_op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - pred_op_node) - op_process_mesh = op_dist_attr.process_mesh - inputs_process_meshes.append(op_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - inputs_process_meshes) - if compatible_process_mesh is not None and tensor_process_mesh is None: - tensor_dist_attr.process_mesh = compatible_process_mesh - changed = True - else: - outputs_process_meshes = [] - for succ_op_node in tensor_node.outputs: - if succ_op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - succ_op_node) - op_process_mesh = op_dist_attr.process_mesh - outputs_process_meshes.append(op_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - outputs_process_meshes) - if compatible_process_mesh is not None and tensor_process_mesh is None: - tensor_dist_attr.process_mesh = compatible_process_mesh - changed = True - return changed - - -def update_op_node_process_mesh(dist_context, op_node, fwd=True): - """ - Update op's process mesh by using its predecessor's process mesh if in the forward direction, - and by using its successor's process mesh if in the backward direction. Note: only the equal - process meshes are compatible for now. - """ - changed = False - op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) - if op_dist_attr.is_annotated("process_mesh"): - return changed - op_process_mesh = op_dist_attr.process_mesh - if fwd: - inputs_process_meshes = [] - for tensor_node in op_node.inputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_process_mesh = tensor_dist_attr.process_mesh - inputs_process_meshes.append(tensor_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - inputs_process_meshes) - if compatible_process_mesh is not None and op_process_mesh is None: - op_dist_attr.process_mesh = compatible_process_mesh - changed = True - else: - outputs_process_meshes = [] - for tensor_node in op_node.outputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_process_mesh = tensor_dist_attr.process_mesh - outputs_process_meshes.append(tensor_process_mesh) - compatible_process_mesh = compute_compatible_process_mesh( - outputs_process_meshes) - if compatible_process_mesh is not None and op_process_mesh is None: - op_dist_attr.process_mesh = compatible_process_mesh - changed = True - return changed - - -def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node): - """Each operator has a default distributed operator, only allowed to be sharded in batch dimension.""" - changed = False - if (not op_node.is_op()) or (op_node.op() is None): - return False - op_desc = op_node.op() - dist_op = dist_context.get_dist_op_for_graph(op_node) - op_dist_attr = dist_op.dist_attr - # The following statement will be replaced by a more elegent way - if op_desc.type() == "shape" or op_desc.type() == "slice": - return False - output_names = op_desc.output_names() - xshape_arg_names = [] - if "XShape" in output_names: - xshape_arg_names = op_desc.output("XShape") - batch_dim_mappings = [] - for arg_name in op_desc.input_arg_names(): - serial_tensor = dist_op.get_serial_input(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if len(dims_mapping) > 1: - for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[0]) - for arg_name in op_desc.output_arg_names(): - serial_tensor = dist_op.get_serial_output(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: - if len(dims_mapping) > 1: - for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[0]) - else: - assert dims_mapping[0] == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ - .format(op_desc.type(), mapping) - if len(dims_mapping) > 2: - for idx, mapping in enumerate(dims_mapping[2:]): - assert mapping == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) - batch_dim_mappings.append(dims_mapping[1]) - - compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) - assert compatible_dim_mapping is not None, "There is no compatible dim mapping." - for arg_name in op_desc.input_arg_names(): - serial_tensor = dist_op.get_serial_input(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping - changed = True - for arg_name in op_desc.output_arg_names(): - serial_tensor = dist_op.get_serial_output(arg_name) - if serial_tensor.is_parameter: - continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if arg_name not in xshape_arg_names: - if compatible_dim_mapping != dims_mapping[0]: - dims_mapping[0] = compatible_dim_mapping + if not dims_mapping_list: + return None + length = len(dims_mapping_list[0]) + for dims_mapping in dims_mapping_list: + if dims_mapping is None: + return None + if len(dims_mapping) != length: + return None + compatible_result = [] + for dim_mappings in zip(*dims_mapping_list): + compatible_dim_mapping = compute_compatible_dim_mapping( + list(dim_mappings)) + if compatible_dim_mapping is None: + return None + compatible_result.append(compatible_dim_mapping) + return compatible_result + + +class Completer: + def __init__(self, dist_context): + assert dist_context is not None + self._dist_context = dist_context + + def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): + changed = False + if (not tensor_node.is_var()) or (tensor_node.var() is None): + return False + tensor_desc = tensor_node.var() + # Skip reader tensor + if tensor_desc.type() == core.VarDesc.VarType.READER: + return False + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + assert tensor_dist_attr is not None + if tensor_dist_attr.is_annotated("dims_mapping"): + return False + tensor_dims_mapping = tensor_dist_attr.dims_mapping + if fwd: + dims_mapping_list = [] + for pred_op_node in tensor_node.inputs: + if pred_op_node.op() is not None: + if pred_op_node.op().type() == "create_py_reader" \ + or pred_op_node.op().type() == "create_double_buffer_reader" \ + or pred_op_node.op().type() == "read": + continue + op_dist_attr = self._dist_context.get_op_dist_attr_for_graph( + pred_op_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + dims_mapping_list.append(op_dims_mapping) + dims_mapping_list.append(tensor_dims_mapping) + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != tensor_dims_mapping): + tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True else: - if compatible_dim_mapping != dims_mapping[1]: - dims_mapping[1] = compatible_dim_mapping + dims_mapping_list = [] + for succ_op_node in tensor_node.outputs: + if succ_op_node.op() is not None: + if succ_op_node.op().type() == "create_py_reader" \ + or succ_op_node.op().type() == "create_double_buffer_reader" \ + or succ_op_node.op().type() == "read": + continue + op_dist_attr = self._dist_context.get_op_dist_attr_for_graph( + succ_op_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_desc.name()) + dims_mapping_list.append(op_dims_mapping) + dims_mapping_list.append(tensor_dims_mapping) + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != tensor_dims_mapping): + tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True + return changed - return changed - - -def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node): - """Element-wise operator can be sharded in any way (but should take care of broadcasting).""" - changed = False - if (not op_node.is_op()) or (op_node.op() is None): - return False - op_desc = op_node.op() - op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node) - - input_arg_names = op_desc.input_arg_names() - input_dims_mapping_dict = {} - input_dims_mapping_lens = {} - max_dims_mapping_len = -1 - for arg_name in input_arg_names: - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if max_dims_mapping_len < len(dims_mapping): - max_dims_mapping_len = len(dims_mapping) - input_dims_mapping_dict[arg_name] = dims_mapping - input_dims_mapping_lens[arg_name] = len(dims_mapping) - - dims_mapping_list = [] - for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] - for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i - new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] - dims_mapping_list.append(new_dims_mapping) - else: - dims_mapping_list.append(input_dims_mapping_dict[arg_name]) - output_arg_names = op_desc.output_arg_names() - for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - assert len(dims_mapping) == max_dims_mapping_len - dims_mapping_list.append(dims_mapping) - - compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) - assert compatible_dims_mapping is not None, "There is no compatible dim mapping." - - for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [ - -1 for _ in range(input_dims_mapping_lens[arg_name]) - ] - for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i - new_dims_mapping[i] = compatible_dims_mapping[new_idx] - if new_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + def _update_op_node_dims_mapping(self, op_node, fwd=True): + changed = False + if (not op_node.is_op()) or (op_node.op() is None): + return False + # Skip reader op + op_desc = op_node.op() + if op_desc.type() == "create_py_reader" \ + or op_desc.type() == "create_double_buffer_reader" \ + or op_desc.type() == "read": + return False + dist_op = self._dist_context.get_dist_op_for_graph(op_node) + op_dist_attr = dist_op.dist_attr + if fwd: + for tensor_node in op_node.inputs: + if tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue + tensor_desc = tensor_node.var() + if op_dist_attr.is_annotated_input_dims_mapping( + tensor_desc.name()): + continue + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dims_mapping = tensor_dist_attr.dims_mapping + op_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_desc.name()) + compatible_dims_mapping = compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping]) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != op_dims_mapping): + op_dist_attr.set_input_dims_mapping( + tensor_desc.name(), compatible_dims_mapping) + changed = True + # Find the most compatible implemenetations from the distributed operator + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=True) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx else: - if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, - compatible_dims_mapping) + for tensor_node in op_node.outputs: + if tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue + tensor_desc = tensor_node.var() + if op_dist_attr.is_annotated_output_dims_mapping( + tensor_desc.name()): + continue + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dims_mapping = tensor_dist_attr.dims_mapping + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + compatible_dims_mapping = compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping]) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != op_dims_mapping): + op_dist_attr.set_output_dims_mapping( + tensor_desc.name(), compatible_dims_mapping) + changed = True + # Find the most compatible implemenetations from the distributed operator + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=False) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + return changed - for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if compatible_dims_mapping != dims_mapping: - op_dist_attr.set_output_dims_mapping(arg_name, - compatible_dims_mapping) - changed = True - - return changed - - -def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): - changed = False - if (not tensor_node.is_var()) or (tensor_node.var() is None): - return False - tensor_desc = tensor_node.var() - # Skip reader tensor - if tensor_desc.type() == core.VarDesc.VarType.READER: - return False - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node) - assert tensor_dist_attr is not None - if tensor_dist_attr.is_annotated("dims_mapping"): - return False - tensor_dims_mapping = tensor_dist_attr.dims_mapping - if fwd: - dims_mapping_list = [] - for pred_op_node in tensor_node.inputs: - if pred_op_node.op() is not None: - if pred_op_node.op().type() == "create_py_reader" \ - or pred_op_node.op().type() == "create_double_buffer_reader" \ - or pred_op_node.op().type() == "read": - continue - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - pred_op_node) - op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) - dims_mapping_list.append(op_dims_mapping) - dims_mapping_list.append(tensor_dims_mapping) - compatible_dims_mapping = compute_compatible_dims_mapping( - dims_mapping_list) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != tensor_dims_mapping): - tensor_dist_attr.dims_mapping = compatible_dims_mapping - changed = True - else: - dims_mapping_list = [] - for succ_op_node in tensor_node.outputs: - if succ_op_node.op() is not None: - if succ_op_node.op().type() == "create_py_reader" \ - or succ_op_node.op().type() == "create_double_buffer_reader" \ - or succ_op_node.op().type() == "read": - continue - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - succ_op_node) - op_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_desc.name()) - dims_mapping_list.append(op_dims_mapping) - dims_mapping_list.append(tensor_dims_mapping) - compatible_dims_mapping = compute_compatible_dims_mapping( - dims_mapping_list) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != tensor_dims_mapping): - tensor_dist_attr.dims_mapping = compatible_dims_mapping - changed = True - return changed - - -def update_op_node_dims_mapping(dist_context, op_node, fwd=True): - changed = False - if (not op_node.is_op()) or (op_node.op() is None): - return False - # Skip reader op - op_desc = op_node.op() - if op_desc.type() == "create_py_reader" \ - or op_desc.type() == "create_double_buffer_reader" \ - or op_desc.type() == "read": - return False - dist_op = dist_context.get_dist_op_for_graph(op_node) - op_dist_attr = dist_op.dist_attr - if fwd: - for tensor_node in op_node.inputs: - if tensor_node.var() is not None: - if tensor_node.var().type() == core.VarDesc.VarType.READER: - continue - tensor_desc = tensor_node.var() - if op_dist_attr.is_annotated_input_dims_mapping( - tensor_desc.name()): - continue - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_dims_mapping = tensor_dist_attr.dims_mapping - op_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_desc.name()) - compatible_dims_mapping = compute_compatible_dims_mapping( - [op_dims_mapping, tensor_dims_mapping]) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != op_dims_mapping): - op_dist_attr.set_input_dims_mapping(tensor_desc.name(), - compatible_dims_mapping) - changed = True - # Find the most compatible implemenetations from the distributed operator - op_dist_impl = find_best_compatible_distributed_operator_impl( - dist_op, fwd=True) - assert op_dist_impl is not None, "Cannot find the dist op implementation." - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if op_dist_impl.is_auto_compatible(dist_op): - if op_dist_impl.type == "elementwise": - op_dist_attr.impl_type = "default" - else: - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx - else: - for tensor_node in op_node.outputs: - if tensor_node.var() is not None: - if tensor_node.var().type() == core.VarDesc.VarType.READER: - continue - tensor_desc = tensor_node.var() - if op_dist_attr.is_annotated_output_dims_mapping( - tensor_desc.name()): - continue - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_dims_mapping = tensor_dist_attr.dims_mapping - op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) - compatible_dims_mapping = compute_compatible_dims_mapping( - [op_dims_mapping, tensor_dims_mapping]) - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != op_dims_mapping): - op_dist_attr.set_output_dims_mapping( - tensor_desc.name(), compatible_dims_mapping) - changed = True - # Find the most compatible implemenetations from the distributed operator - op_dist_impl = find_best_compatible_distributed_operator_impl( - dist_op, fwd=False) - assert op_dist_impl is not None, "Cannot find the dist op implementation." - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if op_dist_impl.is_auto_compatible(dist_op): - if op_dist_impl.type == "elementwise": - op_dist_attr.impl_type = "default" + def _update_process_mesh(self): + def _find_nearset_node(nodes, idx): + for node in reversed(nodes[:idx]): + node_dist_attr = self._dist_context.get_dist_attr_for_graph( + node) + if node_dist_attr.process_mesh is not None: + return node + + total_reach_fix_point = False + while not total_reach_fix_point: + total_changed = False + for is_fwd in [True, False]: + all_nodes = self._dist_context.serial_ordered_nodes \ + if is_fwd else reversed(self._dist_context.serial_ordered_nodes) + reach_fix_point = False + while not reach_fix_point: + changed = False + for idx, node in enumerate(all_nodes): + nearest_node = _find_nearset_node( + self._dist_context.serial_ordered_nodes, idx) + if nearest_node is None: + continue + nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph( + nearest_node) + nearest_process_mesh = nearest_node_dis_attr.process_mesh + cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph( + node) + cur_process_mesh = cur_node_dist_attr.process_mesh + compatible_process_mesh = compute_compatible_process_mesh( + [cur_process_mesh, nearest_process_mesh]) + if compatible_process_mesh is not None \ + and cur_process_mesh != compatible_process_mesh: + cur_node_dist_attr.process_mesh = compatible_process_mesh + changed = True + if changed: + reach_fix_point = False + total_changed = True + else: + reach_fix_point = True + if total_changed: + total_reach_fix_point = False else: - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx - return changed - - -def complete_annotation(program, dist_context=None): - """ Complete annotation for the partial annotated program. - - Arguments: - program: partial annotated program. - dist_context: the distributed context is used to store distributed attributes for program. - If not provided, the default one will be used. - Returns: - program: completed annotated program. - """ - - # Use the default distribted context for completeion if there is no one - if dist_context is None: - dist_context = get_default_distributed_context() - dist_context.serial_program = program - else: - dist_context.serial_program = program - - # print_program_with_dist_attr(program, dist_context) - - # Initialize distributed attributes for all var and op node in program - dist_context.init_dist_attr_for_program() - - # Initialize distributed attributes for all var and op node in graph - dist_context.init_dist_attr_for_graph() - - # Complete process mesh for each node - all_nodes = list(dist_context.serial_graph.all_nodes()) + total_reach_fix_point = True - def sort_key_fun(node): - first = -1 - if node.is_op(): - first = 0 - else: - first = 1 - second = -1 - if node.is_op() and node.op() is not None: - second = node.op().id() - if node.is_var() and node.var() is not None: - second = node.var().id() - return (first, second) - - all_nodes.sort(key=sort_key_fun) - - reach_fix_point = False - while not reach_fix_point: - total_changed = False - reach_fwd_fix_point = False - reach_bwd_fix_point = False - while not reach_fwd_fix_point: + def _update_dims_mapping(self): + # Complete dims_mapping for each node + reach_fix_point = False + while not reach_fix_point: changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_process_mesh( - dist_context, node, fwd=True) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_process_mesh( - dist_context, node, fwd=True) - if op_changed: - changed = True + for is_fwd in [True, False]: + all_nodes = self._dist_context.serial_ordered_nodes \ + if is_fwd else reversed(self._dist_context.serial_ordered_nodes) + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_changed = self._update_tensor_node_dims_mapping( + node, fwd=is_fwd) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = self._update_op_node_dims_mapping( + node, fwd=is_fwd) + if op_changed: + changed = True if changed: - reach_fwd_fix_point = False - total_changed = True + reach_fix_point = False else: - reach_fwd_fix_point = True - while not reach_bwd_fix_point: - changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_process_mesh( - dist_context, node, fwd=False) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_process_mesh( - dist_context, node, fwd=False) - if op_changed: - changed = True - if changed: - reach_bwd_fix_point = False - total_changed = True - else: - reach_bwd_fix_point = True - if total_changed: - reach_fix_point = False - else: - reach_fix_point = True - # Validation the completion of process meshes and should be moved to a proper location - is_wrong = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - node) - if tensor_dist_attr.process_mesh is None: - msg_str = "" - for op_node in node.inputs: - if op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - op_node) - msg_str += "{} [{}], ".format( - op_node.op().type(), - op_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format(op_node.name(), - None) - for op_node in node.outputs: - if op_node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph( - op_node) - msg_str += "{} [{}], ".format( - op_node.op().type(), - op_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format(op_node.name(), - None) - msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_tensor api explicitly to annotate it".format( - node.var().name(), msg_str[:-2]) - is_wrong = True - print(msg_str) - if node.is_op() and node.op() is not None: - op_dist_attr = dist_context.get_op_dist_attr_for_graph(node) - if op_dist_attr.process_mesh is None: - msg_str = "" - for tensor_node in node.inputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - msg_str += "{} [{}], ".format( - tensor_node.var().name(), - tensor_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format( - tensor_node.name(), None) - for tensor_node in node.outputs: - if tensor_node.var() is not None: - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - msg_str += "{} [{}], ".format( - tensor_node.var().name(), - tensor_dist_attr.process_mesh) - else: - msg_str += "{} [{}], ".format( - tensor_node.name(), None) - msg_str = "Cannot decide ProcessMesh of {} among {}. Please use shard_op api explicitly to annotate it".format( - node.op().type(), msg_str[:-2]) - is_wrong = True - print(msg_str) - if node.is_op() and node.op() is None: - print("op op is None", node.name()) - if is_wrong: - assert False, "Cannot complete process_meshes of the program." - - # Complete dims_mapping for each node - reach_fix_point = False - while not reach_fix_point: - changed = False - for node in all_nodes: - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_dims_mapping( - dist_context, node, fwd=True) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_dims_mapping( - dist_context, node, fwd=True) - if op_changed: - changed = True - for node in reversed(all_nodes): - if node.is_var() and node.var() is not None: - tensor_changed = update_tensor_node_dims_mapping( - dist_context, node, fwd=False) - if tensor_changed: - changed = True - if node.is_op() and node.op() is not None: - op_changed = update_op_node_dims_mapping( - dist_context, node, fwd=False) - if op_changed: - changed = True - if changed: - reach_fix_point = False - else: - reach_fix_point = True - - # Copy the corresponding distributed attribute from graph to program - dist_context.copy_dist_attr_from_graph_to_program() - dist_context.clear_dist_info_for_graph() - - # Do the validation check and amend some completion - dist_context.amend_dist_attr_for_program() - - # print_program_with_dist_attr(program, dist_context) - dist_context.validate_dist_attr_for_program() + reach_fix_point = True + + def complete_forward_annotation(self, serial_main_program): + """ Complete annotation for the partial annotated serial_main_program. + + Arguments: + serial_main_program: partial annotated serial_main_program. + + Returns: + serial_main_program: completed annotated serial_main_program. + """ + + # Use the default distribted context for completeion if there is no one + self._dist_context.serial_program = serial_main_program + + # Initialize distributed attributes for all var and op node in serial_main_program + self._dist_context.init_dist_attr_for_program() + + # Initialize distributed attributes for all var and op node in graph + self._dist_context.init_dist_attr_for_graph() + + self._update_process_mesh() + + # Complete dims_mapping for each node + self._update_dims_mapping() + + # Copy the corresponding distributed attribute from graph to serial_main_program + self._dist_context.copy_dist_attr_from_graph_to_program() + self._dist_context.clear_dist_info_for_graph() + + # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context) + # Do the validation check and amend some completion + self._dist_context.amend_dist_attr_for_program() + + # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context) + self._dist_context.validate_dist_attr_for_program() + + return serial_main_program + + def complete_backward_annotation(self, serial_main_program): + """Complete the annotation of vars and ops in the backward phase for parallel program.""" + + def _is_grad_var_name(name): + if "@GRAD" in name: + return True + return False + + def _get_forward_varname_from_grad_varname(grad_var_name): + assert _is_grad_var_name( + grad_var_name), "[{}] is not a grad varnme.".format( + grad_var_name) + return grad_var_name[:grad_var_name.find("@GRAD")] + + def _get_op_by_id(ops, id): + for op in ops: + if op.desc.id() == id: + return op + return None + + first_backward_op_idx = -1 + for idx, op in enumerate(serial_main_program.global_block().ops): + if int(op.attr('op_role')) == int( + int(core.op_proto_and_checker_maker.OpRole.Backward) | int( + core.op_proto_and_checker_maker.OpRole.Loss)): + assert op.type == "fill_constant" + first_backward_op_idx = idx + break + + assert first_backward_op_idx >= 0, "No backward procedure found in this program." + + ops = list(serial_main_program.global_block().ops) + vars = serial_main_program.global_block().vars + dist_op_context = self._dist_context.dist_op_context + + for idx in range(first_backward_op_idx, len(ops)): + + # complete the initial grad loss op + if idx == first_backward_op_idx: + assert ops[idx].type == "fill_constant" + assert len( + ops[idx].input_arg_names + ) == 0, "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].input_arg_names)) + assert len( + ops[idx].output_arg_names + ) == 1, "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].output_arg_names)) + + grad_var = vars[ops[idx].output_arg_names[0]] + forward_var_name = _get_forward_varname_from_grad_varname( + grad_var.name) + forward_var = vars[forward_var_name] + + # TODO complete other attribte for grad var + tensor_dist_attr = TensorDistributedAttribute() + process_mesh = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).process_mesh + dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).dims_mapping + tensor_dist_attr.dims_mapping = dims_mapping + tensor_dist_attr.process_mesh = process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + grad_var, tensor_dist_attr) - return program - - -def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): - """Complete the annotation of vars and ops in the backward phase for parallel program.""" - - def _is_grad_var_name(name): - if "@GRAD" in name: - return True - return False - - def _get_forward_varname_from_grad_varname(grad_var_name): - assert _is_grad_var_name( - grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name) - return grad_var_name[:grad_var_name.find("@GRAD")] - - def _get_op_by_id(ops, id): - for op in ops: - if op.desc.id() == id: - return op - return None + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = process_mesh + op_dist_attr.set_output_dims_mapping(grad_var.name, + dims_mapping) + self._dist_context.set_op_dist_attr_for_program(ops[idx], + op_dist_attr) + continue - if dist_context is None: - dist_context = get_default_distributed_context() - - first_backward_op_idx = -1 - for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): - if int(op.attr('op_role')) == int( - int(core.op_proto_and_checker_maker.OpRole.Backward) | int( - core.op_proto_and_checker_maker.OpRole.Loss)): - assert op.type == "fill_constant" - first_backward_op_idx = idx - break - - assert first_backward_op_idx >= 0, "No backward procedure found in this program." - - ops = list(auto_parallel_main_prog.global_block().ops) - vars = auto_parallel_main_prog.global_block().vars - dist_op_context = dist_context.dist_op_context - - for idx in range(first_backward_op_idx, len(ops)): - - # complete the initial grad loss op - if idx == first_backward_op_idx: - assert ops[idx].type == "fill_constant" - assert len( - ops[idx].input_arg_names - ) == 0, "first backward op should has only ONE output, but got [{}]".format( - len(ops[idx].input_arg_names)) - assert len( - ops[idx].output_arg_names - ) == 1, "first backward op should has only ONE output, but got [{}]".format( - len(ops[idx].output_arg_names)) - - grad_var = vars[ops[idx].output_arg_names[0]] - forward_var_name = _get_forward_varname_from_grad_varname( - grad_var.name) - forward_var = vars[forward_var_name] - - # TODO complete other attribte for grad var - tensor_dist_attr = TensorDistributedAttribute() - process_mesh = dist_context.get_tensor_dist_attr_for_program( - forward_var).process_mesh - dims_mapping = dist_context.get_tensor_dist_attr_for_program( - forward_var).dims_mapping - tensor_dist_attr.dims_mapping = dims_mapping - tensor_dist_attr.process_mesh = process_mesh - dist_context.set_tensor_dist_attr_for_program(grad_var, - tensor_dist_attr) - - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = process_mesh - op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping) - dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr) - continue - - # complete the annotation of grad op (xxx_grad op or sum op) - # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id - grad_op = ops[idx] - if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: - # TODO support the case where one forward op corresponding to multiple xxx_grad op - forward_op = _get_op_by_id( - ops[:first_backward_op_idx], - dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) - assert forward_op is not None - - # op dist attr - forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( - forward_op) - forward_op_process_mesh = forward_op_dist_attr.process_mesh - grad_op_dist_attr = OperatorDistributedAttribute() - grad_op_dist_attr.process_mesh = forward_op_process_mesh - - # var - for input_name in grad_op.input_arg_names: - input_var = vars[input_name] - ref_dims_mapping = None - if "@GRAD" in input_name: - forward_name = _get_forward_varname_from_grad_varname( - input_name) - ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( - forward_name) - else: - if forward_op_dist_attr.get_input_dims_mapping(input_name): - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + # complete the annotation of grad op (xxx_grad op or sum op) + # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id + grad_op = ops[idx] + if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: + # TODO support the case where one forward op corresponding to multiple xxx_grad op + forward_op = _get_op_by_id( + ops[:first_backward_op_idx], + dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()]) + assert forward_op is not None + + # op dist attr + forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( + forward_op) + forward_op_process_mesh = forward_op_dist_attr.process_mesh + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = forward_op_process_mesh + + # var + for input_name in grad_op.input_arg_names: + input_var = vars[input_name] + ref_dims_mapping = None + if "@GRAD" in input_name: + forward_name = _get_forward_varname_from_grad_varname( input_name) - else: ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( - input_name) - - assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( - input_var.name) - grad_op_dist_attr.set_input_dims_mapping(input_name, - ref_dims_mapping) - - for output_name in grad_op.desc.output_names(): - assert len(grad_op.desc.output(output_name)) in [0, 1] - if _is_grad_var_name(output_name): - input_name = _get_forward_varname_from_grad_varname( - output_name) - else: - assert grad_op.type in [ - "cast", "c_identity", "c_allreduce_sum" - ] - input_name = "X" - assert input_name in forward_op.desc.input_names( - ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( - output_name, grad_op.type, input_name) - if len(grad_op.desc.output(output_name)) == 1: - # tensor dist attr - output_var = vars[grad_op.desc.output(output_name)[0]] - forward_name = _get_forward_varname_from_grad_varname( - output_var.name) - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( - forward_name) - - output_var_dist_attr = TensorDistributedAttribute() - output_var_dist_attr.dims_mapping = ref_dims_mapping - output_var_dist_attr.process_mesh = forward_op_process_mesh - dist_context.set_tensor_dist_attr_for_program( - output_var, output_var_dist_attr) - - grad_op_dist_attr.set_output_dims_mapping(output_var.name, - ref_dims_mapping) - - dist_context.set_op_dist_attr_for_program(grad_op, - grad_op_dist_attr) - - # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id - else: - assert grad_op.type == "sum", "got unexpect op [{}]".format( - str(grad_op.type)) - assert all(map(_is_grad_var_name, grad_op.input_arg_names)) - assert len(grad_op.output_arg_names) == 1 - - ref_forward_var_name = _get_forward_varname_from_grad_varname( - grad_op.output_arg_names[0]) - forward_var = vars[ref_forward_var_name] - ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program( - forward_var).dims_mapping - ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program( - forward_var).process_mesh - - # output - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping - tensor_dist_attr.process_mesh = ref_forward_var_process_mesh - dist_context.set_tensor_dist_attr_for_program( - vars[grad_op.output_arg_names[0]], tensor_dist_attr) - - # op - grad_op_dist_attr = OperatorDistributedAttribute() - grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh - for var_name in grad_op.input_arg_names: - assert _get_forward_varname_from_grad_varname( - var_name) == ref_forward_var_name - grad_op_dist_attr.set_input_dims_mapping( - var_name, ref_forward_var_dims_mapping) - - grad_op_dist_attr.set_output_dims_mapping( - grad_op.output_arg_names[0], ref_forward_var_dims_mapping) - dist_context.set_op_dist_attr_for_program(grad_op, - grad_op_dist_attr) - - -def complete_update_annotation(auto_parallel_main_prog, dist_context): - """Complete the annotation of vars and ops in the update phase for parallel program.""" - - if dist_context is None: - dist_context = get_default_distributed_context() - - ops = list(auto_parallel_main_prog.global_block().ops) - vars = auto_parallel_main_prog.global_block().vars - learning_rate_completed = False - - for idx in range(len(ops)): - - # complete the annotation of the optimizer op. - # TODO to add attribute for moment var - op = ops[idx] - if int(op.attr('op_role')) == int(OpRole.Optimize): - if op.type == "clip_by_norm": - - param_grad = vars[op.input("X")[0]] - param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program( - param_grad) - assert param_grad_dist_attr is not None - ref_process_mesh = param_grad_dist_attr.process_mesh - ref_dims_mapping = param_grad_dist_attr.dims_mapping - - out = vars[op.output("Out")[0]] - out_dist_attr = TensorDistributedAttribute() - out_dist_attr.process_mesh = ref_process_mesh - out_dist_attr.dims_mapping = ref_dims_mapping - dist_context.set_tensor_dist_attr_for_program(out, - out_dist_attr) + forward_name) + else: + if forward_op_dist_attr.get_input_dims_mapping( + input_name): + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + input_name) + else: + ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping( + input_name) + + assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( + input_var.name) + grad_op_dist_attr.set_input_dims_mapping(input_name, + ref_dims_mapping) - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = ref_process_mesh - op_dist_attr.set_input_dist_attr(param_grad.name, - param_grad_dist_attr) - op_dist_attr.set_output_dist_attr(out.name, out_dist_attr) - dist_context.set_op_dist_attr_for_program(op, op_dist_attr) - - if "Grad" in op.input_names and "Param" in ops[idx].input_names: - assert len(op.input( - "Param")) == 1, "Only support one-to-one now." - assert len(op.input( - "Grad")) == 1, "Only support one-to-one now." - param = vars[op.input("Param")[0]] - grad_var = vars[op.input("Grad")[0]] - - param_dist_attr = dist_context.get_tensor_dist_attr_for_program( - param) - assert param_dist_attr is not None - ref_process_mesh = dist_context.get_tensor_dist_attr_for_program( - param).process_mesh - assert ref_process_mesh is not None - ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program( - param).dims_mapping - assert ref_dims_mapping is not None - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = ref_process_mesh - op_dist_attr.set_input_dims_mapping(grad_var.name, - ref_dims_mapping) - op_dist_attr.set_input_dims_mapping(param.name, - ref_dims_mapping) - op_dist_attr.set_output_dims_mapping(param.name, - ref_dims_mapping) - learning_var = vars[op.input("LearningRate")[0]] - op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) - op_dist_attr.set_output_dims_mapping(learning_var.name, [-1]) - - if not learning_rate_completed: - learning_rate_completed = True - var_dist_attr = TensorDistributedAttribute() - var_dist_attr.process_mesh = ref_process_mesh - var_dist_attr.dims_mapping = [-1] - dist_context.set_tensor_dist_attr_for_program(learning_var, - var_dist_attr) - - for input_name in op.desc.input_names(): - - if input_name in [ - 'Param', 'Grad', 'LearningRate', "SkipUpdate", - "Beta1Tensor", "Beta2Tensor", "EpsilonTensor", - "MasterParam" - ]: - continue + for output_name in grad_op.desc.output_names(): + assert len(grad_op.desc.output(output_name)) in [0, 1] + if _is_grad_var_name(output_name): + input_name = _get_forward_varname_from_grad_varname( + output_name) + else: + assert grad_op.type in [ + "cast", "c_identity", "c_allreduce_sum" + ] + input_name = "X" + assert input_name in forward_op.desc.input_names( + ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format( + output_name, grad_op.type, input_name) + if len(grad_op.desc.output(output_name)) == 1: + # tensor dist attr + output_var = vars[grad_op.desc.output(output_name)[0]] + forward_name = _get_forward_varname_from_grad_varname( + output_var.name) + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + forward_name) - assert len(op.desc.input(input_name)) == 1 - input_var = vars[op.desc.input(input_name)[0]] - input_var_attr = TensorDistributedAttribute() + output_var_dist_attr = TensorDistributedAttribute() + output_var_dist_attr.dims_mapping = ref_dims_mapping + output_var_dist_attr.process_mesh = forward_op_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + output_var, output_var_dist_attr) - if "Beta1Pow" in input_name or "Beta2Pow" in input_name: - input_var_attr.dims_mapping = [-1] - op_dist_attr.set_input_dims_mapping(input_var.name, - [-1]) - op_dist_attr.set_output_dims_mapping(input_var.name, - [-1]) - else: - assert "Moment" in input_name - input_var_attr.dims_mapping = ref_dims_mapping - op_dist_attr.set_input_dims_mapping(input_var.name, - ref_dims_mapping) - op_dist_attr.set_output_dims_mapping(input_var.name, - ref_dims_mapping) + grad_op_dist_attr.set_output_dims_mapping( + output_var.name, ref_dims_mapping) - input_var_attr.process_mesh = ref_process_mesh - dist_context.set_tensor_dist_attr_for_program( - input_var, input_var_attr) + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) - dist_context.set_op_dist_attr_for_program(op, op_dist_attr) - continue + # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id + else: + assert grad_op.type == "sum", "got unexpect op [{}]".format( + str(grad_op.type)) + assert all(map(_is_grad_var_name, grad_op.input_arg_names)) + assert len(grad_op.output_arg_names) == 1 + + ref_forward_var_name = _get_forward_varname_from_grad_varname( + grad_op.output_arg_names[0]) + forward_var = vars[ref_forward_var_name] + ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).dims_mapping + ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( + forward_var).process_mesh + + # output + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping + tensor_dist_attr.process_mesh = ref_forward_var_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + vars[grad_op.output_arg_names[0]], tensor_dist_attr) + + # op + grad_op_dist_attr = OperatorDistributedAttribute() + grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh + for var_name in grad_op.input_arg_names: + assert _get_forward_varname_from_grad_varname( + var_name) == ref_forward_var_name + grad_op_dist_attr.set_input_dims_mapping( + var_name, ref_forward_var_dims_mapping) + + grad_op_dist_attr.set_output_dims_mapping( + grad_op.output_arg_names[0], ref_forward_var_dims_mapping) + self._dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr) + + def complete_update_annotation(self, serial_main_program): + """Complete the annotation of vars and ops in the update phase for parallel program.""" + ops = list(serial_main_program.global_block().ops) + vars = serial_main_program.global_block().vars + learning_rate_completed = False + + for idx in range(len(ops)): + + # complete the annotation of the optimizer op. + # TODO to add attribute for moment var + op = ops[idx] + if int(op.attr('op_role')) == int(OpRole.Optimize): + + if "Grad" in op.input_names and "Param" in ops[idx].input_names: + assert len(op.input( + "Param")) == 1, "Only support one-to-one now." + assert len(op.input( + "Grad")) == 1, "Only support one-to-one now." + param = vars[op.input("Param")[0]] + grad_var = vars[op.input("Grad")[0]] + + param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + param) + assert param_dist_attr is not None + ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( + param).process_mesh + assert ref_process_mesh is not None + ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( + param).dims_mapping + assert ref_dims_mapping is not None + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dims_mapping(grad_var.name, + ref_dims_mapping) + op_dist_attr.set_input_dims_mapping(param.name, + ref_dims_mapping) + op_dist_attr.set_output_dims_mapping(param.name, + ref_dims_mapping) + learning_var = vars[op.input("LearningRate")[0]] + op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) + op_dist_attr.set_output_dims_mapping(learning_var.name, + [-1]) + + if not learning_rate_completed: + learning_rate_completed = True + var_dist_attr = TensorDistributedAttribute() + var_dist_attr.process_mesh = ref_process_mesh + var_dist_attr.dims_mapping = [-1] + self._dist_context.set_tensor_dist_attr_for_program( + learning_var, var_dist_attr) + + for input_name in op.desc.input_names(): + + if input_name in [ + 'Param', 'Grad', 'LearningRate', "SkipUpdate", + "Beta1Tensor", "Beta2Tensor", "EpsilonTensor", + "MasterParam" + ]: + continue + + assert len(op.desc.input(input_name)) == 1 + input_var = vars[op.desc.input(input_name)[0]] + input_var_attr = TensorDistributedAttribute() + + if "Beta1Pow" in input_name or "Beta2Pow" in input_name: + input_var_attr.dims_mapping = [-1] + op_dist_attr.set_input_dims_mapping(input_var.name, + [-1]) + op_dist_attr.set_output_dims_mapping(input_var.name, + [-1]) + else: + assert "Moment" in input_name + input_var_attr.dims_mapping = ref_dims_mapping + op_dist_attr.set_input_dims_mapping( + input_var.name, ref_dims_mapping) + op_dist_attr.set_output_dims_mapping( + input_var.name, ref_dims_mapping) + + input_var_attr.process_mesh = ref_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + input_var, input_var_attr) + + self._dist_context.set_op_dist_attr_for_program( + op, op_dist_attr) + continue diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index ad3a53ff17d76..e06811df88179 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -247,23 +247,23 @@ def get_op_dist_attr_for_graph(self, serial_op_node): # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op - # def get_dist_attr_for_graph(self, serial_node): - # if serial_node.is_var() and serial_node.var() is not None: - # serial_tensor_node_id = serial_node.id() - # dist_tensor = self._dist_tensors_for_graph.get( - # serial_tensor_node_id, None) - # if dist_tensor: - # return dist_tensor.dist_attr - # else: - # return None - # if serial_node.is_op() and serial_node.op() is not None: - # serial_op_node_id = serial_node.id() - # dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) - # if dist_op: - # return dist_op.dist_attr - # else: - # return None - # return None + def get_dist_attr_for_graph(self, serial_node): + if serial_node.is_var() and serial_node.var() is not None: + serial_tensor_node_id = serial_node.id() + dist_tensor = self._dist_tensors_for_graph.get( + serial_tensor_node_id, None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + if serial_node.is_op() and serial_node.op() is not None: + serial_op_node_id = serial_node.id() + dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + return None def init_dist_attr_for_program(self): assert self._serial_program, \ diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index d6035d02953ac..43f5fa264790f 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -32,7 +32,7 @@ from .dist_context import DistributedContext from .dist_context import get_default_distributed_context from .dist_context import set_default_distributed_context -from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation +from .completion import Completer from .partitioner import Partitioner from .process_group import get_all_process_groups from .process_group import get_process_group @@ -130,8 +130,8 @@ def _generate_backward(self, main_program, startup_program, loss, no_grad_set, callbacks, distop_context=self._dist_context.dist_op_context) - complete_backward_annotation( - main_program, dist_context=self._dist_context) + self._completer = Completer(self._dist_context) + self._completer.complete_backward_annotation(main_program) return params_grads @@ -142,8 +142,8 @@ def _apply_optimize(self, main_program, startup_program, params_grads): params_grads) # update completion - complete_update_annotation( - main_program, dist_context=self._dist_context) + self._completer = Completer(self._dist_context) + self._completer.complete_update_annotation(main_program) return optimize_ops @@ -179,8 +179,9 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): # Annotation completion self._dist_context = DistributedContext() _logger.info("Start annotation dist attr.") - completed_main_program = complete_annotation(serial_main_program, - self._dist_context) + self._completer = Completer(self._dist_context) + completed_main_program = self._completer.complete_forward_annotation( + serial_main_program) else: completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 05d71aca5db2c..bc4f1671f4e20 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -27,6 +27,7 @@ from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix @@ -154,10 +155,9 @@ def test_mlp_dp(self): dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_mlp_mp(self): @@ -171,10 +171,9 @@ def test_mlp_mp(self): dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_mlp_dp_mp(self): @@ -189,10 +188,9 @@ def test_mlp_dp_mp(self): dist_context = DistributedContext() train_program, start_program = mlp_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) # def test_mlp_misc(self): @@ -212,8 +210,8 @@ def test_mlp_dp_mp(self): # train_program, start_program = mlp_pretrain_forward(train_program, # start_program) # # pdb.set_trace() - # complete_train_program = auto.complete_annotation(train_program, - # dist_context) + # completer = Completer(dist_context) + # complete_train_program = auto.completer.complete_forward_annotation(train_program) # # print_program_with_dist_attr(complete_train_program, # # dist_context) # dist_context.finalize_distributed_attr_for_program( @@ -423,8 +421,9 @@ def test_attn_dp(self): dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) # print_program_with_dist_attr(complete_train_program, # dist_context) self.assertTrue(dist_context.validate_dist_attr_for_program()) @@ -440,10 +439,9 @@ def test_attn_mp(self): dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_attn_dp_mp(self): @@ -458,10 +456,9 @@ def test_attn_dp_mp(self): dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) @@ -747,10 +744,9 @@ def test_decoder_dp(self): dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_decoder_mp(self): @@ -764,10 +760,9 @@ def test_decoder_mp(self): dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_decoder_dp_mp(self): @@ -782,10 +777,9 @@ def test_decoder_dp_mp(self): dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py index c2c1e63155c3a..1293a9644027d 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py @@ -31,6 +31,7 @@ from paddle.distributed.fleet import fleet import paddle.static as static import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.dist_context import DistributedContext @@ -817,10 +818,9 @@ def test_gpt_dp(self): dist_context = DistributedContext() train_program, start_program = gpt_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_gpt_mp(self): @@ -834,10 +834,9 @@ def test_gpt_mp(self): dist_context = DistributedContext() train_program, start_program = gpt_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_gpt_dp_mp(self): @@ -852,10 +851,9 @@ def test_gpt_dp_mp(self): dist_context = DistributedContext() train_program, start_program = gpt_pretrain_forward(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) - # print_program_with_dist_attr(complete_train_program, - # dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) self.assertTrue(dist_context.validate_dist_attr_for_program()) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index 83254de61298b..fd19a5bd8b866 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -23,6 +23,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner @@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py index b21cbb5ae78bc..27de9f325063b 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -18,6 +18,7 @@ import paddle from paddle.fluid import core import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner @@ -42,8 +43,9 @@ def get_dist_prog(train_program, parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation( - train_program, dist_context + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program ) if complete_train_program is None else complete_train_program # parallelizer._apply_serial_forward_pass(complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 3a28595c833e0..9d4de771076cd 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -36,6 +36,7 @@ from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.partitioner import Partitioner @@ -433,6 +434,12 @@ def forward(self, input): out = F.gelu(out, approximate=True) out = self.linear1(out) + auto.shard_tensor( + out, + dist_attr={ + "process_mesh": _global_process_mesh[1], + "dims_mapping": [0, -1] + }) out = self.linear2(out) out = F.gelu(out, approximate=True) out = self.linear3(out) @@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # auto completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 21cf8a904b690..deff2144411fc 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -28,6 +28,7 @@ from paddle.fluid import layers from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix @@ -49,8 +50,9 @@ def get_programs(annotated_func): global _global_process_mesh dist_context.process_mesh = _global_process_mesh train_program, start_program = annotated_func(train_program, start_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) rank_id = 3 dist_strategy = fleet.DistributedStrategy() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index dc2ad1d900f52..01e62d886e2b7 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -31,6 +31,7 @@ from paddle.distributed import fleet import paddle.static as static import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr from paddle.distributed.auto_parallel.dist_context import DistributedContext @@ -881,8 +882,9 @@ def test_gpt_dp_mp(self): dist_context.process_mesh = _global_process_mesh train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) # serial backward pass params_grads = parallelizer._generate_backward( @@ -913,8 +915,9 @@ def test_gpt_dp_mp(self): "w") as fw: fw.write(str(auto_parallel_startup_prog)) # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: - # from paddle.distributed.auto_parallel.completion import complete_backward_annotation - # complete_backward_annotation(auto_parallel_main_prog) + # from paddle.distributed.auto_parallel.completion import Completer + # completer = Completer() + # completer.complete_forward_annotation(auto_parallel_main_prog) # fw.write(str(auto_parallel_main_prog)) nrank = 4 # col parallel diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 614b996d26521..b234e25823f4b 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer @@ -152,8 +153,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, @@ -299,7 +301,6 @@ def test_mlp_pp(self): for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) - # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index cfbb7653fad8e..40847a769033a 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer @@ -116,8 +117,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 272c1c212f08e..869bcd4c7ab32 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -22,6 +22,7 @@ import paddle.nn.functional as F import paddle.utils as utils import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer @@ -132,8 +133,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): parallelizer._dist_context = dist_context # serial forward & backward completion - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) params_grads = parallelizer._generate_backward( complete_train_program, @@ -263,8 +265,9 @@ def test_allgather(self): dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id) - complete_train_program = auto.complete_annotation(train_program, - dist_context) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py index ed64fa0630fa1..78ad64b1dd852 100755 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -154,7 +154,7 @@ def test_update(self): ops = train_program.global_block().ops vars = train_program.global_block().vars from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container - from paddle.distributed.auto_parallel.completion import is_elementwise_like_op + from paddle.distributed.auto_parallel.operators.common import is_elementwise_op from paddle.distributed.auto_parallel.dist_op import DistributedOperator for op in ops: @@ -163,7 +163,7 @@ def test_update(self): if dist_op_impl_container is None: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_op = DistributedOperator(op, op_dist_attr) - if is_elementwise_like_op(op.type): + if is_elementwise_op(op.type): changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) self.assertFalse(changed)