Skip to content

Commit

Permalink
Fuse duplicate fused-elementwise ops (#915)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #915

### Why?
~~The MTML Dedicated Arch has `FusedElementwise` operators that have the same input tensor and the same elementwise operator. ~~

Originally, this pass was implemented for the MTML Dedicated Arch. However, the model has multiple `fused_elementwise` ops with the same input tensor and same elementwise-ops but different input tensor accessors (they read from different offsets and have different read sizes).

### What/How?
In this diff, we add a fusion pass to get rid of this redundancy. We do this by:
* Take all `FusedElementwise` operators that are equal (same inputs, input accessors and elementwise-ops).
* Remove all but one.
* Transfer any outputs & output_accessors to the remaining operator.
* Update src_ops / dst_ops for affected tensors.

### Example
Here's a real example from the MTML Dedicated Arch memory planning graph.
```
### Before the fusion pass ###
# fused_elementwise_1033
(Tensor(name=elementwise_99_0, shape=[batch_size, 128]))     # output tensor
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
    Tensor(name=elementwise_94_0, shape=[batch_size, 640])   # input tensor
)

# fused_elementwise_1034
(Tensor(name=elementwise_101_0, shape=[batch_size, 128]))
# same elementwise ops
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(             # same elementwise op
    Tensor(name=elementwise_94_0, shape=[batch_size, 640])   # same input tensor
)

### After the fusion pass ###
# fused_elementwise_1033
(
    Tensor(name=elementwise_99_0, shape=[batch_size, 128]),  # original output tensor
    Tensor(name=elementwise_101_0, shape=[batch_size, 128])  # output added from fusion
)
= fused_elementwise(func=[<FuncEnum.RELU: 18>])(
    Tensor(name=elementwise_94_0, shape=[batch_size, 640])
)
```

 ---

Reviewed By: chenyang78

Differential Revision: D48359513

fbshipit-source-id: d2cdcb7991dfc77ff0e6528e3ffb6395c869e065
  • Loading branch information
ColinPeppler authored and facebook-github-bot committed Aug 28, 2023
1 parent dd44640 commit c60dc19
Show file tree
Hide file tree
Showing 5 changed files with 481 additions and 1 deletion.
10 changes: 10 additions & 0 deletions python/aitemplate/backend/common/elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,16 @@ def _gen_write_outputs_str(
read_t=fused_elementwise_metadata.max_read_t,
data_idx=index_variable,
)

# This is for fusing duplicate fused-elementwise ops. The newly fused op
# will have multiple outputs but only a single original output. Allowing
# us to calculate the original output once and re-use it for all outputs.
if (
len(fused_elementwise_metadata.original_outputs) == 1
and len(fused_elementwise_metadata.outputs) > 1
):
output_idx = 0

write_out = KERNEL_WRITE_OUTPUT_TEMPLATE.render(
get_strided_address=get_strided_addr_str,
output_name=output_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import itertools
import logging
from collections import defaultdict
from typing import Dict, List

from aitemplate.compiler.base import Operator, Tensor
from aitemplate.compiler.transform import transform_utils
from aitemplate.utils import graph_utils


_LOGGER = logging.getLogger(__name__)


def _fused_elementwise_ops_are_equal(op1: Operator, op2: Operator) -> bool:
"""We consider two fused elementwise to be duplicates when:
2. Their elementwise operations are the same.
1. And their inputs accessors are the same.
NOTE: We assume the inputs are in the same order as the sub-elementwise
operations. Otherwise, this is a problem because some elementwise operations
are non-commutative.
"""
op1_elementwise_ops = op1._attrs["elementwise_ops"]
op2_elementwise_ops = op2._attrs["elementwise_ops"]
op1_inputs, op2_inputs = op1._attrs["inputs"], op2._attrs["inputs"]
op1_input_accessors = op1._attrs["input_accessors"]
op2_input_accessors = op2._attrs["input_accessors"]
if (
len(op1_elementwise_ops) != len(op2_elementwise_ops)
or len(op1_inputs) != len(op2_inputs)
or len(op1_input_accessors) != len(op2_input_accessors)
):
return False

are_elementwise_equal = all(
a._attrs["func"] == b._attrs["func"]
for a, b in zip(op1_elementwise_ops, op2_elementwise_ops)
)
are_input_accessors_equal = all(
input1 == input2 and input_accessor1 == input_accessor2
for input1, input2, input_accessor1, input_accessor2 in zip(
op1_inputs, op2_inputs, op1_input_accessors, op2_input_accessors
)
)
return are_elementwise_equal and are_input_accessors_equal


def find_duplicate_fused_elementwise(
sorted_graph: List[Tensor],
) -> Dict[Operator, List[Operator]]:
sorted_ops = graph_utils.get_sorted_ops(sorted_graph)
fused_elementwise_ops = filter(
lambda operator: operator._attrs["op"] == "fused_elementwise", sorted_ops
)
visited = set()
fusion_groups = defaultdict(list)

for op1, op2 in itertools.combinations(fused_elementwise_ops, 2):
if op1 in visited or op2 in visited:
continue
if _fused_elementwise_ops_are_equal(op1, op2):
fusion_groups[op1].append(op2)
visited.add(op2)

return fusion_groups


def fuse_duplicate_fused_elementwise(
sorted_graph: List[Tensor], _workdir: str
) -> List[Tensor]:
"""This pass finds all duplicate fused elementwise ops and fuses them once
more. It assumes any fuse elementwise passes are complete.
We do the fusion by taking all the duplicate fused elementwise ops and
effectively deleting all but one. We make sure to transfer the outputs and
output_accessors of the duplicate ops to the remaining op. That means, the
newly fused op will have multiple outputs.
Parameters
----------
sorted_graph : List[Tensor]
Input graph
_workdir : str
Required by optimize_graph.py
Returns
----------
sorted_graph : List[Tensor]
Modified input graph with duplicate fused elementwise ops fused together.
"""

fusion_groups = find_duplicate_fused_elementwise(sorted_graph)
for primary_op, duplicate_ops in fusion_groups.items():
# Primary op inherits the outputs from the duplicate ops.

for key in ("outputs", "output_accessors"):
duplicate_ops_outputs = [
output for op in duplicate_ops for output in op._attrs[key]
]
primary_op._attrs[key] += duplicate_ops_outputs
if key != "outputs":
continue

# Make sure to update src_ops in the output tensors.
for output_tensor in duplicate_ops_outputs:
old_src_ops = output_tensor._attrs["src_ops"]
output_tensor._attrs["src_ops"] = set(old_src_ops) - set(
duplicate_ops
) | {primary_op}

# Make sure to update dst_ops in the input tensors.
for input_tensor in primary_op._attrs["inputs"]:
input_tensor._attrs["dst_ops"] = set(
input_tensor._attrs["dst_ops"]
) - set(duplicate_ops)

# Assumption: If the input accessors are the same, then the output's
# original shape must be the same.
prev_shape = primary_op._attrs["output_accessors"][0].original_shapes
for output_accessor in primary_op._attrs["output_accessors"]:
shape = output_accessor.original_shapes
assert (
prev_shape == shape
), "Output shapes mismatch in fuse_duplicate_fused_elementwise: {}, {}".format(
prev_shape, shape
)
prev_shape = shape

_LOGGER.info(
"Fusing {} with {}".format(
primary_op._attrs["name"],
", ".join([dup_op._attrs["name"] for dup_op in duplicate_ops]),
)
)

return transform_utils.sanitize_sorted_graph(sorted_graph)
6 changes: 6 additions & 0 deletions python/aitemplate/compiler/transform/optimize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from aitemplate.compiler.transform.dedup_make_jagged_ops import dedup_make_jagged_ops
from aitemplate.compiler.transform.fuse_bmm_permute import fuse_bmm_permute
from aitemplate.compiler.transform.fuse_conv_elementwise import fuse_conv_elementwise
from aitemplate.compiler.transform.fuse_duplicate_fused_elementwise import (
fuse_duplicate_fused_elementwise,
)
from aitemplate.compiler.transform.fuse_expand_bmm import fuse_expand_bmm
from aitemplate.compiler.transform.fuse_group_ops import fuse_group_ops
from aitemplate.compiler.transform.fuse_mm_elementwise import fuse_mm_elementwise
Expand Down Expand Up @@ -127,6 +130,9 @@ def optimize_graph(
transform_permute_to_reshape,
transform_memory_ops,
eliminate_permutations,
# fuse_duplicate_fused_elementwise must run after elementwise fusion and
# after passes that modify/replace a fused_elementwise's input/output accessor.
fuse_duplicate_fused_elementwise,
]

if not optimize:
Expand Down
Loading

0 comments on commit c60dc19

Please sign in to comment.