From 89f35da5d11a63ecd11705341478fddb5a139362 Mon Sep 17 00:00:00 2001 From: Arne Symons <34397549+asyms@users.noreply.github.com> Date: Fri, 25 Oct 2024 10:04:54 +0200 Subject: [PATCH] Schedule visualization upgrade (#54) * update Plotly visualization to display energy breakdown and spatial utilization in hover * update main files to load CostModelEvaluationLUT to display extra performance info in visualization * fix small typo in simd parser: capitalization of K dimension * allow 'all' as tiling dimension size in mapping * update tpu mapping intra_core_tiling and inter_core_tiling * rename node_hw_performances to cost_lut and rename all variables; restructure output saving paths * add required and used link bandwidth in schedule visualization * change print in memory usage visualization to logging INFO statement * ignore all .pkl and .pickle files * remove output pickle files --- .gitignore | 4 +- docs/source/stages.rst | 4 +- main_stream_co.py | 34 +++--- main_stream_ga.py | 16 ++- stream/api.py | 79 ++++--------- stream/cost_model/communication_manager.py | 9 +- .../examples/mapping/tpu_like_quad_core.yaml | 28 ++++- .../constraint_optimization/allocation.py | 6 +- .../constraint_optimization/utils.py | 10 +- .../genetic_algorithm/fitness_evaluator.py | 26 ++-- stream/parser/mapping_factory.py | 8 +- stream/parser/mapping_validator.py | 2 +- stream/parser/onnx/simd.py | 4 +- .../constraint_optimization_allocation.py | 47 ++++---- .../genetic_algorithm_allocation.py | 14 +-- .../stream_cost_model_evaluation.py | 4 +- .../zigzag_core_mapping_estimation.py | 56 ++++----- .../set_fixed_allocation_performance.py | 27 ++--- stream/utils.py | 12 ++ .../visualization/constraint_optimization.py | 10 +- ...mances.py => cost_model_evaluation_lut.py} | 4 +- stream/visualization/memory_usage.py | 4 +- stream/visualization/schedule.py | 111 ++++++++++++++++-- .../workload/computation/computation_node.py | 2 +- stream/workload/node.py | 6 +- 25 files changed, 305 insertions(+), 222 deletions(-) rename stream/visualization/{node_hw_performances.py => cost_model_evaluation_lut.py} (97%) diff --git a/.gitignore b/.gitignore index 008cbfb..adb0043 100644 --- a/.gitignore +++ b/.gitignore @@ -136,8 +136,8 @@ dmypy.json .vscode/ # result pickle files -output*/*.pkl -output*/*.pickle +*.pkl +*.pickle # result json files outputs*/*.json diff --git a/docs/source/stages.rst b/docs/source/stages.rst index 1fe383d..aab2584 100644 --- a/docs/source/stages.rst +++ b/docs/source/stages.rst @@ -27,7 +27,7 @@ Stages within Stream are used to modularly and easily adapt the functionality of loma_lpf_limit=6, # required by LomaEngine nb_ga_individuals=32, # number of individuals in each genetic algorithm generation nb_ga_generations=100, # number of genetic algorithm generations - node_hw_performances_path=node_hw_performances_path, # saved node_hw_performances to skip re-computation + cost_lut_path=cost_lut_path, # saved CostModelEvaluationLUT to skip re-computation plot_hof=True, # Save schedule and memory usage plot of each individual in the Genetic Algorithm hall of fame plot_file_name=plot_file_name, plot_full_schedule=plot_full_schedule, @@ -74,7 +74,7 @@ Multiple modes are applicable through the `cn_define_mode` parameter in conjunct `InterCoreMappingStage `_ ---------------------------------------------------------------------------------------------------------------------------------- -Stage that finds the best inter-core mapping using a genetic algorithm. From the IntraCoreMappingStage we receive the `node_hw_performances`, containing for each node and its valid core allocations the best CME. We then initialize the genetic algorithm. +Stage that finds the best inter-core mapping using a genetic algorithm. From the IntraCoreMappingStage we receive the `CostModelEvaluationLUT`, containing for each node and its valid core allocations the best CME. We then initialize the genetic algorithm. `IntraCoreMappingStage `_ ----------------------------------------------------------------------------------------------------------------------------------- diff --git a/main_stream_co.py b/main_stream_co.py index 84e0e8e..3cb4518 100644 --- a/main_stream_co.py +++ b/main_stream_co.py @@ -2,6 +2,7 @@ import re from stream.api import optimize_allocation_co +from stream.utils import CostModelEvaluationLUT from stream.visualization.memory_usage import plot_memory_usage from stream.visualization.schedule import ( visualize_timeline_plotly, @@ -27,8 +28,18 @@ experiment_id = f"{hw_name}-{wl_name}-{mode}-constraint_optimization" ###################################################################### +scme = optimize_allocation_co( + hardware=accelerator, + workload=workload_path, + mapping=mapping_path, + mode=mode, + layer_stacks=layer_stacks, + experiment_id=experiment_id, + output_path="outputs", + skip_if_exists=False, +) + ############PLOTTING############# -plot_file_name = f"-{experiment_id}-" plot_full_schedule = True draw_dependencies = True plot_data_transfer = True @@ -36,21 +47,15 @@ percent_shown = (100,) ################################# - -################################PATHS################################ -timeline_fig_path_plotly = f"outputs/{experiment_id}-schedule.html" -memory_fig_path = f"outputs/{experiment_id}-memory.png" +#########################PLOTTING PATHS############################## +timeline_fig_path_plotly = f"outputs/{experiment_id}/schedule.html" +memory_fig_path = f"outputs/{experiment_id}/memory.png" ##################################################################### -scme = optimize_allocation_co( - hardware=accelerator, - workload=workload_path, - mapping=mapping_path, - mode=mode, - layer_stacks=layer_stacks, - experiment_id=experiment_id, - output_path="outputs", -) +#####################CostModelEvaluationLUT LOAD############################# +cost_lut_path = f"outputs/{experiment_id}/cost_lut_post_co.pickle" +cost_lut = CostModelEvaluationLUT(cost_lut_path) +############################################################################# # Plotting schedule timeline of best SCME visualize_timeline_plotly( @@ -58,6 +63,7 @@ draw_dependencies=draw_dependencies, draw_communication=plot_data_transfer, fig_path=timeline_fig_path_plotly, + cost_lut=cost_lut, ) # Plotting memory usage of best SCME plot_memory_usage(scme, section_start_percent, percent_shown, fig_path=memory_fig_path) diff --git a/main_stream_ga.py b/main_stream_ga.py index 713c3d2..805f3e4 100644 --- a/main_stream_ga.py +++ b/main_stream_ga.py @@ -2,6 +2,7 @@ import re from stream.api import optimize_allocation_ga +from stream.utils import CostModelEvaluationLUT from stream.visualization.memory_usage import plot_memory_usage from stream.visualization.schedule import ( visualize_timeline_plotly, @@ -17,8 +18,8 @@ mapping_path = "stream/inputs/examples/mapping/tpu_like_quad_core.yaml" mode = "fused" layer_stacks = [tuple(range(0, 11)), tuple(range(11, 22))] + list((i,) for i in range(22, 49)) -nb_ga_generations = 16 -nb_ga_individuals = 16 +nb_ga_generations = 4 +nb_ga_individuals = 4 ############################################################################################## ################################PARSING############################### @@ -40,8 +41,8 @@ ################################PATHS################################ -timeline_fig_path_plotly = f"outputs/{experiment_id}-schedule.html" -memory_fig_path = f"outputs/{experiment_id}-memory.png" +timeline_fig_path_plotly = f"outputs/{experiment_id}/schedule.html" +memory_fig_path = f"outputs/{experiment_id}/memory.png" ##################################################################### scme = optimize_allocation_ga( @@ -54,15 +55,20 @@ nb_ga_individuals=nb_ga_individuals, experiment_id=experiment_id, output_path="outputs", - skip_if_exists=False, + skip_if_exists=True, ) +# Load in the CostModelEvaluationLUT from the run +cost_lut_path = f"outputs/{experiment_id}/cost_lut.pickle" +cost_lut = CostModelEvaluationLUT(cost_lut_path) + # Plotting schedule timeline of best SCME visualize_timeline_plotly( scme, draw_dependencies=draw_dependencies, draw_communication=plot_data_transfer, fig_path=timeline_fig_path_plotly, + cost_lut=cost_lut, ) # Plotting memory usage of best SCME plot_memory_usage(scme, section_start_percent, percent_shown, fig_path=memory_fig_path) diff --git a/stream/api.py b/stream/api.py index 06edbeb..7c3f37d 100644 --- a/stream/api.py +++ b/stream/api.py @@ -40,6 +40,7 @@ def _sanity_check_gurobi_license(): try: # Try to create a simple optimization model model = gp.Model() + model.setParam("OutputFlag", 0) # Check if the model was successfully created (license check) model.optimize() # If model.optimize() runs without a license issue, return @@ -67,12 +68,17 @@ def optimize_allocation_ga( ) -> StreamCostModelEvaluation: _sanity_check_inputs(hardware, workload, mapping, mode, output_path) - logger = _logging.getLogger(__name__) + # Create experiment_id path + os.makedirs(f"{output_path}/{experiment_id}", exist_ok=True) # Output paths - node_hw_performances_path = f"{output_path}/{experiment_id}-saved_cn_hw_cost.pickle" - scme_path = f"{output_path}/{experiment_id}-scme.pickle" + cost_lut_path = f"{output_path}/{experiment_id}/cost_lut.pickle" + scme_path = f"{output_path}/{experiment_id}/scme.pickle" + + # Get logger + logger = _logging.getLogger(__name__) + # Load SCME if it exists and skip_if_exists is True if os.path.exists(scme_path) and skip_if_exists: scme = pickle_load(scme_path) logger.info(f"Loaded SCME from {scme_path}") @@ -93,11 +99,11 @@ def optimize_allocation_ga( workload_path=workload, # required by ModelParserStage mapping_path=mapping, # required by ModelParserStage loma_lpf_limit=6, # required by LomaEngine - nb_ga_generations=nb_ga_generations, # number of genetic algorithm generations - nb_ga_individuals=nb_ga_individuals, # number of individuals in each genetic algorithm generation + nb_ga_generations=nb_ga_generations, # number of genetic algorithm (ga) generations + nb_ga_individuals=nb_ga_individuals, # number of individuals in each ga generation mode=mode, layer_stacks=layer_stacks, - node_hw_performances_path=node_hw_performances_path, + cost_lut_path=cost_lut_path, operands_to_prefetch=[], # required by GeneticAlgorithmAllocationStage ) # Launch the MainStage @@ -120,14 +126,19 @@ def optimize_allocation_co( _sanity_check_inputs(hardware, workload, mapping, mode, output_path) _sanity_check_gurobi_license() + # Create experiment_id path + os.makedirs(f"{output_path}/{experiment_id}", exist_ok=True) + # Output paths - node_hw_performances_path = f"{output_path}/{experiment_id}-saved_cn_hw_cost.pickle" - scme_path = f"{output_path}/{experiment_id}-scme.pickle" - # After constraint optimization paths - node_hw_performances_path_with_split = f"outputs/{experiment_id}-saved_cn_hw_cost-with_split.pickle" + cost_lut_path = f"{output_path}/{experiment_id}/cost_lut.pickle" + allocations_path = f"{output_path}/{experiment_id}/waco/" + cost_lut_post_co_path = f"outputs/{experiment_id}/cost_lut_post_co.pickle" + scme_path = f"{output_path}/{experiment_id}/scme.pickle" + # Get logger logger = _logging.getLogger(__name__) + # Load SCME if it exists and skip_if_exists is True if os.path.exists(scme_path) and skip_if_exists: scme = pickle_load(scme_path) logger.info(f"Loaded SCME from {scme_path}") @@ -150,8 +161,9 @@ def optimize_allocation_co( loma_lpf_limit=6, # required by LomaEngine mode=mode, layer_stacks=layer_stacks, - node_hw_performances_path=node_hw_performances_path, - node_hw_performances_path_with_split=node_hw_performances_path_with_split, + cost_lut_path=cost_lut_path, + allocations_path=allocations_path, + cost_lut_post_co_path=cost_lut_post_co_path, operands_to_prefetch=[], # required by ConstraintOptimizationAllocationStage ) # Launch the MainStage @@ -159,46 +171,3 @@ def optimize_allocation_co( scme = answers[0][0] pickle_save(scme, scme_path) return scme - - -if __name__ == "__main__": - from stream.visualization.memory_usage import plot_memory_usage - from stream.visualization.schedule import visualize_timeline_plotly - - accelerator = "stream/inputs/examples/hardware/tpu_like_quad_core.yaml" - workload = "stream/inputs/examples/workload/resnet18.yaml" - mapping = "stream/inputs/examples/mapping/tpu_like_quad_core.yaml" - - hw_name = "tpu_like_quad_core" - wl_name = "resnet18" - mode = "fused" - experiment_id = f"{hw_name}-{wl_name}" - output_path = "outputs" - layer_stacks = [tuple(range(0, 11)), tuple(range(11, 22))] + list((i,) for i in range(22, 49)) - - scme, _ = optimize_allocation_ga( - accelerator, - workload, - mapping, - mode, - layer_stacks, - experiment_id, - output_path, - ) - - plot_full_schedule = True - draw_dependencies = True - plot_data_transfer = True - section_start_percent = (0,) - percent_shown = (100,) - schedule_fig_path = f"{output_path}/schedule_plot.png" - memory_fig_path = f"{output_path}/memory_plot.png" - energy_fig_path = f"{output_path}/energy_plot.png" - visualize_timeline_plotly( - scme=scme, - draw_dependencies=draw_dependencies, - draw_communication=True, - fig_path=schedule_fig_path, - ) - plot_memory_usage(scme.accelerator.memory_manager, fig_path=memory_fig_path) - # bar_plot_stream_cost_model_evaluations_breakdown([scme], fig_path=energy_fig_path) diff --git a/stream/cost_model/communication_manager.py b/stream/cost_model/communication_manager.py index d91464c..1ee67c5 100644 --- a/stream/cost_model/communication_manager.py +++ b/stream/cost_model/communication_manager.py @@ -46,13 +46,11 @@ class CommunicationLinkEvent: - a list of tensors relevant for the event: * the tensor being transferred * the tensor(s) for which we are blocking - - an activity percentage: - * the percentage of the link bandwidth used + - an activity: + * the bits per clock cycle used of the link bandwidth """ - def __init__( - self, type: str, start: int, end: int, tensors: list[Tensor], energy: float, activity: float = 100 - ) -> None: + def __init__(self, type: str, start: int, end: int, tensors: list[Tensor], energy: float, activity: float) -> None: self.type = type self.start = start self.end = end @@ -163,6 +161,7 @@ def update_links( end=end_timestep, tensors=[tensor], energy=duration * link.unit_energy_cost, + activity=link.bandwidth, ) for link in links ] diff --git a/stream/inputs/examples/mapping/tpu_like_quad_core.yaml b/stream/inputs/examples/mapping/tpu_like_quad_core.yaml index b896be8..676dd10 100644 --- a/stream/inputs/examples/mapping/tpu_like_quad_core.yaml +++ b/stream/inputs/examples/mapping/tpu_like_quad_core.yaml @@ -1,31 +1,55 @@ - name: default core_allocation: [0, 1, 2, 3] intra_core_tiling: - - D, 64 + - D, all inter_core_tiling: - K, * - name: Conv core_allocation: [0, 1, 2, 3] intra_core_tiling: - - K, 8 + - OY, all inter_core_tiling: - K, * - name: Gemm core_allocation: [0, 1, 2, 3] + intra_core_tiling: + - D, all + inter_core_tiling: + - H, * - name: Pool core_allocation: [4] + intra_core_tiling: + - OY, all + inter_core_tiling: + - K, * - name: MaxPool core_allocation: [4] + intra_core_tiling: + - OY, all + inter_core_tiling: + - K, * - name: AveragePool core_allocation: [4] + intra_core_tiling: + - OY, all + inter_core_tiling: + - K, * - name: GlobalAveragePool core_allocation: [4] + intra_core_tiling: + - OY, all + inter_core_tiling: + - K, * - name: Add core_allocation: [5] + intra_core_tiling: + - D, all + inter_core_tiling: + - H, * diff --git a/stream/opt/allocation/constraint_optimization/allocation.py b/stream/opt/allocation/constraint_optimization/allocation.py index a1f5077..5a59ea1 100644 --- a/stream/opt/allocation/constraint_optimization/allocation.py +++ b/stream/opt/allocation/constraint_optimization/allocation.py @@ -22,7 +22,7 @@ def get_optimal_allocations( workload: ComputationNodeWorkload, accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, iterations: int, gap: float = 0.5, time_limit: int = 600, @@ -34,9 +34,9 @@ def get_optimal_allocations( ids = convert_ids(nodes) latencies, possible_allocation_splits = get_latencies( - nodes, core_ids, accelerator, node_hw_performances, impossible_lat=0, ids=ids + nodes, core_ids, accelerator, cost_lut, impossible_lat=0, ids=ids ) - energies = get_energies(nodes, core_ids, accelerator, node_hw_performances, impossible_energy=0, ids=ids) + energies = get_energies(nodes, core_ids, accelerator, cost_lut, impossible_energy=0, ids=ids) output_operand = LayerOperand("O") dependencies = { (ids[p], ids[c]): p.operand_size_bit[output_operand] for p, c in workload.edges() if p in nodes and c in nodes diff --git a/stream/opt/allocation/constraint_optimization/utils.py b/stream/opt/allocation/constraint_optimization/utils.py index 2ab95c3..263e62a 100644 --- a/stream/opt/allocation/constraint_optimization/utils.py +++ b/stream/opt/allocation/constraint_optimization/utils.py @@ -47,7 +47,7 @@ def get_latencies( nodes: list[ComputationNode], core_ids: list[int], accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, impossible_lat: float = 1e11, ids: dict[ComputationNode, int] = {}, ) -> tuple[dict[tuple[int, str, int], int], dict]: @@ -64,9 +64,9 @@ def get_latencies( for core_id, core_name in zip(core_ids, core_names): core = accelerator.get_core(core_id) try: - equal_node = node_hw_performances.get_equal_node(node) + equal_node = cost_lut.get_equal_node(node) assert equal_node, f"No equal node for {node} found in CostModelEvaluationLUT" - cme = node_hw_performances.get_cme(equal_node, core) + cme = cost_lut.get_cme(equal_node, core) output_operand = LayerOperand("O") temporal_loops = [ i for tm_level in cme.temporal_mapping.mapping_dic_stationary[output_operand] for i in tm_level @@ -110,7 +110,7 @@ def get_energies( nodes: list[ComputationNode], core_ids: list[int], accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, impossible_energy: float = 1e11, ids: dict[ComputationNode, int] = {}, ) -> dict[tuple[int, str], float]: @@ -123,7 +123,7 @@ def get_energies( for core_id, core_name in zip(core_ids, core_names): core = accelerator.get_core(core_id) try: - cme = node_hw_performances.get_cme(node, core) + cme = cost_lut.get_cme(node, core) en = getattr(cme, "energy_total") except ValueError: en = impossible_energy diff --git a/stream/opt/allocation/genetic_algorithm/fitness_evaluator.py b/stream/opt/allocation/genetic_algorithm/fitness_evaluator.py index 842e659..59c9d81 100644 --- a/stream/opt/allocation/genetic_algorithm/fitness_evaluator.py +++ b/stream/opt/allocation/genetic_algorithm/fitness_evaluator.py @@ -1,10 +1,9 @@ from zigzag.datatypes import LayerOperand -from zigzag.mapping.data_movement import MemoryAccesses from zigzag.utils import pickle_deepcopy from stream.cost_model.cost_model import StreamCostModelEvaluation from stream.hardware.architecture.accelerator import Accelerator -from stream.utils import CostModelEvaluationLUT, get_too_large_operands +from stream.utils import CostModelEvaluationLUT, get_required_offchip_bandwidth, get_too_large_operands from stream.workload.computation.computation_node import ComputationNode from stream.workload.onnx_workload import ComputationNodeWorkload @@ -14,11 +13,11 @@ def __init__( self, workload: ComputationNodeWorkload, accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, ) -> None: self.workload = workload self.accelerator = accelerator - self.node_hw_performances = node_hw_performances + self.cost_lut = cost_lut # self.num_cores = len(inputs.accelerator.cores) def get_fitness(self): @@ -32,12 +31,12 @@ def __init__( self, workload: ComputationNodeWorkload, accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, layer_groups_flexible, operands_to_prefetch: list[LayerOperand], scheduling_order: list[tuple[int, int]], ) -> None: - super().__init__(workload, accelerator, node_hw_performances) + super().__init__(workload, accelerator, cost_lut) self.weights = (-1.0, -1.0) self.metrics = ["energy", "latency"] @@ -85,9 +84,9 @@ def set_node_core_allocations(self, core_allocations: list[int]): if isinstance(node, ComputationNode) and node.id == layer_id and node.group == group_id ) for node in nodes: - equal_unique_node = self.node_hw_performances.get_equal_node(node) - assert equal_unique_node is not None, "Node not found in node_hw_performances" - cme = self.node_hw_performances.get_cme(equal_unique_node, core) + equal_unique_node = self.cost_lut.get_equal_node(node) + assert equal_unique_node is not None, "Node not found in CostModelEvaluationLUT" + cme = self.cost_lut.get_cme(equal_unique_node, core) onchip_energy = cme.energy_total # Initialize on-chip energy as total energy latency = cme.latency_total1 too_large_operands = get_too_large_operands(cme, self.accelerator, core_id=core_allocation) @@ -101,15 +100,10 @@ def set_node_core_allocations(self, core_allocations: list[int]): offchip_energy += layer_operand_offchip_energy onchip_energy -= layer_operand_offchip_energy # If there was offchip memory added for too_large_operands, get the offchip bandwidth - if self.accelerator.offchip_core_id is not None: - offchip_core = self.accelerator.get_core(self.accelerator.offchip_core_id) - offchip_instance = next(v for k, v in offchip_core.mem_hierarchy_dict.items())[-1].memory_instance - offchip_bw = cme.get_total_inst_bandwidth(offchip_instance) - else: - offchip_bw = MemoryAccesses(0, 0, 0, 0) + required_offchip_bandwidth = get_required_offchip_bandwidth(cme, too_large_operands) node.set_onchip_energy(onchip_energy) node.set_offchip_energy(offchip_energy) node.set_runtime(int(latency)) node.set_chosen_core_allocation(core_allocation) node.set_too_large_operands(too_large_operands) - node.set_offchip_bandwidth(offchip_bw) + node.set_offchip_bandwidth(required_offchip_bandwidth) diff --git a/stream/parser/mapping_factory.py b/stream/parser/mapping_factory.py index 44bd97a..fdf0388 100644 --- a/stream/parser/mapping_factory.py +++ b/stream/parser/mapping_factory.py @@ -72,6 +72,12 @@ def __convert_layer_dim_int_pair(self, pair: str): """Convert strings such as `D, 4` into a LayerDim and int""" layer_dim_str = pair.split(",")[0] unrolling_str = pair.split(",")[-1] - unrolling = int(unrolling_str) if "*" not in unrolling_str else "*" + match unrolling_str.strip(" "): + case "all": + unrolling = "all" + case "*": + unrolling = "*" + case _: + unrolling = int(unrolling_str) layer_dim = LayerDim(layer_dim_str) return layer_dim, unrolling diff --git a/stream/parser/mapping_validator.py b/stream/parser/mapping_validator.py index 0c4dae8..cef80fc 100644 --- a/stream/parser/mapping_validator.py +++ b/stream/parser/mapping_validator.py @@ -9,7 +9,7 @@ class MappingValidator: """Class to validate user-given mappings from yaml file""" - TILING_REGEX = r"^[A-Z]+, ([0-9]+|\*)$" + TILING_REGEX = r"^[A-Z]+, ([0-9]+|all|\*)$" SPATIAL_MAPPING_REGEX = r"^[A-Z]+, [0-9]+$" SPATIAL_MAPPING_HINT_REGEX = r"^[A-Z]+$" diff --git a/stream/parser/onnx/simd.py b/stream/parser/onnx/simd.py index 8248f8f..2c5ae21 100644 --- a/stream/parser/onnx/simd.py +++ b/stream/parser/onnx/simd.py @@ -35,10 +35,10 @@ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[ data["loop_dims"] = ["D", "K"] case 3: data["equation"] = f"O[b][d][k]+=I[b][d][k]*W{'[]' if has_single_input else '[b][d][k]'}" - data["loop_dims"] = ["B", "D", "k"] + data["loop_dims"] = ["B", "D", "K"] case 4: data["equation"] = f"O[b][h][d][k]+=I[b][h][d][k]*W{'[]' if has_single_input else '[b][h][d][k]'}" - data["loop_dims"] = ["B", "H", "D", "k"] + data["loop_dims"] = ["B", "H", "D", "K"] case _: raise NotImplementedError diff --git a/stream/stages/allocation/constraint_optimization_allocation.py b/stream/stages/allocation/constraint_optimization_allocation.py index d5096ed..632268a 100644 --- a/stream/stages/allocation/constraint_optimization_allocation.py +++ b/stream/stages/allocation/constraint_optimization_allocation.py @@ -39,15 +39,18 @@ class ConstraintOptimizationAllocationStage(Stage): This stages requires a CostModelEvaluationLUT, containing for each node and its valid core allocations the best CME. """ + CO_TIME_LIMIT = 600 + def __init__( self, list_of_callables: list[StageCallable], *, workload: ComputationNodeWorkload, accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, layer_stacks: list[tuple[int, ...]], - node_hw_performances_path_with_split: str, + allocations_path: str, + cost_lut_post_co_path: str, **kwargs: Any, ): """Initialize the ResourceAllocationStage. @@ -56,29 +59,23 @@ def __init__( list_of_callables (list): List of the substages to be called. This should be empty as this is a leaf stage. workload (DiGraph): The NetworkX DiGraph representing the workload to be scheduled accelerator (Accelerator): The hardware accelerator onto which we schedule the workload - node_hw_performances (dict): A nested dict containing for each node a dict with for each valid core its best HW performance + cost_lut (CostModelEvaluationLUT): A lookup table containing for each node the best CME for each core layer_stacks (list): List of tuples with each tuple containing the layer ids to fuse together + allocations_path (str): Path to the directory where the optimal allocations are stored + cost_lut_post_co_path (str): Path to the file where the cost LUT after CO is stored """ super().__init__(list_of_callables, **kwargs) self.workload = workload self.accelerator = accelerator - self.node_hw_performances = node_hw_performances + self.cost_lut = cost_lut self.layer_stacks = layer_stacks self.original_workload: ComputationNodeWorkload = kwargs["original_workload"] self.mode = kwargs.get("mode", "fused") # assume default is fused - self.steady_state_visualization_path = kwargs.get("steady_state_visualization_path", "outputs/") - self.node_hw_performances_path_with_split = node_hw_performances_path_with_split - if "visualize_node_hw_performances_path_with_split" in kwargs: - self.visualize_node_hw_performances_path_with_split = kwargs[ - "visualize_node_hw_performances_path_with_split" - ] - else: - node_hw_performances_visualization_path = ( - os.path.splitext(self.node_hw_performances_path_with_split)[0] + ".png" - ) - self.visualize_node_hw_performances_path_with_split = node_hw_performances_visualization_path - self.co_time_limit: int = kwargs.get("co_time_limit", 600) + self.allocations_path = allocations_path + os.makedirs(self.allocations_path, exist_ok=True) + self.cost_lut_post_co_path = cost_lut_post_co_path + self.co_time_limit: int = kwargs.get("co_time_limit", self.CO_TIME_LIMIT) # Which CME attribute to use for the node latencies self.latency_attr = kwargs.get("latency_attr", "latency_total1") @@ -209,23 +206,22 @@ def find_best_allocation( """# TODO: Implement overhead of tensor transfers between cores""" # Check if the allocation is already cached, if not: find it stack_str = "_".join([str(id) for id in stack]) - allocations_path = os.path.join(self.steady_state_visualization_path, f"steady_state-{stack_str}.pickle") - if os.path.exists(allocations_path): - allocation = pickle_load(allocations_path) + stack_allocations_path = os.path.join(self.allocations_path, f"steady_state-{stack_str}.pickle") + if os.path.exists(stack_allocations_path): + allocation = pickle_load(stack_allocations_path) else: sg = self.workload.subgraph(to_compute) logger.info(f"Optimizing allocation for {iterations} iterations of {len(to_compute)} ss nodes.") allocation = get_optimal_allocations( sg, self.accelerator, - self.node_hw_performances, + self.cost_lut, iterations, time_limit=time_limit, ) - pickle_save(allocation, allocations_path) - fig_path = os.path.join(self.steady_state_visualization_path, f"steady_state-{stack_str}.html") - print(f"stack = {stack}") - visualize_waco(allocation, self.node_hw_performances, self.accelerator, fig_path, iterations) + pickle_save(allocation, stack_allocations_path) + fig_path = stack_allocations_path.replace(".pickle", ".html") + visualize_waco(allocation, self.cost_lut, self.accelerator, fig_path, iterations) return allocation def get_scheduling_order(self, unpartitioned_workload: DNNWorkloadStream) -> SCHEDULE_ORDER_T: @@ -401,8 +397,7 @@ def schedule_allocation(self, allocation: ALLOCATION_T) -> StreamCostModelEvalua kwargs["accelerator"] = self.accelerator kwargs["workload"] = unpartitioned_sub_workload kwargs["scheduling_order"] = scheduling_order - kwargs["node_hw_performances_path"] = self.node_hw_performances_path_with_split - kwargs["visualize_node_hw_performances_path"] = self.visualize_node_hw_performances_path_with_split + kwargs["cost_lut_path"] = self.cost_lut_post_co_path kwargs["latency_attr"] = self.latency_attr # Create stages that will run a single cost model evaluation (fixed core allocations) diff --git a/stream/stages/allocation/genetic_algorithm_allocation.py b/stream/stages/allocation/genetic_algorithm_allocation.py index 3c8f2d4..1b55c77 100644 --- a/stream/stages/allocation/genetic_algorithm_allocation.py +++ b/stream/stages/allocation/genetic_algorithm_allocation.py @@ -21,7 +21,7 @@ class GeneticAlgorithmAllocationStage(Stage): """ Class that finds the best inter-core mapping using a genetic algorithm. - From the IntraCoreMappingStage we receive the `node_hw_performances`, containing for each node and its valid core + From the IntraCoreMappingStage we receive the `CostModelEvaluationLUT`, containing for each node and its valid core allocations the best CME. We then initialize the genetic algorithm. TODO A separate "GeneticAlgorithmStage" should be added where we parse all GA-related info and this stage then calls @@ -34,7 +34,7 @@ def __init__( *, workload: ComputationNodeWorkload, accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, nb_ga_generations: int, nb_ga_individuals: int, operands_to_prefetch: list[LayerOperand], @@ -47,14 +47,14 @@ def __init__( list_of_callables (list): List of the substages to be called. This should be empty as this is a leaf stage. workload (DiGraph): The NetworkX DiGraph representing the workload to be scheduled accelerator (Accelerator): The hardware accelerator onto which we schedule the workload - node_hw_performances (CostModelEvaluationLUT): A LUT of CMEs for each unique node and their valid cores + cost_lut (CostModelEvaluationLUT): A LUT of CMEs for each unique node and their valid cores nb_ga_generations (int): The number of generations considered by the genetic algorithm nb_ga_individuals (int): The number of individuals in each genetic algorithm generation """ super().__init__(list_of_callables, **kwargs) self.workload = workload self.accelerator = accelerator - self.node_hw_performances = node_hw_performances + self.cost_lut = cost_lut self.nb_generations = nb_ga_generations self.nb_individuals = nb_ga_individuals self.operands_to_prefetch = operands_to_prefetch @@ -75,7 +75,7 @@ def __init__( self.unique_nodes_flexible.append(n) self.coarse_node_ids_flexible: list[int] = [n.id for n in self.unique_nodes_flexible] - # For each unique node get the possible core allocations by getting the ids of the cores in node_hw_performances + # For each unique node get the possible core allocations by getting the ids of the cores in cost_lut self.valid_allocations: list[list[int]] = [] # Save all the layer group combinations that are flexible self.layer_groups_flexible: list[tuple[int, int]] = [] @@ -84,7 +84,7 @@ def __init__( # This assumes all the nodes of this layer are identical unique_node = next((n for n in self.unique_nodes if n.id == layer_id)) if unique_node in self.unique_nodes_flexible: - cores = self.node_hw_performances.get_cores(unique_node) + cores = self.cost_lut.get_cores(unique_node) valid_core_ids = [core.id for core in cores if core.id < len(self.unique_nodes_flexible)] self.layer_groups_flexible.append((layer_id, group_id)) self.valid_allocations.append(valid_core_ids) @@ -93,7 +93,7 @@ def __init__( self.fitness_evaluator = StandardFitnessEvaluator( self.workload, self.accelerator, - self.node_hw_performances, + self.cost_lut, self.layer_groups_flexible, self.operands_to_prefetch, self.scheduling_order, diff --git a/stream/stages/estimation/stream_cost_model_evaluation.py b/stream/stages/estimation/stream_cost_model_evaluation.py index aa79214..9b131eb 100644 --- a/stream/stages/estimation/stream_cost_model_evaluation.py +++ b/stream/stages/estimation/stream_cost_model_evaluation.py @@ -31,9 +31,7 @@ def __init__( list_of_callables (list): List of the substages to be called. This should be empty as this is a leaf stage. workload (DiGraph): The NetworkX DiGraph representing the workload to be scheduled accelerator (Accelerator): The hardware accelerator onto which we schedule the workload - node_hw_performances (CostModelEvaluationLUT): A LUT of CMEs for each unique node and their valid cores - nb_ga_generations (int): The number of generations considered by the genetic algorithm - nb_ga_individuals (int): The number of individuals in each genetic algorithm generation + operands_to_prefetch (list[LayerOperand]): A list of LayerOperands that whose tensors should be prefetched """ super().__init__(list_of_callables, **kwargs) self.workload = workload diff --git a/stream/stages/estimation/zigzag_core_mapping_estimation.py b/stream/stages/estimation/zigzag_core_mapping_estimation.py index b035b9a..acbb506 100644 --- a/stream/stages/estimation/zigzag_core_mapping_estimation.py +++ b/stream/stages/estimation/zigzag_core_mapping_estimation.py @@ -16,8 +16,8 @@ from stream.hardware.architecture.accelerator import Accelerator from stream.stages.stage import MainStage, Stage, StageCallable from stream.utils import CostModelEvaluationLUT, get_unique_nodes -from stream.visualization.node_hw_performances import ( - visualize_node_hw_performances_pickle, +from stream.visualization.cost_model_evaluation_lut import ( + visualize_cost_lut_pickle, ) from stream.workload.computation.computation_node import ComputationNode from stream.workload.onnx_workload import ComputationNodeWorkload @@ -37,7 +37,7 @@ def __init__( workload: ComputationNodeWorkload, accelerator: Accelerator, loma_lpf_limit: int, - node_hw_performances_path: str, + cost_lut_path: str, **kwargs: dict[str, Any], ): """ @@ -49,11 +49,8 @@ def __init__( self.workload = workload self.accelerator = accelerator self.loma_lpf_limit = loma_lpf_limit - self.node_hw_performances_path = node_hw_performances_path - if "visualize_node_hw_performances_path" in kwargs: - self.visualize_node_hw_performances_path = kwargs["visualize_node_hw_performances_path"] - else: - self.visualize_node_hw_performances_path = os.path.splitext(self.node_hw_performances_path)[0] + ".png" + self.cost_lut_path = cost_lut_path + self.visualize_cost_lut_path = os.path.splitext(self.cost_lut_path)[0] + ".png" self.loma_show_progress_bar: bool = kwargs.get("loma_show_progress_bar", False) # Extract all unique nodes that will have to be evaluated @@ -69,7 +66,7 @@ def __init__( self.valid_allocations[node] = node.possible_core_allocation # Initialize CostModelEvaluationLUT - self.node_hw_performances = CostModelEvaluationLUT(self.node_hw_performances_path) + self.cost_lut = CostModelEvaluationLUT(self.cost_lut_path) def run(self): logger.info("Start ZigZagCoreMappingEstimationStage.") @@ -85,22 +82,22 @@ def run(self): if core.operational_array.total_unit_count == 0: continue # If the (node, core) combination has already been optimized, we skip it - if self.node_hw_performances.has_cme(node, core): + if self.cost_lut.has_cme(node, core): continue # If an equal performance has already been computed, we take it - equal_node = self.node_hw_performances.get_equal_node(node) - equal_core = self.node_hw_performances.get_equal_core(equal_node, core) if equal_node else None + equal_node = self.cost_lut.get_equal_node(node) + equal_core = self.cost_lut.get_equal_core(equal_node, core) if equal_node else None if equal_node and equal_core: - cme = pickle_deepcopy(self.node_hw_performances.get_cme(equal_node, equal_core)) + cme = pickle_deepcopy(self.cost_lut.get_cme(equal_node, equal_core)) # Update the CME attributes for this node-core combination cme.layer.core_allocation = [core_id] cme.core_id = core_id - self.node_hw_performances.add_cme(node, core, cme, allow_overwrite=False) + self.cost_lut.add_cme(node, core, cme, allow_overwrite=False) continue else: node_duplicate = pickle_deepcopy(node) # Remove duplicate cores with same id in case the core definition has changed - self.node_hw_performances.remove_cores_with_same_id(node, core) + self.cost_lut.remove_cores_with_same_id(node, core) # We need to compute the optimal performance for this node-core combination # It's possible this node might not fully fit within the core's top level memories. # If so, we update the core @@ -122,33 +119,28 @@ def run(self): assert len(answers) == 1, "ZigZagCoreMappingEstimationStage's subflow returned more than one CME" cme: CostModelEvaluation = answers[0][0] # type: ignore node_duplicate.set_chosen_core_allocation(None) # Reset the node's chosen core allocation - self.node_hw_performances.add_cme(node, core, cme, allow_overwrite=False) - self.node_hw_performances.save() + self.cost_lut.add_cme(node, core, cme, allow_overwrite=False) + self.cost_lut.save() - self.visualize_node_hw_performances() + self.visualize_cost_lut() kwargs = self.kwargs.copy() kwargs["workload"] = self.workload kwargs["accelerator"] = self.accelerator - kwargs["node_hw_performances"] = self.node_hw_performances + kwargs["cost_lut"] = self.cost_lut logger.info("Finished ZigZagCoreMappingEstimationStage.") sub_stage = self.list_of_callables[0](self.list_of_callables[1:], **kwargs) for cme, extra_info in sub_stage.run(): yield cme, extra_info - def visualize_node_hw_performances(self): - if "visualize_node_hw_performances_path" in self.kwargs: - # Get the scale factors - scale_factors = { - n: len([cn for cn in self.workload.node_list if cn.has_same_performance(n)]) - for n in self.node_hw_performances.get_nodes() - } - # Run the visualization - visualize_node_hw_performances_pickle( - self.node_hw_performances, - scale_factors, - self.kwargs["visualize_node_hw_performances_path"], - ) + def visualize_cost_lut(self): + # Get the scale factors + scale_factors = { + n: len([cn for cn in self.workload.node_list if cn.has_same_performance(n)]) + for n in self.cost_lut.get_nodes() + } + # Run the visualization + visualize_cost_lut_pickle(self.cost_lut, scale_factors, self.visualize_cost_lut_path) def get_intra_core_mapping_flow(self, node: ComputationNode, too_large_operands: list[MemoryOperand], core_id: int): logger.info(f"Launching intra-core mapping optimization for {node} -> core {core_id} ...") diff --git a/stream/stages/set_fixed_allocation_performance.py b/stream/stages/set_fixed_allocation_performance.py index 4a59e16..8938833 100644 --- a/stream/stages/set_fixed_allocation_performance.py +++ b/stream/stages/set_fixed_allocation_performance.py @@ -3,11 +3,10 @@ from zigzag.cost_model.cost_model import CostModelEvaluation from zigzag.datatypes import MemoryOperand -from zigzag.mapping.data_movement import MemoryAccesses from stream.hardware.architecture.accelerator import Accelerator from stream.stages.stage import Stage, StageCallable -from stream.utils import CostModelEvaluationLUT, get_too_large_operands +from stream.utils import CostModelEvaluationLUT, get_required_offchip_bandwidth, get_too_large_operands from stream.workload.computation.computation_node import ComputationNode from stream.workload.onnx_workload import ComputationNodeWorkload @@ -21,13 +20,13 @@ def __init__( *, workload: ComputationNodeWorkload, accelerator: Accelerator, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, **kwargs: Any, ): super().__init__(list_of_callables, **kwargs) self.accelerator = accelerator self.workload = workload - self.node_hw_performances = node_hw_performances + self.cost_lut = cost_lut self.latency_attr = kwargs.get("latency_attr", "latency_total2") def run(self): @@ -39,7 +38,7 @@ def run(self): kwargs = self.kwargs.copy() kwargs["workload"] = self.workload kwargs["accelerator"] = self.accelerator - kwargs["node_hw_performances"] = self.node_hw_performances + kwargs["cost_lut"] = self.cost_lut sub_stage = self.list_of_callables[0]( self.list_of_callables[1:], **kwargs, @@ -56,15 +55,15 @@ def set_fixed_allocation_performance(self): core_id = node.chosen_core_allocation if core_id is None: raise ValueError(f"Node {node} has fixed allocation but the chosen_core_allocation was not set.") - equal_node = self.node_hw_performances.get_equal_node(node) + equal_node = self.cost_lut.get_equal_node(node) assert equal_node is not None, f"{node} has fixed allocation but no equal node found." core = self.accelerator.get_core(core_id) - cme = self.node_hw_performances.get_cme(equal_node, core) + cme = self.cost_lut.get_cme(equal_node, core) latency = getattr(cme, self.latency_attr) too_large_operands = get_too_large_operands(cme, self.accelerator, core_id=core_id) onchip_energy, offchip_energy = self.get_energy_distribution(cme, too_large_operands) # Get the required offchip bandwidth during the execution of the node for all directions - offchip_bandwidth = self.get_offchip_bandwidth(cme, too_large_operands) + offchip_bandwidth = get_required_offchip_bandwidth(cme, too_large_operands) self.set_hw_performance_node(node, onchip_energy, offchip_energy, latency, core_id) node.set_too_large_operands(too_large_operands.copy()) node.set_offchip_bandwidth(offchip_bandwidth) @@ -82,18 +81,6 @@ def get_energy_distribution( onchip_energy -= layer_operand_offchip_energy return onchip_energy, offchip_energy - def get_offchip_bandwidth( - self, cme: CostModelEvaluation, too_large_operands: list[MemoryOperand] - ) -> MemoryAccesses: - if not too_large_operands: - return MemoryAccesses(0, 0, 0, 0) - # If there was offchip memory added for some operands, get the offchip bandwidth required - assert self.accelerator.offchip_core_id is not None, "Off-chip core id is not set." - offchip_core = self.accelerator.get_core(self.accelerator.offchip_core_id) - offchip_instance = next(iter(offchip_core.mem_hierarchy_dict.values()))[-1].memory_instance - offchip_bandwidth = cme.get_total_inst_bandwidth(offchip_instance) - return offchip_bandwidth - @staticmethod def set_hw_performance_node( node: ComputationNode, diff --git a/stream/utils.py b/stream/utils.py index 1bfc435..286c19f 100644 --- a/stream/utils.py +++ b/stream/utils.py @@ -8,6 +8,7 @@ from zigzag.cost_model.cost_model import CostModelEvaluation from zigzag.datatypes import MemoryOperand from zigzag.hardware.architecture.accelerator import Accelerator as Core +from zigzag.mapping.data_movement import FourWayDataMoving from zigzag.parser.onnx.utils import get_onnx_tensor_type if TYPE_CHECKING: @@ -119,6 +120,17 @@ def get_unique_nodes(workload: "ComputationNodeWorkload") -> list["ComputationNo return unique_nodes +def get_required_offchip_bandwidth( + cme: CostModelEvaluation, too_large_operands: list[MemoryOperand] +) -> FourWayDataMoving: + if not too_large_operands: + return FourWayDataMoving(0, 0, 0, 0) + # If there was offchip memory added for some operands, get the offchip bandwidth required + offchip_level = cme.accelerator.get_memory_level(too_large_operands[0], -1) + req_offchip_bw = cme.get_total_inst_bandwidth(offchip_level) + return req_offchip_bw + + class CostModelEvaluationLUT: """A class to store the cost model evaluations in a look-up table. The look-up table is a dictionary with the following structure: diff --git a/stream/visualization/constraint_optimization.py b/stream/visualization/constraint_optimization.py index 0eebffa..71ed182 100644 --- a/stream/visualization/constraint_optimization.py +++ b/stream/visualization/constraint_optimization.py @@ -20,14 +20,14 @@ def visualize_waco( allocation: ALLOCATION_T, - node_hw_performances: CostModelEvaluationLUT, + cost_lut: CostModelEvaluationLUT, accelerator: Accelerator, fig_path: str, iterations: int, ): """ Allocation is a list of tuples, with each tuple being of form (timestep, allocation, node_id). Allocation is a core. - node_hw_performances is a nested dict storing for each node and each core the hardware performance. + cost_lut is a CostModelEvaluationLUT storing for each node and each core the hardware performance. """ # Extract the number of allocations (k splits) of all nodes k_splits: dict[int, list[Core]] = {} @@ -46,8 +46,8 @@ def visualize_waco( layer_ids.add(id[0]) ids.append(id) resources.add(a) - node = next(n for n in node_hw_performances.get_nodes() if n.id == id[0]) - latencies, _ = get_latencies([node], core_ids, accelerator, node_hw_performances) + node = next(n for n in cost_lut.get_nodes() if n.id == id[0]) + latencies, _ = get_latencies([node], core_ids, accelerator, cost_lut) nb_k_splits = len(k_splits[id]) lat = latencies[(node.id, a, nb_k_splits)] node_latencies[id, a] = lat @@ -66,7 +66,7 @@ def visualize_waco( starts[id, a] = start _, total_lat_str = calculate_total_latency(starts, timestep_latencies, node_timesteps, iterations) # Plot the nodes using Plotly rectangles - color_cycle = cycle(sample_colorscale("rainbow", np.linspace(0, 1, len(node_hw_performances.get_nodes())))) + color_cycle = cycle(sample_colorscale("rainbow", np.linspace(0, 1, len(cost_lut.get_nodes())))) colors = {layer_id: c for (layer_id, c) in zip(layer_ids, color_cycle)} fig = go.Figure() bars = [] diff --git a/stream/visualization/node_hw_performances.py b/stream/visualization/cost_model_evaluation_lut.py similarity index 97% rename from stream/visualization/node_hw_performances.py rename to stream/visualization/cost_model_evaluation_lut.py index 77746c4..b03164d 100644 --- a/stream/visualization/node_hw_performances.py +++ b/stream/visualization/cost_model_evaluation_lut.py @@ -51,7 +51,7 @@ def autolabel(rects, ax, indices=[], labels=[], offsets=None): index += 1 -def visualize_node_hw_performances_pickle(pickle_filepath, scale_factors=None, fig_path=None): +def visualize_cost_lut_pickle(pickle_filepath, scale_factors=None, fig_path=None): plt.rc("font", size=SMALL_SIZE) # controls default text sizes plt.rc("axes", titlesize=SMALL_SIZE) # fontsize of the axes title plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels @@ -140,4 +140,4 @@ def visualize_node_hw_performances_pickle(pickle_filepath, scale_factors=None, f axs[1].set_ylabel("Energy [pJ]") fig.tight_layout() plt.savefig(fig_path, bbox_inches="tight") - logger.info(f"Saved node_hw_performances visualization to: {fig_path}") + logger.info(f"Saved CostModelEvaluationLUT visualization to: {fig_path}") diff --git a/stream/visualization/memory_usage.py b/stream/visualization/memory_usage.py index 34ac184..3ebf087 100644 --- a/stream/visualization/memory_usage.py +++ b/stream/visualization/memory_usage.py @@ -1,3 +1,4 @@ +import logging from typing import TYPE_CHECKING import numpy as np @@ -8,6 +9,7 @@ if TYPE_CHECKING: from stream.cost_model.cost_model import StreamCostModelEvaluation +logger = logging.getLogger(__name__) SMALL_SIZE = 8 MEDIUM_SIZE = 10 @@ -156,4 +158,4 @@ def plot_memory_usage( # ax.set_xlabel("Cycles") # Set xlabel of last axis (bottom one) # plt.show(block=True) fig.savefig(fig_path) - print(f"Saved memory usage fig to {fig_path}") + logger.info(f"Saved memory usage fig to {fig_path}") diff --git a/stream/visualization/schedule.py b/stream/visualization/schedule.py index 68fdfaa..a48567d 100644 --- a/stream/visualization/schedule.py +++ b/stream/visualization/schedule.py @@ -15,6 +15,8 @@ from plotly.express.colors import sample_colorscale from zigzag.datatypes import LayerOperand +from stream.utils import CostModelEvaluationLUT +from stream.workload.computation.computation_node import ComputationNode from stream.workload.tensor import Tensor if TYPE_CHECKING: @@ -435,6 +437,7 @@ def get_communication_dicts(scme): Type=task_type, Activity=activity, Energy=energy, + LinkBandwidth=cl.bandwidth, ) dicts.append(d) return dicts @@ -447,7 +450,54 @@ def get_real_input_tensors(n, G): return inputs -def get_dataframe_from_scme(scme: "StreamCostModelEvaluation", layer_ids: list[int], add_communication: bool = False): +def get_spatial_utilizations( + scme: "StreamCostModelEvaluation", node: "ComputationNode", cost_lut: "CostModelEvaluationLUT" +): + if cost_lut: + equal_node = cost_lut.get_equal_node(node) + assert ( + equal_node + ), f"No equal node for {node} found in CostModelEvaluationLUT. Check LUT path (use the post-CO LUT when using CO)." + core = scme.accelerator.get_core(node.chosen_core_allocation) + cme = cost_lut.get_cme(equal_node, core) + return cme.mac_spatial_utilization, cme.mac_utilization1 + return np.nan, np.nan + + +def get_energy_breakdown( + scme: "StreamCostModelEvaluation", node: "ComputationNode", cost_lut: "CostModelEvaluationLUT" +): + if cost_lut: + equal_node = cost_lut.get_equal_node(node) + assert ( + equal_node + ), f"No equal node for {node} found in CostModelEvaluationLUT. Check LUT path (use the post-CO LUT when using CO)." + core = scme.accelerator.get_core(node.chosen_core_allocation) + cme = cost_lut.get_cme(equal_node, core) + total_ops = cme.layer.total_mac_count + en_total_per_op = cme.energy_total / total_ops + en_breakdown = cme.mem_energy_breakdown + en_breakdown_per_op = {} + energy_sum_check = 0 + for layer_op, energies_for_all_levels in en_breakdown.items(): + d = {} + mem_op = cme.layer.memory_operand_links[layer_op] + for mem_level_idx, en in enumerate(energies_for_all_levels): + mem_name = cme.accelerator.get_memory_level(mem_op, mem_level_idx).name + d[mem_name] = en / total_ops + energy_sum_check += en + en_breakdown_per_op[layer_op] = d + assert np.isclose(energy_sum_check, cme.mem_energy) + return en_total_per_op, en_breakdown_per_op + return np.nan, np.nan + + +def get_dataframe_from_scme( + scme: "StreamCostModelEvaluation", + layer_ids: list[int], + add_communication: bool = False, + cost_lut: "CostModelEvaluationLUT" = None, +): nodes = scme.workload.topological_sort() dicts = [] for node in nodes: @@ -459,6 +509,8 @@ def get_dataframe_from_scme(scme: "StreamCostModelEvaluation", layer_ids: list[i start = node.start end = node.end runtime = node.runtime + su_perfect_temporal, su_nonperfect_temporal = get_spatial_utilizations(scme, node, cost_lut) + en_total_per_op, en_breakdown_per_op = get_energy_breakdown(scme, node, cost_lut) energy = node.onchip_energy tensors = get_real_input_tensors(node, scme.workload) task_type = "compute" @@ -471,10 +523,14 @@ def get_dataframe_from_scme(scme: "StreamCostModelEvaluation", layer_ids: list[i Resource=f"Core {core_id}", Layer=layer, Runtime=runtime, + SpatialUtilization=su_perfect_temporal, + SpatialUtilizationWithTemporal=su_nonperfect_temporal, Tensors=tensors, Type=task_type, Activity=np.nan, Energy=energy, + EnergyTotalPerOp=en_total_per_op, + EnergyBreakdownPerOp=en_breakdown_per_op, ) dicts.append(d) if add_communication: @@ -509,7 +565,38 @@ def format_tensors(tensors: list[Tensor]): else: formatted_tensors.append(", ".join(map(str, tensor_list))) - return "
[".join(formatted_tensors) + "]
" + return "[
" + "
".join(formatted_tensors) + "
]" + + +def add_spatial_util_to_hovertext(hovertext: str, su_perfect_temporal: float, su_imperfect_temporal: float): + if not isnan(su_perfect_temporal): + hovertext += "
Spatial Utilization:
" + hovertext += f"    Without memory stalls: {su_perfect_temporal:.4f}
" + hovertext += f"    With memory stalls: {su_imperfect_temporal:.4f}
" + return hovertext + + +def add_energy_breakdown_to_hovertext( + hovertext: str, energy_total: float, energy_per_operation: float, energy_breakdown_per_op: dict +): + if not isnan(energy_per_operation): + hovertext += f"Energy total: {energy_total:.4e}
" + hovertext += f"Energy per operation: {energy_per_operation:.4e}
" + for layer_op, energy_dict in energy_breakdown_per_op.items(): + hovertext += f"Energy breakdown for {layer_op}:
" + for mem_level, en in energy_dict.items(): + hovertext += f"    {mem_level}: {en:.4e}
" + return hovertext + + +def add_activity_to_hovertext(hovertext: str, required_bandwidth: int, link_bandwidth: int): + if not isnan(required_bandwidth) and not isnan(link_bandwidth): + required_bandwidth = int(required_bandwidth) + link_bandwidth = int(link_bandwidth) + used_bandwidth = min(required_bandwidth, link_bandwidth) + hovertext += f"Required bandwidth: {required_bandwidth} bits/cc
" + hovertext += f"Used bandwidth: {used_bandwidth} bits/cc
" + return hovertext def visualize_timeline_plotly( @@ -518,10 +605,11 @@ def visualize_timeline_plotly( draw_communication: bool = True, fig_path: str = "outputs/schedule.html", layer_ids: list[int] | None = None, + cost_lut: CostModelEvaluationLUT = None, ): if not layer_ids: layer_ids = sorted(set(n.id for n in scme.workload.node_list)) - df = get_dataframe_from_scme(scme, layer_ids, draw_communication) + df = get_dataframe_from_scme(scme, layer_ids, draw_communication, cost_lut) # We get all the layer ids to get a color mapping for them layer_ids = sorted(list(set(df["Layer"].tolist()))) color_cycle = cycle(sample_colorscale("rainbow", np.linspace(0, 1, len(layer_ids)))) @@ -534,11 +622,19 @@ def visualize_timeline_plotly( sub_id = row["Sub_id"] start = row["Start"] runtime = row["Runtime"] + su_perfect_temporal = row["SpatialUtilization"] + su_imperfect_temporal = row["SpatialUtilizationWithTemporal"] energy = row["Energy"] + energy_total_per_op = row["EnergyTotalPerOp"] + energy_breakdown_per_op = row["EnergyBreakdownPerOp"] + activity = row["Activity"] + link_bandwidth = row["LinkBandwidth"] resource = row["Resource"] layer = row["Layer"] color = colors[layer] name = row["Task"] + if isinstance((row["Id"]), str) and isinstance(row["Sub_id"], str): + name += f" Id: {id} Sub_id: {sub_id}" legendgroup = f"Layer {layer}" legendgrouptitle_text = legendgroup tensors = format_tensors(row["Tensors"]) @@ -551,13 +647,10 @@ def visualize_timeline_plotly( f"Runtime: {runtime:.2e}
" f"Start: {start:.4e}
" f"End: {start+runtime:.4e}
" - f"Energy: {energy:.4e}" ) - if not isnan(row["Activity"]): - activity = int(row["Activity"]) - hovertext += f"
Activity: {activity} %" - if isinstance((row["Id"]), str) and isinstance(row["Sub_id"], str): - hovertext += f"
Id: {id} Sub_id: {sub_id}" + hovertext = add_activity_to_hovertext(hovertext, activity, link_bandwidth) + hovertext = add_spatial_util_to_hovertext(hovertext, su_perfect_temporal, su_imperfect_temporal) + hovertext = add_energy_breakdown_to_hovertext(hovertext, energy, energy_total_per_op, energy_breakdown_per_op) bar = go.Bar( base=[start], x=[runtime], diff --git a/stream/workload/computation/computation_node.py b/stream/workload/computation/computation_node.py index 70f2fee..50c4030 100644 --- a/stream/workload/computation/computation_node.py +++ b/stream/workload/computation/computation_node.py @@ -264,7 +264,7 @@ def reshape_operand_tensor(self, tensor: NodeTensor, operand: LayerOperand): new_shape = self.operand_tensor_reshape[operand] return tensor.reshape(new_shape) - def set_too_large_operands(self, too_large_operands: list[LayerOperand]): + def set_too_large_operands(self, too_large_operands: list[MemoryOperand]): self.too_large_operands = too_large_operands def update_loop_ranges(self, new_ranges: LoopRanges): diff --git a/stream/workload/node.py b/stream/workload/node.py index 9f956f7..06f216a 100644 --- a/stream/workload/node.py +++ b/stream/workload/node.py @@ -1,6 +1,6 @@ from abc import ABCMeta -from zigzag.mapping.data_movement import MemoryAccesses +from zigzag.mapping.data_movement import FourWayDataMoving from zigzag.workload.layer_node_abc import LayerNodeABC @@ -51,7 +51,7 @@ def __init__( self.data_produced_unique = 0 # will be set together with the core allocation - self.offchip_bw = MemoryAccesses(0, 0, 0, 0) + self.offchip_bw = FourWayDataMoving(0, 0, 0, 0) def get_total_energy(self) -> float: """Get the total energy of running this node, including off-chip energy.""" @@ -131,7 +131,7 @@ def has_end(self) -> bool: """ return self.end is not None - def set_offchip_bandwidth(self, offchip_bw: MemoryAccesses): + def set_offchip_bandwidth(self, offchip_bw: FourWayDataMoving): self.offchip_bw = offchip_bw def __str__(self):