From 005bc473df5d5be3f261fa2d883e027608deb12d Mon Sep 17 00:00:00 2001 From: Yun Chen Date: Tue, 2 Jan 2018 19:34:32 +0800 Subject: [PATCH] make weight initialization optional to speed vgg-construction (#377) --- torchvision/models/vgg.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 5da34d536c3..4f112d96772 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -23,7 +23,7 @@ class VGG(nn.Module): - def __init__(self, features, num_classes=1000): + def __init__(self, features, num_classes=1000, init_weights=True): super(VGG, self).__init__() self.features = features self.classifier = nn.Sequential( @@ -35,7 +35,8 @@ def __init__(self, features, num_classes=1000): nn.Dropout(), nn.Linear(4096, num_classes), ) - self._initialize_weights() + if init_weights: + self._initialize_weights() def forward(self, x): x = self.features(x) @@ -88,6 +89,8 @@ def vgg11(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['A']), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) @@ -100,6 +103,8 @@ def vgg11_bn(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) @@ -112,6 +117,8 @@ def vgg13(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['B']), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) @@ -124,6 +131,8 @@ def vgg13_bn(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) @@ -136,6 +145,8 @@ def vgg16(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['D']), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) @@ -148,6 +159,8 @@ def vgg16_bn(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) @@ -160,6 +173,8 @@ def vgg19(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['E']), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) @@ -172,6 +187,8 @@ def vgg19_bn(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ + if pretrained: + kwargs['init_weights'] = False model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))