Skip to content

susan0199/StacNAS

Repository files navigation

The algorithm proposed a stable and consistent optimization strategy for differentiable architecture search. The mean accuracy is largely improved over DARTS and variance over repeated experiment (withou multiple run and pick the best strategy) is significantly reduced.

two_stage DAG
Two Stages

Note: it is optional to keep the search stage operation ZERO in the final architecture, which can serve as an attention to learn the relative contribution of the two paths passing to the same node

StacNAS
Two Stages

Requirements

Python >= 3.5.5, PyTorch >= 1.0.0, torchvision >= 0.2.0

Datasets

CIFAR-10 and CIFAR-100 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.

Pretrained models

The easiest way to get started is to evaluate our pretrained StacNAS models.

CIFAR-10/CIFAR-100 (cifar10.pth.tar) (cifar100.pth.tar)

python test.py \
    --name job_name \
    --dataset cifar10 \
    --data_dir /path/to/dataset/ \
    --save_dir /path/to/results/ \
    --seed 1 \
    --stage test \
    --batch_size 96 \
    --load_model_dir "models/cifar10.pth.tar"
  • Expected result: 1.88% test error rate with 4.48M model params. For CIFAR100, replace --dataset cifar10 with --dataset cifar100
  • Expected result: 12.9% test error rate with 4.36M model params.

IMAGENET(mobile setting) (imagenet.pth.tar)

  • Expected result: 1.88% test error rate with 4.48M model params. (Code coming soon)

Architecture search stage1

To carry out architecture search using

python search.py \
    --name job_name \
    --dataset cifar10 \
    --data_dir /path/to/dataset/ \
    --save_dir /path/to/results/ \
    --seed 6 \
    --stage search1 \
    --batch_size 64 \
    --init_channels 16 \
    --num_cells 14 \
    --epochs 80 \
    --alpha_share

Note that train_ratio is the train-valid split ratio, train_ratio=1 means we use all the 50000 training images to search the architecure and use the final result as the best results.

Architecture search stage2

To carry out architecture search2, run

python search.py \
    --name job_name \
    --dataset cifar10 \
    --data_dir /path/to/dataset/ \
    --save_dir /path/to/results/ \
    --seed 6 \
    --stage search2 \
    --batch_size 64 \
    --init_channels 16 \
    --num_cells 20 \
    --epochs 80 \
    --alpha_share

Note that train_ratio is the train-valid split ratio, train_ratio=1 means we use all the 50000 training images to search the architecure and use the final result as the best results.

Architecture evaluation (using full-sized models)

To evaluate our best cells by training from scratch, run

python augment.py \
    --name job_name \
    --dataset cifar10 \
    --data_dir /path/to/dataset/ \
    --save_dir /path/to/results/ \
    --seed 6 \
    --stage augment \
    --batch_size 96 \
    --init_channels 36 \
    --num_cells 20 \
    --grad_clip 5 \
    --aux_weight 0.4 \
    --epochs 600 \
    --alpha_share

Note that aux_weight=0.4 means add an auxiliary head at the 2/3 position of the network, and add aux_weight * auxiliary_loss to the final loss.

Feature clustering

To visualizatize feature clustering, run

python feature.py \
    --name job_name \
    --dataset cifar10 \
    --data_dir /path/to/dataset/ \
    --save_dir /path/to/results/ \
    --seed 1 \
    --stage search1 \
    --batch_size 64

Note that feature.py will load the saved model of search stage1.

Run all

To run all the above directly, run

python run_all.py

Note that you should modify the parameters in file run_all.py before running.

Results

CIFAR10 Benchmark
CIFAR10 Benchmark: We reported the mean and standard deviation over 8 single runs (i.e. search once and evaluate once, repeats the procedure 8 times with different seeds). base represents for that the final architecture training procedure is exactly the same with DARTS for a fair comparison; fancy represents for that the finaltraining adopts training tricks for comparison with other NAS methods.

CIFAR10 Benchmark
IMAGENET mobile setting Benchmark

Can architecture weight really learn the relative importance?

Alpha prediction correlation with stand-alone model accuracy
Two Stages

Found architectures

cifar10_normal cifar10_reduce
CIFAR-10

cifar100_normal cifar100_reduce
CIFAR-100

imagenet_normal imagenet_reduce
ImageNet

Reference

https://github.com/quark0/darts

Releases

No releases published

Packages

No packages published

Languages