Skip to content

Commit

Permalink
removing unnessary avoid_bfloat16_autocast_context (#6481)
Browse files Browse the repository at this point in the history
Signed-off-by: Dima Rekesh <[email protected]>
  • Loading branch information
bmwshop authored May 3, 2023
1 parent d3dee47 commit 72994bf
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions nemo/collections/asr/parts/submodules/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch.nn import LayerNorm

from nemo.collections.asr.parts.submodules.causal_convs import CausalConv2D
from nemo.utils import avoid_bfloat16_autocast_context


class StackingSubsampling(torch.nn.Module):
Expand Down Expand Up @@ -265,13 +264,7 @@ def forward(self, x, lengths):
)
x = x.unsqueeze(1)

if self._subsampling in ['striding', 'dw_striding']:
# added in order to prevent slowdown in torch.nn.Conv2d with bfloat16 / CUDNN v8 API
# to be removed once the above is fixed in cudnn
with avoid_bfloat16_autocast_context():
x = self.conv(x)
else:
x = self.conv(x)
x = self.conv(x)

b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
Expand Down

0 comments on commit 72994bf

Please sign in to comment.