From 170b46e8b146b7ed93cd5f28d194bf8412ba0cbf Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 11 Sep 2024 16:21:43 -0700 Subject: [PATCH] Add conditional on torch version for scaled_dot_product_attention (#6517) Changes from #4724 broke support for torch<2.0 in the flops profiler as the scaled_dot_product_attention [wasn't added](https://pytorch.org/docs/2.0/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention) until a beta version in torch 2.0 Resolved: #5534 Todo: - [ ] Test this - [ ] Issue resolution with users. --- deepspeed/profiling/flops_profiler/profiler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index de847e59e82e..96306184e42c 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -15,6 +15,7 @@ from deepspeed.utils import logger from deepspeed.moe.layer import MoE from deepspeed.utils.timer import FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER +from deepspeed.utils.torch import required_torch_version Tensor = torch.Tensor @@ -908,8 +909,9 @@ def _patch_functionals(): # embedding F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) - # attn - F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute) + # attn - scaled_dot_product_attention added in torch 2.0+ + if required_torch_version(min_version=2.0): + F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute) def _patch_tensor_methods():