diff --git a/byte_micro_perf/backends/AMD/backend_amd.py b/byte_micro_perf/backends/AMD/backend_amd.py index 2a109730..57b8bf62 100644 --- a/byte_micro_perf/backends/AMD/backend_amd.py +++ b/byte_micro_perf/backends/AMD/backend_amd.py @@ -296,4 +296,7 @@ def setup_2d_group(self): torch.distributed.barrier() def destroy_process_group(self): - dist.destroy_process_group() \ No newline at end of file + dist.destroy_process_group() + + def barier(self): + dist.barrier(self.group) \ No newline at end of file diff --git a/byte_micro_perf/backends/GPU/backend_gpu.py b/byte_micro_perf/backends/GPU/backend_gpu.py index 9f09b781..19bfb94e 100644 --- a/byte_micro_perf/backends/GPU/backend_gpu.py +++ b/byte_micro_perf/backends/GPU/backend_gpu.py @@ -304,4 +304,7 @@ def setup_2d_group(self): torch.distributed.barrier() def destroy_process_group(self): - dist.destroy_process_group() \ No newline at end of file + dist.destroy_process_group() + + def barier(self): + dist.barrier(self.group) \ No newline at end of file diff --git a/byte_micro_perf/backends/backend.py b/byte_micro_perf/backends/backend.py index d8c041b6..e684ab41 100644 --- a/byte_micro_perf/backends/backend.py +++ b/byte_micro_perf/backends/backend.py @@ -92,6 +92,9 @@ def setup_2d_group(self): def destroy_process_group(self): pass + @abstractmethod + def barier(self): + pass # communication ops @@ -229,6 +232,11 @@ def perf(self, input_shapes: List[List[int]], dtype): for _ in range(num_warm_up): self._run_operation(self.op, tensor_list[0]) + + # ccl ops need barrier + if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]: + self.barier() + # test perf num_test_perf = 5 self.device_synchronize() @@ -241,7 +249,6 @@ def perf(self, input_shapes: List[List[int]], dtype): self.device_synchronize() end_time = time.perf_counter_ns() - prefer_iterations = self.iterations max_perf_seconds = 10.0 op_duration = (end_time - start_time) / num_test_perf / 1e9 @@ -250,6 +257,11 @@ def perf(self, input_shapes: List[List[int]], dtype): else: prefer_iterations = min(max(int(max_perf_seconds // op_duration), 10), self.iterations) + + # ccl ops need barrier + if self.op_name in ["allreduce", "allgather", "reducescatter", "alltoall", "broadcast", "p2p"]: + self.barier() + # perf self.device_synchronize() start_time = time.perf_counter_ns()