diff --git a/hta/analyzers/critical_path_analysis.py b/hta/analyzers/critical_path_analysis.py index 307d890..68525a7 100644 --- a/hta/analyzers/critical_path_analysis.py +++ b/hta/analyzers/critical_path_analysis.py @@ -11,7 +11,7 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum -from functools import cached_property +from functools import cached_property, wraps from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -34,6 +34,24 @@ CP_LAUNCH_EDGE_ENV = "CRITICAL_PATH_ADD_ZERO_WEIGHT_LAUNCH_EDGE" CP_LAUNCH_EDGE_SHOW_ENV = "CRITICAL_PATH_SHOW_ZERO_WEIGHT_LAUNCH_EDGE" +PROFILE_TIMES = {} + + +# Enable per function timing +def timeit(func): + global PROFILE_TIMES + + @wraps(func) + def timeit_wrapper(*args, **kwargs): + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + total_time = end_time - start_time + PROFILE_TIMES[func.__name__] = total_time + return result + + return timeit_wrapper + @dataclass class CPNode: @@ -181,8 +199,7 @@ def __init__( # this is the attributed event for an edge self.edge_to_event_map: Dict[Tuple[int, int], int] = {} - cg = CallGraph(t, ranks=[rank]) - self._construct_graph(cg) + self._construct_graph() def _add_node(self, node: CPNode) -> int: """Adds a node to the graph. @@ -357,7 +374,7 @@ def get_edges_attributed_to_event(self, ev_idx: int) -> List[CPEdge]: """ return self._event_to_attributed_edges_map.get(ev_idx, []) - def _construct_graph(self, cg: CallGraph) -> None: + def _construct_graph(self) -> None: if self._add_zero_weight_launch_edges(): logger.info( "Adding zero weight launch edges to retain causality in subsequent simulations." @@ -365,13 +382,11 @@ def _construct_graph(self, cg: CallGraph) -> None: self._create_event_nodes() - cpu_call_stacks = ( - csg for csg in cg.call_stacks if csg.device_type == DeviceType.CPU - ) - for csg in cpu_call_stacks: - self._construct_graph_from_call_stack(csg) + self._construct_graph_from_call_stacks() + self._construct_graph_from_kernels() + @timeit def _create_event_nodes(self) -> None: """Generates a start and end node for every event we would like to represent in our graph""" @@ -417,6 +432,17 @@ def create_cpnode(row): self.event_to_end_node_map = dict(zip(_df["ev_idx"], _df["idx"])) assert len(self.event_to_start_node_map) == len(self.event_to_end_node_map) + @timeit + def _construct_graph_from_call_stacks(self) -> None: + cg = CallGraph(self.t, ranks=[self.rank]) + + cpu_call_stacks = ( + csg for csg in cg.call_stacks if csg.device_type == DeviceType.CPU + ) + + for csg in cpu_call_stacks: + self._construct_graph_from_call_stack(csg) + def _construct_graph_from_call_stack( self, csg: CallStackGraph, link_operators: bool = True ) -> None: @@ -577,6 +603,7 @@ def _get_cuda_event_to_stream_df(self) -> pd.DataFrame: # Sanity checking to do. return cuda_record_stream_df.set_index("correlation") + @timeit def _get_cuda_event_record_df(self) -> Optional[pd.DataFrame]: """For Event based synchronization we need to track the last kernel/memcpy launched on a CPU thread just before the cudaEventRecord @@ -689,6 +716,7 @@ def find_previous_launch(gpu, stream): return cuda_record_calls + @timeit def _get_cuda_stream_wait_event_df(self) -> Optional[pd.DataFrame]: """For Event based synchronization we need to track the next kernel/memcpy launched on a CPU thread just after cudaStreamWaitEvent @@ -850,6 +878,7 @@ def _add_kernel_launch_delay_edge( ) return True + @timeit def _construct_graph_from_kernels(self) -> None: """Create nodes and edges for GPU kernels""" sym_id_map = self.symbol_table.get_sym_id_map() @@ -1488,6 +1517,7 @@ def critical_path_analysis( CPGraph is also a subinstance of a networkx.DiGraph. Run 'CPGraph?' for more info and APIs. """ + global PROFILE_TIMES t0 = time.perf_counter() trace_df: pd.DataFrame = t.get_trace(rank) sym_index = t.symbol_table.get_sym_id_map() @@ -1572,6 +1602,9 @@ def critical_path_analysis( t2 = time.perf_counter() logger.info(f"CPGraph construction took {t2 - t1:.2f} seconds") + for func, total_time in PROFILE_TIMES.items(): + logger.info(f" Function {func} Took {total_time:.4f} seconds") + return cp_graph, cp_graph.critical_path() @staticmethod