Skip to content

TooTouch/BalancedSoftmax

Repository files navigation

BalancedSoftmax

Balanced Softmax for classification.

This repository only considers Balanced Softmax.

Environment

I use a docker image. nvcr.io/nvidia/pytorch:22.12-py3

pip install -r requirements.txt

Datasets

datasets/build.py

  • CIFAR10-LT
  • CIFAR100-LT
from datasets import CIFAR10LT

trainset = CIFAR10LT(
    root             = '/datasets/CIFAR10',
    train            = True,
    download         = True,
    imb_type         = 'exp',
    imbalance_factor = 200,
)

print(trainset.num_per_cls)
>> {0: 5000,
    1: 2775,
    2: 1540,
    3: 854,
    4: 474,
    5: 263,
    6: 146,
    7: 81,
    8: 45,
    9: 25}

Balanced Softmax

losses.py

from losses import BalancedSoftmax

num_per_cls = list(trainset.num_per_cls.values())
criterion = BalancedSoftmax(num_per_cls=num_per_cls)

Experiments

1. Experiment setting

configs.yaml

DEFAULT:
  seed: 0
  savedir: ./results
  exp_name: CE-IF_1
DATASET:
  datadir: /datasets
  batch_size: 32
  test_batch_size: 2048
  num_workers: 12
  imbalance_type: null
  imbalance_factor: 1
  aug_info:
    - RandomCrop
    - RandomHorizontalFlip
LOSS:
  name: CrossEntropyLoss
OPTIMIZER:
  name: SGD
  lr: 0.1
SCHEDULER:
  sched_name: cosine_annealing
  params:
    t_mult: 1
    eta_min: 0.00001
TRAIN:
  epochs: 50
  grad_accum_steps: 1
  mixed_precision: fp16
  log_interval: 10
  ckp_metric: bcr
  wandb:
    use: true
    entity: tootouch
    project_name: Balanced Softmax
MODEL:
  name: resnet18
  pretrained: false

2. Run

run.sh

dataname='CIFAR10LT CIFAR100LT'
IF='1 10 50 100 200'
losses='CrossEntropyLoss BalancedSoftmax'

for d in $dataname
do
    for f in $IF
    do
        for l in $losses
        do
            if [ $f == '1' ] && [ $l == 'BalancedSoftmax' ]; then
                continue
            else
                echo "dataset: $d, loss: $l, IF: $f"
                python main.py --config configs.yaml \
                            DEFAULT.exp_name $l-IF_$f \
                            DATASET.name $d \
                            DATASET.imbalance_type exp \
                            DATASET.imbalance_factor $f \
                            LOSS.name $l
            fi
        done
    done
done

3. Results

3.1 Imbalance type - exp

Experiments log [ wandb ]


Figure 1. Imbalance factor에 따른 실험 결과

Table 1. Imbalance factor에 따른 실험 결과
Dataset CIFAR10LT CIFAR100LT
Imbalance factor 1 10 50 100 200 1 10 50 100 200
CrossEntropyLoss 0.9283 0.8717 0.7779 0.7065 0.6426 0.7313 0.5865 0.4544 0.4060 0.3492
BalancedSoftmax 0.8694 0.7992 0.7601 0.7034 0.5999 0.4845 0.4447 0.3823


Figure 2. Imbalance factor에 따른 실험 결과 class별 성능


Figure 3. CIFAR10LT에 대한 cross entropy와 balanced softmax 간 confusion matrix 비교. Imbalance factor(IF)는 200.

3.2 Imbalance type - step


Figure 4. Imbalance factor에 따른 실험 결과

Table 1. Imbalance factor에 따른 실험 결과
Dataset CIFAR10LT CIFAR100LT
Imbalance factor 1 10 50 100 200 1 10 50 100 200
CrossEntropyLoss 0.9283 0.8525 0.7078 0.6421 0.5570 0.7313 0.5696 0.4440 0.4067 0.3921
BalancedSoftmax 0.8762 0.8027 0.7633 0.7070 0.6058 0.5202 0.4715 0.4301


Figure 5. Imbalance factor에 따른 실험 결과 class별 성능


Figure 6. CIFAR10LT에 대한 cross entropy와 balanced softmax 간 confusion matrix 비교. Imbalance factor(IF)는 200.

About

Balanced Softmax for classification

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published