Skip to content

Commit

Permalink
Add conditional on torch version for scaled_dot_product_attention (#6517
Browse files Browse the repository at this point in the history
)

Changes from #4724 broke support for torch<2.0 in the flops profiler as
the scaled_dot_product_attention [wasn't
added](https://pytorch.org/docs/2.0/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention)
until a beta version in torch 2.0

Resolved: #5534

Todo:
- [ ] Test this
- [ ] Issue resolution with users.
  • Loading branch information
loadams committed Sep 11, 2024
1 parent 659f6be commit 170b46e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.utils import logger
from deepspeed.moe.layer import MoE
from deepspeed.utils.timer import FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER
from deepspeed.utils.torch import required_torch_version

Tensor = torch.Tensor

Expand Down Expand Up @@ -908,8 +909,9 @@ def _patch_functionals():
# embedding
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)

# attn
F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute)
# attn - scaled_dot_product_attention added in torch 2.0+
if required_torch_version(min_version=2.0):
F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute)


def _patch_tensor_methods():
Expand Down

0 comments on commit 170b46e

Please sign in to comment.