Skip to content

Commit

Permalink
update setup and add instructions for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
by-liu committed Apr 30, 2022
1 parent 4e51e7f commit 224931c
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 17 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Bingyuan Liu

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
30 changes: 16 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,32 @@
</div><br/>


## Requirments:
## Install:

Here is some core dependences:
[option] create a new virtual env
```
torch==1.7.1+cu110
torchvision==0.8.2+cu110
albumentations==1.0.0
hydra-core==1.1.0
conda create -n mbls python=3.8.10
```

For detailed dependences, please check [requirements.txt](requirements.txt). (It maybe a little redundant. We will make it clean soon.)


## Install:
It's recommended to install [torch and torchvision](https://pytorch.org/) tailored to your environment in advance.
The torch versions I have tested are 1.10.0+cu111 and 1.7.1+cu110.

```
python setup.py develop
pip install -e .
```

The required libraies are included in [steup.py](setup.py).

## Data preparation

For CIFAR-10, our code can automatically download the data samples. For the others (Tiny-Imagenet, CUB-200 and VOC 2012), please refer to the official cites for downloading the datasets.

**Important Note** : Before you run the code, please add the absolute path of the data directory for the related data configs in [configs/data](configs/data/).
**Important Note** : Before you run the code, please add the absolute path of the data directory for the related data configs in [configs/data](configs/data/). Or you could pass it in the running commands.

## Usage:


### Command arguments :
### Training arguments:

<details><summary>python tools/train_net.py --help</summary>
<p>
Expand Down Expand Up @@ -155,13 +151,14 @@ Use --hydra-help to view Hydra specific help
</details>


### Example
### Traing Examples

Ours :
```python
python tools/train_net.py \
log_period=100 \
data=tiny_imagenet \
data.data_root=/home/bliu/work/Data/tiny-imagenet-200 \
model=resnet50_tiny model.num_classes=200 \
loss=logit_margin loss.margin=10.0 loss.alpha=0.1 \
optim=sgd optim.lr=0.1 optim.momentum=0.9 \
Expand All @@ -174,6 +171,7 @@ Cross entropy (CE) :
python tools/train_net.py \
log_period=100 \
data=tiny_imagenet \
data.data_root=/home/bliu/work/Data/tiny-imagenet-200 \
model=resnet50_tiny model.num_classes=200 \
loss=ce \
optim=sgd optim.lr=0.1 optim.momentum=0.9 \
Expand All @@ -186,13 +184,17 @@ Label smoothing (LS) :
python tools/train_net.py \
log_period=100 \
data=tiny_imagenet \
data.data_root=/home/bliu/work/Data/tiny-imagenet-200 \
model=resnet50_tiny model.num_classes=200 \
loss=ls \
optim=sgd optim.lr=0.1 optim.momentum=0.9 \
scheduler=multi_step scheduler.milestones="[40, 60]" \
train.max_epoch=100
```

[Testing Examples](docs/TEST.md)



## Support for extension or follow-up works

Expand Down
2 changes: 1 addition & 1 deletion configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ defaults:
- optim: sgd
- scheduler: multi_step
- wandb: my

- override hydra/job_logging: custom
- _self_

task: cv
device: cuda:0
Expand Down
98 changes: 98 additions & 0 deletions docs/TEST.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Testing with trained model

* Dataset : TinyImageNet
* Network : ResNet-50

### Running command
```python
python tools/test_net.py \
data=tiny_imagenet \
data.data_root=[The Root Directory Of The Dataset] \
model=resnet50_tiny \
model.num_classes=200 \
hydra.run.dir=[The Directory Of The Downloaded Checkpoint] \
test.checkpoint=[The Filename Of The Checkpoint]
```


### Running Examples

#### CE
```python
python tools/test_net.py \
data=tiny_imagenet \
data.data_root=/home/bliu/work/Data/tiny-imagenet-200 \
model=resnet50_tiny \
model.num_classes=200 \
hydra.run.dir=outputs/best_models/tiny_resnet50 \
test.checkpoint=resnet50_tiny-ce-best.pth
```

```
[2022-04-30 18:26:32,242 INFO][tester.py:123 - log_eval_epoch_info] -
+---------+---------+---------+---------+
| samples | acc | acc_5 | macc |
+---------+---------+---------+---------+
| 10000 | 0.65020 | 0.85960 | 0.65020 |
+---------+---------+---------+---------+
[2022-04-30 18:26:32,286 INFO][tester.py:124 - log_eval_epoch_info] -
+---------+---------+---------+---------+---------+
| samples | nll | ece | aece | cece |
+---------+---------+---------+---------+---------+
| 10000 | 1.40984 | 0.03728 | 0.03687 | 0.00137 |
+---------+---------+---------+---------+---------+
```

### LS
```python
python tools/test_net.py \
data=tiny_imagenet \
data.data_root=/home/bliu/work/Data/tiny-imagenet-200 \
model=resnet50_tiny \
model.num_classes=200 \
hydra.run.dir=outputs/best_models/tiny_resnet50 \
test.checkpoint=resnet50_tiny-ls-best.pth
```

```
[2022-04-30 18:27:34,880 INFO][tester.py:123 - log_eval_epoch_info] -
+---------+---------+---------+---------+
| samples | acc | acc_5 | macc |
+---------+---------+---------+---------+
| 10000 | 0.65780 | 0.86190 | 0.65780 |
+---------+---------+---------+---------+
[2022-04-30 18:27:34,880 INFO][tester.py:124 - log_eval_epoch_info] -
+---------+---------+---------+---------+---------+
| samples | nll | ece | aece | cece |
+---------+---------+---------+---------+---------+
| 10000 | 1.41337 | 0.03165 | 0.03159 | 0.00138 |
+---------+---------+---------+---------+---------+
```


#### MbLS
```python
python tools/test_net.py \
data=tiny_imagenet \
data.data_root=/home/bliu/work/Data/tiny-imagenet-200 \
model=resnet50_tiny \
model.num_classes=200 \
hydra.run.dir=outputs/best_models/tiny_resnet50 \
test.checkpoint=resnet50_tiny-mbls-best.pth
```

```
[2022-04-30 18:23:46,768 INFO][tester.py:123 - log_eval_epoch_info] -
+---------+---------+---------+---------+
| samples | acc | acc_5 | macc |
+---------+---------+---------+---------+
| 10000 | 0.64740 | 0.86030 | 0.64740 |
+---------+---------+---------+---------+
[2022-04-30 18:23:46,768 INFO][tester.py:124 - log_eval_epoch_info] -
+---------+---------+---------+---------+---------+
| samples | nll | ece | aece | cece |
+---------+---------+---------+---------+---------+
| 10000 | 1.43714 | 0.01641 | 0.01730 | 0.00140 |
+---------+---------+---------+---------+---------+
```

16 changes: 14 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,24 @@

setup(
name="calibrate",
version="0.1",
version="1.0",
author="",
description="For awesome calibration research",
packages=find_packages(),
python_requries=">=3.8",
install_requires=[
# Please install pytorch-related libraries and opencv by yourself based on your environment
# Please install torch and torchvision libraries before running this script
"torch",
"torchvision>=0.8.2",
"ipdb==0.13.9",
"albumentations==1.1.0",
"opencv-python==4.5.1.48",
"hydra-core==1.1.2",
"flake8==4.0.1",
"wandb==0.12.14",
"terminaltables==3.1.10",
"matplotlib==3.5.1",
"plotly==5.7.0",
"pandas==1.4.2"
],
)

0 comments on commit 224931c

Please sign in to comment.