Skip to content

Commit

Permalink
Support Activation Checkpointing for ConvNeXt
Browse files Browse the repository at this point in the history
  • Loading branch information
nijkah committed Nov 1, 2022
1 parent 91b85bb commit 61af216
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions mmcls/models/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
build_norm_layer)
from mmcv.runner import BaseModule
Expand Down Expand Up @@ -77,8 +78,11 @@ def __init__(self,
mlp_ratio=4.,
linear_pw_conv=True,
drop_path_rate=0.,
layer_scale_init_value=1e-6):
layer_scale_init_value=1e-6,
with_cp=False):
super().__init__()
self.with_cp = with_cp

self.depthwise_conv = nn.Conv2d(
in_channels,
in_channels,
Expand Down Expand Up @@ -108,24 +112,33 @@ def __init__(self,
drop_path_rate) if drop_path_rate > 0. else nn.Identity()

def forward(self, x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)

if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
def _inner_forward(x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)

x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)

if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
if self.linear_pw_conv:
x = x.permute(0, 3, 1, 2) # permute back

if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))

x = shortcut + self.drop_path(x)
return x

if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)

x = shortcut + self.drop_path(x)
return x


Expand Down Expand Up @@ -206,6 +219,7 @@ def __init__(self,
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
with_cp=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)

Expand Down Expand Up @@ -288,8 +302,8 @@ def __init__(self,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value)
for j in range(depth)
layer_scale_init_value=layer_scale_init_value,
with_cp=with_cp) for j in range(depth)
])
block_idx += depth

Expand Down

0 comments on commit 61af216

Please sign in to comment.