diff --git a/models/common.py b/models/common.py index 30e7319f98a0..2d24672a6b44 100644 --- a/models/common.py +++ b/models/common.py @@ -29,11 +29,6 @@ def autopad(k, p=None): # kernel, padding return p -def DWConv(c1, c2, k=1, s=1, act=True): - # Depth-wise convolution function - return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) - - class Conv(nn.Module): # Standard convolution def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups @@ -49,11 +44,10 @@ def forward_fuse(self, x): return self.act(self.conv(x)) -class DWConvClass(Conv): +class DWConv(Conv): # Depth-wise convolution class def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups - super().__init__(c1, c2, k, s, act) - self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False) + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act) class TransformerLayer(nn.Module): diff --git a/models/yolo.py b/models/yolo.py index 9f05c8329f38..380f3401e5b9 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -202,7 +202,7 @@ def _print_biases(self): def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers LOGGER.info('Fusing layers... ') for m in self.model.modules(): - if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'): + if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.forward_fuse # update forward