From c60dc19788217556ba12ea378c02b9fd0aea9ffe Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Mon, 28 Aug 2023 07:43:41 -0700 Subject: [PATCH] Fuse duplicate fused-elementwise ops (#915) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/915 ### Why? ~~The MTML Dedicated Arch has `FusedElementwise` operators that have the same input tensor and the same elementwise operator. ~~ Originally, this pass was implemented for the MTML Dedicated Arch. However, the model has multiple `fused_elementwise` ops with the same input tensor and same elementwise-ops but different input tensor accessors (they read from different offsets and have different read sizes). ### What/How? In this diff, we add a fusion pass to get rid of this redundancy. We do this by: * Take all `FusedElementwise` operators that are equal (same inputs, input accessors and elementwise-ops). * Remove all but one. * Transfer any outputs & output_accessors to the remaining operator. * Update src_ops / dst_ops for affected tensors. ### Example Here's a real example from the MTML Dedicated Arch memory planning graph. ``` ### Before the fusion pass ### # fused_elementwise_1033 (Tensor(name=elementwise_99_0, shape=[batch_size, 128])) # output tensor = fused_elementwise(func=[])( Tensor(name=elementwise_94_0, shape=[batch_size, 640]) # input tensor ) # fused_elementwise_1034 (Tensor(name=elementwise_101_0, shape=[batch_size, 128])) # same elementwise ops = fused_elementwise(func=[])( # same elementwise op Tensor(name=elementwise_94_0, shape=[batch_size, 640]) # same input tensor ) ### After the fusion pass ### # fused_elementwise_1033 ( Tensor(name=elementwise_99_0, shape=[batch_size, 128]), # original output tensor Tensor(name=elementwise_101_0, shape=[batch_size, 128]) # output added from fusion ) = fused_elementwise(func=[])( Tensor(name=elementwise_94_0, shape=[batch_size, 640]) ) ``` --- Reviewed By: chenyang78 Differential Revision: D48359513 fbshipit-source-id: d2cdcb7991dfc77ff0e6528e3ffb6395c869e065 --- .../backend/common/elementwise_common.py | 10 + .../fuse_duplicate_fused_elementwise.py | 151 +++++++++ .../compiler/transform/optimize_graph.py | 6 + .../test_fuse_duplicate_fused_elementwise.py | 313 ++++++++++++++++++ tests/unittest/compiler/test_move_view_ops.py | 2 +- 5 files changed, 481 insertions(+), 1 deletion(-) create mode 100644 python/aitemplate/compiler/transform/fuse_duplicate_fused_elementwise.py create mode 100644 tests/unittest/compiler/test_fuse_duplicate_fused_elementwise.py diff --git a/python/aitemplate/backend/common/elementwise_common.py b/python/aitemplate/backend/common/elementwise_common.py index 71d7101a1..4f73a0960 100644 --- a/python/aitemplate/backend/common/elementwise_common.py +++ b/python/aitemplate/backend/common/elementwise_common.py @@ -1101,6 +1101,16 @@ def _gen_write_outputs_str( read_t=fused_elementwise_metadata.max_read_t, data_idx=index_variable, ) + + # This is for fusing duplicate fused-elementwise ops. The newly fused op + # will have multiple outputs but only a single original output. Allowing + # us to calculate the original output once and re-use it for all outputs. + if ( + len(fused_elementwise_metadata.original_outputs) == 1 + and len(fused_elementwise_metadata.outputs) > 1 + ): + output_idx = 0 + write_out = KERNEL_WRITE_OUTPUT_TEMPLATE.render( get_strided_address=get_strided_addr_str, output_name=output_name, diff --git a/python/aitemplate/compiler/transform/fuse_duplicate_fused_elementwise.py b/python/aitemplate/compiler/transform/fuse_duplicate_fused_elementwise.py new file mode 100644 index 000000000..1fc7307c1 --- /dev/null +++ b/python/aitemplate/compiler/transform/fuse_duplicate_fused_elementwise.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import itertools +import logging +from collections import defaultdict +from typing import Dict, List + +from aitemplate.compiler.base import Operator, Tensor +from aitemplate.compiler.transform import transform_utils +from aitemplate.utils import graph_utils + + +_LOGGER = logging.getLogger(__name__) + + +def _fused_elementwise_ops_are_equal(op1: Operator, op2: Operator) -> bool: + """We consider two fused elementwise to be duplicates when: + 2. Their elementwise operations are the same. + 1. And their inputs accessors are the same. + + NOTE: We assume the inputs are in the same order as the sub-elementwise + operations. Otherwise, this is a problem because some elementwise operations + are non-commutative. + """ + op1_elementwise_ops = op1._attrs["elementwise_ops"] + op2_elementwise_ops = op2._attrs["elementwise_ops"] + op1_inputs, op2_inputs = op1._attrs["inputs"], op2._attrs["inputs"] + op1_input_accessors = op1._attrs["input_accessors"] + op2_input_accessors = op2._attrs["input_accessors"] + if ( + len(op1_elementwise_ops) != len(op2_elementwise_ops) + or len(op1_inputs) != len(op2_inputs) + or len(op1_input_accessors) != len(op2_input_accessors) + ): + return False + + are_elementwise_equal = all( + a._attrs["func"] == b._attrs["func"] + for a, b in zip(op1_elementwise_ops, op2_elementwise_ops) + ) + are_input_accessors_equal = all( + input1 == input2 and input_accessor1 == input_accessor2 + for input1, input2, input_accessor1, input_accessor2 in zip( + op1_inputs, op2_inputs, op1_input_accessors, op2_input_accessors + ) + ) + return are_elementwise_equal and are_input_accessors_equal + + +def find_duplicate_fused_elementwise( + sorted_graph: List[Tensor], +) -> Dict[Operator, List[Operator]]: + sorted_ops = graph_utils.get_sorted_ops(sorted_graph) + fused_elementwise_ops = filter( + lambda operator: operator._attrs["op"] == "fused_elementwise", sorted_ops + ) + visited = set() + fusion_groups = defaultdict(list) + + for op1, op2 in itertools.combinations(fused_elementwise_ops, 2): + if op1 in visited or op2 in visited: + continue + if _fused_elementwise_ops_are_equal(op1, op2): + fusion_groups[op1].append(op2) + visited.add(op2) + + return fusion_groups + + +def fuse_duplicate_fused_elementwise( + sorted_graph: List[Tensor], _workdir: str +) -> List[Tensor]: + """This pass finds all duplicate fused elementwise ops and fuses them once + more. It assumes any fuse elementwise passes are complete. + + We do the fusion by taking all the duplicate fused elementwise ops and + effectively deleting all but one. We make sure to transfer the outputs and + output_accessors of the duplicate ops to the remaining op. That means, the + newly fused op will have multiple outputs. + + Parameters + ---------- + sorted_graph : List[Tensor] + Input graph + + _workdir : str + Required by optimize_graph.py + + Returns + ---------- + sorted_graph : List[Tensor] + Modified input graph with duplicate fused elementwise ops fused together. + """ + + fusion_groups = find_duplicate_fused_elementwise(sorted_graph) + for primary_op, duplicate_ops in fusion_groups.items(): + # Primary op inherits the outputs from the duplicate ops. + + for key in ("outputs", "output_accessors"): + duplicate_ops_outputs = [ + output for op in duplicate_ops for output in op._attrs[key] + ] + primary_op._attrs[key] += duplicate_ops_outputs + if key != "outputs": + continue + + # Make sure to update src_ops in the output tensors. + for output_tensor in duplicate_ops_outputs: + old_src_ops = output_tensor._attrs["src_ops"] + output_tensor._attrs["src_ops"] = set(old_src_ops) - set( + duplicate_ops + ) | {primary_op} + + # Make sure to update dst_ops in the input tensors. + for input_tensor in primary_op._attrs["inputs"]: + input_tensor._attrs["dst_ops"] = set( + input_tensor._attrs["dst_ops"] + ) - set(duplicate_ops) + + # Assumption: If the input accessors are the same, then the output's + # original shape must be the same. + prev_shape = primary_op._attrs["output_accessors"][0].original_shapes + for output_accessor in primary_op._attrs["output_accessors"]: + shape = output_accessor.original_shapes + assert ( + prev_shape == shape + ), "Output shapes mismatch in fuse_duplicate_fused_elementwise: {}, {}".format( + prev_shape, shape + ) + prev_shape = shape + + _LOGGER.info( + "Fusing {} with {}".format( + primary_op._attrs["name"], + ", ".join([dup_op._attrs["name"] for dup_op in duplicate_ops]), + ) + ) + + return transform_utils.sanitize_sorted_graph(sorted_graph) diff --git a/python/aitemplate/compiler/transform/optimize_graph.py b/python/aitemplate/compiler/transform/optimize_graph.py index edf0eede5..3eab2ad9d 100644 --- a/python/aitemplate/compiler/transform/optimize_graph.py +++ b/python/aitemplate/compiler/transform/optimize_graph.py @@ -22,6 +22,9 @@ from aitemplate.compiler.transform.dedup_make_jagged_ops import dedup_make_jagged_ops from aitemplate.compiler.transform.fuse_bmm_permute import fuse_bmm_permute from aitemplate.compiler.transform.fuse_conv_elementwise import fuse_conv_elementwise +from aitemplate.compiler.transform.fuse_duplicate_fused_elementwise import ( + fuse_duplicate_fused_elementwise, +) from aitemplate.compiler.transform.fuse_expand_bmm import fuse_expand_bmm from aitemplate.compiler.transform.fuse_group_ops import fuse_group_ops from aitemplate.compiler.transform.fuse_mm_elementwise import fuse_mm_elementwise @@ -127,6 +130,9 @@ def optimize_graph( transform_permute_to_reshape, transform_memory_ops, eliminate_permutations, + # fuse_duplicate_fused_elementwise must run after elementwise fusion and + # after passes that modify/replace a fused_elementwise's input/output accessor. + fuse_duplicate_fused_elementwise, ] if not optimize: diff --git a/tests/unittest/compiler/test_fuse_duplicate_fused_elementwise.py b/tests/unittest/compiler/test_fuse_duplicate_fused_elementwise.py new file mode 100644 index 000000000..3b278df09 --- /dev/null +++ b/tests/unittest/compiler/test_fuse_duplicate_fused_elementwise.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest +from typing import List + +import torch + +from aitemplate.compiler import compile_model, ops +from aitemplate.compiler.base import IntVar, Tensor +from aitemplate.compiler.ops.common.epilogue import FuncEnum +from aitemplate.compiler.transform.fuse_utils import is_elementwise_type +from aitemplate.testing import detect_target +from aitemplate.testing.test_utils import gen_input_tensor, get_random_torch_tensor +from aitemplate.utils.graph_utils import get_sorted_ops + + +class TestFuseDuplicateFusedElementwise(unittest.TestCase): + """ + This tests the compiler's behavior when fusing duplicate fused-elementwise ops. + See fuse_duplicate_fused_elementwise. + + We test the following test cases: + 1. Test duplicates + 2. Test duplicates with memory ops + 3. Test non-duplicates + 4. Test all interactions + 5. Test same input accessors + 6. Test different input accessors + """ + + SHAPE = [32, 64, 100] + + @staticmethod + def _count_fused_elementwise_ops( + graph: List[Tensor], target_elementwise_ops: List[FuncEnum] + ) -> int: + fused_elementwise_ops = filter( + lambda op: op._attrs["op"] == "fused_elementwise", get_sorted_ops(graph) + ) + + count = 0 + for op in fused_elementwise_ops: + elementwise_ops = op._attrs["elementwise_ops"] + if len(target_elementwise_ops) != len(elementwise_ops): + continue + if all( + is_elementwise_type(op, target) + for op, target in zip(elementwise_ops, target_elementwise_ops) + ): + count += 1 + return count + + def test_fuse_duplicates(self): + """When the input and elementwise ops are the same.""" + x = gen_input_tensor(shape=self.SHAPE, name="input_x") + sigmoid1 = ops.elementwise(FuncEnum.SIGMOID)(x) + sigmoid2 = ops.elementwise(FuncEnum.SIGMOID)(x) + softmax1 = ops.softmax()(sigmoid1, dim=0) + softmax2 = ops.softmax()(sigmoid2, dim=0) + model_output = softmax1 + softmax2 + model_output._attrs["is_output"] = True + model_output._attrs["name"] = "output" + + x_pt = get_random_torch_tensor(self.SHAPE) + sigmoid1_pt = torch.sigmoid(x_pt) + sigmoid2_pt = torch.sigmoid(x_pt) + softmax1_pt = torch.nn.functional.softmax(sigmoid1_pt, dim=0) + softmax2_pt = torch.nn.functional.softmax(sigmoid2_pt, dim=0) + y_pt = softmax1_pt + softmax2_pt + y_ait = torch.empty_like(y_pt) + + with compile_model( + model_output, + detect_target(), + "/tmp", + "fuse_duplicate_fused_elementwise_dups", + ) as module: + module.run_with_tensors({"input_x": x_pt}, {"output": y_ait}) + nsigmoid = self._count_fused_elementwise_ops( + module.debug_sorted_graph, [FuncEnum.SIGMOID] + ) + self.assertEqual(nsigmoid, 1) + self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-2, rtol=1e-2)) + + def test_fuse_duplicates_with_concat_output_accessor(self): + """Fused_elementwise ops' that have the same input and elementwise ops + and they have output accessors that write to the same concat output.""" + x = gen_input_tensor(shape=self.SHAPE, name="input_x") + sigmoid1 = ops.elementwise(FuncEnum.SIGMOID)(x) + sigmoid2 = ops.elementwise(FuncEnum.SIGMOID)(x) + model_output = ops.concatenate()([sigmoid1, sigmoid2]) + model_output._attrs["is_output"] = True + model_output._attrs["name"] = "output" + + x_pt = get_random_torch_tensor(self.SHAPE) + sigmoid1_pt = torch.sigmoid(x_pt) + sigmoid2_pt = torch.sigmoid(x_pt) + y_pt = torch.concat([sigmoid1_pt, sigmoid2_pt]) + y_ait = torch.empty_like(y_pt) + + with compile_model( + model_output, + detect_target(), + "/tmp", + "fuse_duplicate_fused_elementwise_dups_with_accessors", + ) as module: + module.run_with_tensors({"input_x": x_pt}, {"output": y_ait}) + nsigmoid = self._count_fused_elementwise_ops( + module.debug_sorted_graph, [FuncEnum.SIGMOID] + ) + self.assertEqual(nsigmoid, 1) + self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-2, rtol=1e-2)) + + def test_dont_fuse_non_duplicates(self): + """Fused-elementwise ops that have different inputs or different + elementwise-ops aren't fused together. + """ + x = gen_input_tensor(shape=self.SHAPE, name="input_x") + z = gen_input_tensor(shape=self.SHAPE, name="input_z") + relu_x = ops.elementwise(FuncEnum.RELU)(x) + gelu_x = ops.elementwise(FuncEnum.GELU)(x) + gelu_z = ops.elementwise(FuncEnum.GELU)(z) + softmax1 = ops.softmax()(relu_x, dim=0) + softmax2 = ops.softmax()(gelu_x, dim=0) + softmax3 = ops.softmax()(gelu_z, dim=0) + model_output = softmax1 + softmax2 + softmax3 + model_output._attrs["is_output"] = True + model_output._attrs["name"] = "output" + + x_pt = get_random_torch_tensor(self.SHAPE) + z_pt = get_random_torch_tensor(self.SHAPE) + relu_x_pt = torch.nn.functional.relu(x_pt) + gelu_x_pt = torch.nn.functional.gelu(x_pt) + gelu_z_pt = torch.nn.functional.gelu(z_pt) + softmax1_pt = torch.nn.functional.softmax(relu_x_pt, dim=0) + softmax2_pt = torch.nn.functional.softmax(gelu_x_pt, dim=0) + softmax3_pt = torch.nn.functional.softmax(gelu_z_pt, dim=0) + + y_pt = softmax1_pt + softmax2_pt + softmax3_pt + y_ait = torch.empty_like(y_pt) + + with compile_model( + model_output, + detect_target(), + "/tmp", + "fuse_duplicate_fused_elementwise_non_dups", + ) as module: + module.run_with_tensors( + {"input_x": x_pt, "input_z": z_pt}, {"output": y_ait} + ) + graph = module.debug_sorted_graph + nrelu = self._count_fused_elementwise_ops(graph, [FuncEnum.RELU]) + ngelu = self._count_fused_elementwise_ops(graph, [FuncEnum.GELU]) + self.assertEqual(nrelu, 1) + self.assertEqual(ngelu, 2) + self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-2, rtol=1e-2)) + + def test_all_interactions(self): + """Test all interactions: + 1. Fusing duplicates + 2. Fusing duplicates with accessors that write to a concat's output tensor + 3. Avoid fusing non-duplicates + """ + x = gen_input_tensor(shape=self.SHAPE, name="input_x") + z = gen_input_tensor(shape=self.SHAPE, name="input_z") + p = gen_input_tensor(shape=self.SHAPE, name="input_p") + + # First ReLU op with x as the input. + relu1 = ops.elementwise(FuncEnum.RELU)(x) + tanh = ops.elementwise(FuncEnum.TANH)(relu1) + concat1 = ops.concatenate()([relu1, tanh]) + + # Fuse relu2 with relu1. This ReLU uses a tensor accessor to write + # directly to concat2's output. + relu2 = ops.elementwise(FuncEnum.RELU)(x) + concat2 = ops.concatenate()([relu2, p]) + + # Fuse relu3 with relu1. + relu3 = ops.elementwise(FuncEnum.RELU)(x) + softmax = ops.softmax()(relu3, dim=0) + concat3 = ops.concatenate()([softmax, softmax]) + + # Don't fuse operators with different input or elementwise-ops. + gelu = ops.elementwise(FuncEnum.GELU)(x) + relu4 = ops.elementwise(FuncEnum.RELU)(z) + concat4 = ops.concatenate()([relu4, gelu]) + + model_output = concat1 + concat2 + concat3 + concat4 + model_output._attrs["is_output"] = True + model_output._attrs["name"] = "output" + + # Setup PyTorch + x_pt = get_random_torch_tensor(self.SHAPE) + z_pt = get_random_torch_tensor(self.SHAPE) + p_pt = get_random_torch_tensor(self.SHAPE) + + relu1_pt = torch.nn.functional.relu(x_pt) + tanh_pt = torch.nn.functional.tanh(relu1_pt) + concat1_pt = torch.concat([relu1_pt, tanh_pt]) + + relu2_pt = torch.nn.functional.relu(x_pt) + concat2_pt = torch.concat([relu2_pt, p_pt]) + + relu3_pt = torch.nn.functional.relu(x_pt) + softmax_pt = torch.nn.functional.softmax(relu3_pt, dim=0) + concat3_pt = torch.concat([softmax_pt, softmax_pt]) + + relu4_pt = torch.nn.functional.relu(z_pt) + gelu_pt = torch.nn.functional.gelu(x_pt) + concat4_pt = torch.concat([relu4_pt, gelu_pt]) + + y_pt = concat1_pt + concat2_pt + concat3_pt + concat4_pt + y_ait = torch.empty_like(y_pt) + + with compile_model( + model_output, + detect_target(), + "/tmp", + "fuse_duplicate_fused_elementwise_all_interactions", + ) as module: + module.run_with_tensors( + inputs={ + "input_x": x_pt, + "input_z": z_pt, + "input_p": p_pt, + }, + outputs={"output": y_ait}, + ) + graph = module.debug_sorted_graph + nrelu = self._count_fused_elementwise_ops(graph, [FuncEnum.RELU]) + ngelu = self._count_fused_elementwise_ops(graph, [FuncEnum.GELU]) + self.assertEqual(nrelu, 2) + self.assertEqual(ngelu, 1) + self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-2, rtol=1e-2)) + + def test_same_and_different_input_accessors(self): + """ + Before _fuse_slice_and_strided_op the fused_elementwise ops have different + input tensors. After _fuse_slice_and_strided_op, the fused_elementwise + ops have the same input tensor and depending on the slice indices, the + same or different input accessor. + """ + + # Input accessors are the same -- fuse them! + self._test_input_accessors_impl( + slice1_start=[0, 0, 0], + slice1_end=[32, 64, 50], + slice2_start=[0, 0, 0], + slice2_end=[32, 64, 50], + should_fuse=True, + ) + # Input accessors are different -- don't fuse. + self._test_input_accessors_impl( + slice1_start=[0, 0, 0], + slice1_end=[32, 64, 50], + slice2_start=[0, 0, 50], + slice2_end=[32, 64, 100], + should_fuse=False, + ) + + def _test_input_accessors_impl( + self, + slice1_start: List[IntVar], + slice1_end: List[IntVar], + slice2_start: List[IntVar], + slice2_end: List[IntVar], + should_fuse: bool, + ): + x = gen_input_tensor(shape=self.SHAPE, name="input_x") + x_sliced_1 = ops.dynamic_slice()(x, slice1_start, slice1_end) + x_sliced_2 = ops.dynamic_slice()(x, slice2_start, slice2_end) + sigmoid1 = ops.elementwise(FuncEnum.SIGMOID)(x_sliced_1) + sigmoid2 = ops.elementwise(FuncEnum.SIGMOID)(x_sliced_2) + softmax1 = ops.softmax()(sigmoid1, dim=0) + softmax2 = ops.softmax()(sigmoid2, dim=0) + model_output = softmax1 + softmax2 + model_output._attrs["is_output"] = True + model_output._attrs["name"] = "output" + + x_pt = get_random_torch_tensor(self.SHAPE) + x_sliced_1_pt = x_pt[[slice(s, e) for s, e in zip(slice1_start, slice1_end)]] + x_sliced_2_pt = x_pt[[slice(s, e) for s, e in zip(slice2_start, slice2_end)]] + sigmoid1_pt = torch.sigmoid(x_sliced_1_pt) + sigmoid2_pt = torch.sigmoid(x_sliced_2_pt) + softmax1_pt = torch.nn.functional.softmax(sigmoid1_pt, dim=0) + softmax2_pt = torch.nn.functional.softmax(sigmoid2_pt, dim=0) + y_pt = softmax1_pt + softmax2_pt + y_ait = torch.empty_like(y_pt) + + with compile_model( + model_output, + detect_target(), + "/tmp", + "fuse_duplicate_fused_elementwise_same_input_different_input_accessors", + ) as module: + module.run_with_tensors({"input_x": x_pt}, {"output": y_ait}) + nsigmoid = self._count_fused_elementwise_ops( + module.debug_sorted_graph, [FuncEnum.SIGMOID] + ) + self.assertEqual(nsigmoid, 1 if should_fuse else 2) + self.assertTrue(torch.allclose(y_pt, y_ait, atol=1e-2, rtol=1e-2)) diff --git a/tests/unittest/compiler/test_move_view_ops.py b/tests/unittest/compiler/test_move_view_ops.py index 3493f7c83..f380e0e90 100644 --- a/tests/unittest/compiler/test_move_view_ops.py +++ b/tests/unittest/compiler/test_move_view_ops.py @@ -861,7 +861,7 @@ def _test_move_strided_reshape_cat_3( module = compile_model(Y, target, "./tmp", test_name) sorted_graph = module.debug_sorted_graph sorted_ops = graph_utils.get_sorted_ops(sorted_graph) - self.assertEqual(len(sorted_ops), 5) + self.assertEqual(len(sorted_ops), 4) concat_cnt = 0 for sorted_op in sorted_ops: if sorted_op._attrs["op"] == "concatenate":