Skip to content

DingXiaoH/RepOptimizers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

49 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RepOptimizers (ICLR 2023)

Accepted by ICLR 2023!

(update: released RepOpt-GhostNet model and reproducible code)

This is the official repository of Re-parameterizing Your Optimizers rather than Architectures.

If you find the paper or this repository helpful, please consider citing

    @article{ding2022re,
    title={Re-parameterizing Your Optimizers rather than Architectures},
    author={Ding, Xiaohan and Chen, Honghao and Zhang, Xiangyu and Huang, Kaiqi and Han, Jungong and Ding, Guiguang},
    journal={arXiv preprint arXiv:2205.15242},
    year={2022}
    }

Highlights

RepOptimizer and RepOpt-VGG have been used in YOLOv6 (paper, code) and deployed in business. The methodology of Structural Re-parameterization also plays a critical role in YOLOv7 (paper, code).

Catalog

  • Code
  • PyTorch pretrained models
  • PyTorch training code

Verify the equivalency GR = CSLA

Tired of reading proof? We provide a script demonstrating the equivalency of GR = CSLA in both SGD and AdamW cases.

You may run the following script to verify the equivalency (GR = CSLA) on an experimental model. You can run without a GPU.

python check_equivalency.py

We provide a more complicated example (RepGhostNet) to verify the equivalency. Just run

python check_equivalency_repghost.py

Design

RepOptimizers currently support two update rules (SGD with momentum and AdamW) and two models (RepOpt-VGG and RepOpt-GhostNet). While re-designing the code of RepOptimizer, I decided to separate the update-rule-related behaviors and model-specific behaviors.

The key components of the new implementation (please see repoptimizer/) include

Model: repoptvgg_model.py and repoptghostnet_model.py define the model architecutres, including the target and search structures.

Model-specific Handler: a RepOptimizerHandler defines the model-specific behavior of RepOptimizer given the searched scales, which include 1) re-initializing the model (i.e., Rule of Initialization) and 2) generating the Grad Mults (i.e., Rule of Iteration).

For example, RepOptVGGHandler (see repoptvgg_impl.py) implements the formulas presented in the paper.

Update rule: repoptimizer_sgd.py and repoptimizer_adamw.py define the behavior of RepOptimizers based on different update rules. The differences between a RepOptimizer and its regular counterpart (torch.optim.SGD or torch.optim.AdamW) include

  1. RepOptimizers take one more argument, grad_mult_map, which is the output from RepOptimizerHandler and will be stored in memory. It is a dict where the key is the parameter (torch.nn.Parameter) and the value is the corresponding Grad Mult (torch.Tensor).

  2. In the step function, RepOptimizers will use the Grad Mults properly. For SGD, please see here. For AdamW, please see here and here.

Pre-trained RepOpt-VGG Models

We have released the models pre-trained with this codebase.

name ImageNet-1K acc #params download
RepOpt-VGG-B1 78.62 51.8M Google Drive, Baidu Cloud
RepOpt-VGG-B2 79.68 80.3M Google Drive, Baidu Cloud
RepOpt-VGG-L1 79.82 76.0M Google Drive, Baidu Cloud
RepOpt-VGG-L2 80.47 118.1M Google Drive, Baidu Cloud

Use cases

The following cases use RepOpt-VGG-B1 as an example. You may replace RepOpt-VGG-B1 by RepOpt-VGG-B2, RepOpt-VGG-L1, or RepOpt-VGG-L2 as you wish.

Evaluation

You may test our released models by

python -m torch.distributed.launch --nproc_per_node {your_num_gpus} --master_port 12349 main_repopt.py --arch RepOpt-VGG-B1-target --tag test --eval --resume RepOpt-VGG-B1-acc78.62.pth --data-path /path/to/imagenet --batch-size 32 --opts DATA.DATASET imagenet

Training

To reproduce RepOpt-VGG-B1, you may build a RepOptimizer with our released constants RepOpt-VGG-B1-scales.pth and train a RepOpt-VGG-B1 with it.

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch RepOpt-VGG-B1-target --batch-size 32 --tag experiment --scales-path RepOpt-VGG-B1-scales.pth --opts TRAIN.EPOCHS 120 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET raug15 DATA.DATASET imagenet

The log and weights will be saved to output/RepOpt-VGG-B1-target/experiment/

Hyper-Search

Besides using our released scales, you may Hyper-Search by

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/search/dataset --arch RepOpt-VGG-B1-hs --batch-size 32 --tag search --opts TRAIN.EPOCHS 240 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 10 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET cf100 TRAIN.CLIP_GRAD 5.0

(Note that since the model seems too big for such a small dataset, we use grad clipping to stablize the training. But do not use grad clipping while training with RepOptimizer! That would break the equivalency.)

The weights of the search model will be saved to output/RepOpt-VGG-B1-hs/search/latest.pth

Then you may train with it by

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch RepOpt-VGG-B1-target --batch-size 32 --tag experiment --scales-path output/RepOpt-VGG-B1-hs/search/latest.pth --opts TRAIN.EPOCHS 120 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET raug15 DATA.DATASET imagenet

Use RepOptimizer and RepOpt-VGG in your code

Given the searched scales (saved in a .pth file), you may conveniently build a RepOptimizer and a RepOpt-VGG model and use them just like you use the common optimizers and models.

Please see build_RepOptVGG_and_SGD_optimizer_from_pth here.

Another example: RepOpt-GhostNet

RepGhostNet is a recently proposed lightweight model built with Structural Re-parameterization. The training-time forward function of a block can be formulated as output=batch_norm(depthwise_convolution(x)) + batch_norm(x). With RepOptimizer, the parallel batch norm (referred to as "fusion layer" in the RepGhostNet paper) can be removed even during training. Similar to RepVGG and RepOpt-VGG, we design the CSLA model by replacing the batch norm layers with constant or trainable scaling layers and the Grad Mults of RepOptimizer accordingly.

name ImageNet-1K acc download
RepGhostNet-0.5x (our implementation) 66.51 Google Drive, Baidu Cloud
RepOpt-GhostNet-0.5x 66.50 Google Drive, Baidu Cloud

We trained the original RepGhostNet-0.5x with this codebase and got a top-1 accuracy of 66.51%.

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch ghost-rep --batch-size 128 --tag reproduce --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET imagenet TRAIN.OPTIMIZER.NAME sgd TRAIN.WARMUP_LR 1e-4

The log and weights will be saved to output/ghost-rep/reproduce/

You may reproduce RepOpt-GhostNet with our released scales

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch ghost-target --batch-size 128 --tag reproduce --scales-path RepOptGhostNet_0_5x_scales.pth --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET imagenet TRAIN.OPTIMIZER.NAME sgd TRAIN.WARMUP_LR 1e-4

Or first Hyper-Search and then use the searched scales

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/cifar100 --arch ghost-hs --batch-size 128 --tag reproduce --opts TRAIN.EPOCHS 600 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 10 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET cf100 TRAIN.CLIP_GRAD 5.0

python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repopt.py --data-path /path/to/imagenet --arch ghost-target --batch-size 128 --tag reproduce --scales-path output/ghost-hs/reproduce/latest.pth --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.6 TRAIN.WEIGHT_DECAY 1e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 DATA.DATASET imagenet TRAIN.OPTIMIZER.NAME sgd TRAIN.WARMUP_LR 1e-4

License

This project is released under the MIT license. Please see the LICENSE file for more information.

Contact

[email protected] (The original Tsinghua mailbox [email protected] will expire in several months)

Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en

Homepage: https://dingxiaohan.xyz/

My open-sourced papers and repos:

The Structural Re-parameterization Universe:

  1. RepLKNet (CVPR 2022) Powerful efficient architecture with very large kernels (31x31) and guidelines for using large kernels in model CNNs
    Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
    code.

  2. RepOptimizer (ICLR 2023) uses Gradient Re-parameterization to train powerful models efficiently. The training-time RepOpt-VGG is as simple as the inference-time. It also addresses the problem of quantization.
    Re-parameterizing Your Optimizers rather than Architectures
    code.

  3. RepVGG (CVPR 2021) A super simple and powerful VGG-style ConvNet architecture. Up to 84.16% ImageNet top-1 accuracy!
    RepVGG: Making VGG-style ConvNets Great Again
    code.

  4. RepMLP (CVPR 2022) MLP-style building block and Architecture
    RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality
    code.

  5. ResRep (ICCV 2021) State-of-the-art channel pruning (Res50, 55% FLOPs reduction, 76.15% acc)
    ResRep: Lossless CNN Pruning via Decoupling Remembering and Forgetting
    code.

  6. ACB (ICCV 2019) is a CNN component without any inference-time costs. The first work of our Structural Re-parameterization Universe.
    ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks.
    code.

  7. DBB (CVPR 2021) is a CNN component with higher performance than ACB and still no inference-time costs. Sometimes I call it ACNet v2 because "DBB" is 2 bits larger than "ACB" in ASCII (lol).
    Diverse Branch Block: Building a Convolution as an Inception-like Unit
    code.

Model compression and acceleration:

  1. (CVPR 2019) Channel pruning: Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure
    code

  2. (ICML 2019) Channel pruning: Approximated Oracle Filter Pruning for Destructive CNN Width Optimization
    code

  3. (NeurIPS 2019) Unstructured pruning: Global Sparse Momentum SGD for Pruning Very Deep Neural Networks
    code

About

Official repo of RepOptimizers and RepOpt-VGG

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages