Skip to content

Quantize yolov7 using pytorch_quantization.πŸš€πŸš€πŸš€

License

Notifications You must be signed in to change notification settings

yhwang-hub/yolov7_QAT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

8 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Description

Language Language Language Language Language

This is a repository for QAT finetune on yolov7 using TensorRT's pytorch_quantization tool

Method Calibration method mAPval
0.5
mAPval
0.5:0.95
batch-1 fps
Jetson Orin-X
batch-16 fps
Jetson Orin-X
weight
pytorch FP16 - 0.6972 0.5120 - - yolov7.pt
pytorch PTQ-INT8 Histogram(MSE) 0.6957 0.5100 - - yolov7_ptq.pt yolov7_ptq_640.onnx
pytorch QAT-INT8 Histogram(MSE) 0.6961 0.5111 - - yolov7_qat.pt
TensorRT FP16 - 0.6973 0.5124 140 168 yolov7.onnx
TensorRT PTQ-INT8 TensorRT built in EntropyCalibratorV2 0.6317 0.4573 207 264 -
TensorRT QAT-INT8 Histogram(MSE) 0.6962 0.5113 207 266 yolov7_qat_640.onnx

How To QAT Training

1.Setup

Suggest to use docker environment.

Download docker image:

docker pull longxiaowyh/yolov7:v1

Create docker container:

nvidia-docker run -itu root:root --name yolov7 --gpus all -v /your_path:/target_path -v /tmp/.X11-unix/:/tmp/.X11-unix/ -e DISPLAY=unix$DISPLAY -e GDK_SCALE -e GDK_DPI_SCALE  -e NVIDIA_VISIBLE_DEVICES=all -e NVIDIA_DRIVER_CAPABILITIES=compute,utility --shm-size=64g yolov7:v1 /bin/bash

1.Clone and apply patch

git clone [email protected]:yhwang-hub/yolov7_quantization.git

2.Install dependencies

pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com

3.Prepare coco dataset

.
β”œβ”€β”€ annotations
β”‚Β Β  β”œβ”€β”€ captions_train2017.json
β”‚Β Β  β”œβ”€β”€ captions_val2017.json
β”‚Β Β  β”œβ”€β”€ instances_train2017.json
β”‚Β Β  β”œβ”€β”€ instances_val2017.json
β”‚Β Β  β”œβ”€β”€ person_keypoints_train2017.json
β”‚Β Β  └── person_keypoints_val2017.json
β”œβ”€β”€ coco -> coco
β”œβ”€β”€ coco128
β”‚Β Β  β”œβ”€β”€ images
β”‚Β Β  β”œβ”€β”€ labels
β”‚Β Β  β”œβ”€β”€ LICENSE
β”‚Β Β  └── README.txt
β”œβ”€β”€ images
β”‚Β Β  β”œβ”€β”€ train2017
β”‚Β Β  └── val2017
β”œβ”€β”€ labels
β”‚Β Β  β”œβ”€β”€ train2017
β”‚Β Β  β”œβ”€β”€ train2017.cache
β”‚Β Β  └── val2017
β”œβ”€β”€ train2017.cache
β”œβ”€β”€ train2017.txt
β”œβ”€β”€ val2017.cache
└── val2017.txt

2.Start PTQ

2.1 Start sensitive layer analysis

python ptq.py --weights ./weights/yolov7s.pt --cocodir /home/wyh/disk/coco/ --batch_size 5 --save_ptq True --eval_origin --eval_ptq --start_ptq False --sensitive True

Modify the ignore_layers parameter in ptq.py as follows

parser.add_argument("--ignore_layers", type=str, default="model\.105\.m\.(.*)", help="regx")

2.2 Start PTQ

python ptq.py --weights ./weights/yolov7s.pt --cocodir /home/wyh/disk/coco/ --batch_size 5 --save_ptq True --eval_origin --eval_ptq --start_ptq True --sensitive False

3.Start QAT Training

python qat.py --weights ./weights/yolov5s.pt --cocodir /home/wyh/disk/coco/ --batch_size 5 --save_ptq True --save_qat True --eval_origin --eval_ptq --eval_qat

This script includes steps below:

  • Insert Q&DQ nodes to get fake-quant pytorch model Pytorch quntization tool provides automatic insertion of QDQ function. But for yolov7 model, it can not get the same performance as PTQ, because in Explicit mode(QAT mode), TensorRT will henceforth refer Q/DQ nodes' placement to restrict the precision of the model. Some of the automatic added Q&DQ nodes can not be fused with other layers which will cause some extra useless precision convertion. In our script, We find Some rules and restrictions for yolov7, QDQ nodes are automatically analyzed and configured in a rule-based manner, ensuring that they are optimal under TensorRT. Ensuring that all nodes are running INT8(confirmed with tool:trt-engine-explorer, see scripts/draw-engine.py). for details of this part, please refer quantization/rules.py, About the guidance of Q&DQ insert, please refer Guidance_of_QAT_performance_optimization

  • PTQ calibration After inserting Q&DQ nodes, we recommend to run PTQ-Calibration first. Per experiments, Histogram(MSE) is the best PTQ calibration method for yolov7. Note: if you are satisfied with PTQ result, you could also skip QAT.

  • QAT training After QAT, need to finetune traning our model. after getting the accuracy we are satisfied, Saving the weights to files

See the run_qat.log file for the running results, ptq.onnx and qat.onnx will be generated in this path.

Benchmark

# η”Ÿζˆengine
trtexec --onnx=./outdir-no-rule/ptq.onnx --fp16 --int8 --verbose --saveEngine=./outdir-no-rule/yolo_ptq.engine --workspace=1024000 --warmUp=500 --duration=10 --useCudaGraph --useSpinWait --noDataTransfers --exportLayerInfo=./outdir-no-rule/yolov7_ptq_layer.json --profilingVerbosity=detailed --exportProfile=./outdir-no-rule/yolov7_ptq_profile.json

trtexec --onnx=./outdir-no-rule/qat.onnx --fp16 --int8 --verbose --saveEngine=./outdir-no-rule/yolo_qat.engine --workspace=1024000 --warmUp=500 --duration=10 --useCudaGraph --useSpinWait --noDataTransfers --exportLayerInfo=./outdir-no-rule/yolov7_qat_layer.json --profilingVerbosity=detailed --exportProfile=./outdir-no-rule/yolov7_qat_profile.json

# 桋试性能
trtexec --loadEngine=./outdir-no-rule/yolo_ptq.engine --batch=1
trtexec --loadEngine=./outdir-no-rule/yolo_qat.engine --batch=1

# engineε―θ§†εŒ–
python scripts/draw-engine.py --layer=./outdir-no-rule/yolov7_ptq_layer.json --profile=./outdir-no-rule/yolov7_ptq_profile.json
python scripts/draw-engine.py --layer=./outdir-no-rule/yolov7_qat_layer.json --profile=./outdir-no-rule/yolov7_qat_profile.json

RTX3060 qps test result as follow:

image image

Official YOLOv7

Implementation of paper - YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors

PWC Hugging Face Spaces Open In Colab arxiv.org

Web Demo

Performance

MS COCO

Model Test Size APtest AP50test AP75test batch 1 fps batch 32 average time
YOLOv7 640 51.4% 69.7% 55.9% 161 fps 2.8 ms
YOLOv7-X 640 53.1% 71.2% 57.8% 114 fps 4.3 ms
YOLOv7-W6 1280 54.9% 72.6% 60.1% 84 fps 7.6 ms
YOLOv7-E6 1280 56.0% 73.5% 61.2% 56 fps 12.3 ms
YOLOv7-D6 1280 56.6% 74.0% 61.8% 44 fps 15.0 ms
YOLOv7-E6E 1280 56.8% 74.4% 62.1% 36 fps 18.7 ms

Installation

Docker environment (recommended)

Expand
# create the docker container, you can change the share memory size if you have more.
nvidia-docker run --name yolov7 -it -v your_coco_path/:/coco/ -v your_code_path/:/yolov7 --shm-size=64g nvcr.io/nvidia/pytorch:21.08-py3

# apt install required packages
apt update
apt install -y zip htop screen libgl1-mesa-glx

# pip install required packages
pip install seaborn thop

# go to code folder
cd /yolov7

Testing

yolov7.pt yolov7x.pt yolov7-w6.pt yolov7-e6.pt yolov7-d6.pt yolov7-e6e.pt

python test.py --data data/coco.yaml --img 640 --batch 32 --conf 0.001 --iou 0.65 --device 0 --weights yolov7.pt --name yolov7_640_val

You will get the results:

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.51206
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.69730
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.55521
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.35247
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.55937
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.66693
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.38453
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.63765
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.68772
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.53766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.73549
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.83868

To measure accuracy, download COCO-annotations for Pycocotools to the ./coco/annotations/instances_val2017.json

Training

Data preparation

bash scripts/get_coco.sh
  • Download MS COCO dataset images (train, val, test) and labels. If you have previously used a different version of YOLO, we strongly recommend that you delete train2017.cache and val2017.cache files, and redownload labels

Single GPU training

# train p5 models
python train.py --workers 8 --device 0 --batch-size 32 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights '' --name yolov7 --hyp data/hyp.scratch.p5.yaml

# train p6 models
python train_aux.py --workers 8 --device 0 --batch-size 16 --data data/coco.yaml --img 1280 1280 --cfg cfg/training/yolov7-w6.yaml --weights '' --name yolov7-w6 --hyp data/hyp.scratch.p6.yaml

Multiple GPU training

# train p5 models
python -m torch.distributed.launch --nproc_per_node 4 --master_port 9527 train.py --workers 8 --device 0,1,2,3 --sync-bn --batch-size 128 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights '' --name yolov7 --hyp data/hyp.scratch.p5.yaml

# train p6 models
python -m torch.distributed.launch --nproc_per_node 8 --master_port 9527 train_aux.py --workers 8 --device 0,1,2,3,4,5,6,7 --sync-bn --batch-size 128 --data data/coco.yaml --img 1280 1280 --cfg cfg/training/yolov7-w6.yaml --weights '' --name yolov7-w6 --hyp data/hyp.scratch.p6.yaml

Transfer learning

yolov7_training.pt yolov7x_training.pt yolov7-w6_training.pt yolov7-e6_training.pt yolov7-d6_training.pt yolov7-e6e_training.pt

Single GPU finetuning for custom dataset

# finetune p5 models
python train.py --workers 8 --device 0 --batch-size 32 --data data/custom.yaml --img 640 640 --cfg cfg/training/yolov7-custom.yaml --weights 'yolov7_training.pt' --name yolov7-custom --hyp data/hyp.scratch.custom.yaml

# finetune p6 models
python train_aux.py --workers 8 --device 0 --batch-size 16 --data data/custom.yaml --img 1280 1280 --cfg cfg/training/yolov7-w6-custom.yaml --weights 'yolov7-w6_training.pt' --name yolov7-w6-custom --hyp data/hyp.scratch.custom.yaml

Re-parameterization

See reparameterization.ipynb

Inference

On video:

python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source yourvideo.mp4

On image:

python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg

Export

Pytorch to CoreML (and inference on MacOS/iOS) Open In Colab

Pytorch to ONNX with NMS (and inference) Open In Colab

python export.py --weights yolov7-tiny.pt --grid --end2end --simplify \
        --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640

Pytorch to TensorRT with NMS (and inference) Open In Colab

wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7-tiny.pt
python export.py --weights ./yolov7-tiny.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640
git clone https://github.com/Linaom1214/tensorrt-python.git
python ./tensorrt-python/export.py -o yolov7-tiny.onnx -e yolov7-tiny-nms.trt -p fp16

Pytorch to TensorRT another way Open In Colab

Expand

wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7-tiny.pt
python export.py --weights yolov7-tiny.pt --grid --include-nms
git clone https://github.com/Linaom1214/tensorrt-python.git
python ./tensorrt-python/export.py -o yolov7-tiny.onnx -e yolov7-tiny-nms.trt -p fp16

# Or use trtexec to convert ONNX to TensorRT engine
/usr/src/tensorrt/bin/trtexec --onnx=yolov7-tiny.onnx --saveEngine=yolov7-tiny-nms.trt --fp16

Tested with: Python 3.7.13, Pytorch 1.12.0+cu113

Pose estimation

code yolov7-w6-pose.pt

See keypoint.ipynb.

Instance segmentation (with NTU)

code yolov7-mask.pt

See instance.ipynb.

Instance segmentation

code yolov7-seg.pt

YOLOv7 for instance segmentation (YOLOR + YOLOv5 + YOLACT)

Model Test Size APbox AP50box AP75box APmask AP50mask AP75mask
YOLOv7-seg 640 51.4% 69.4% 55.8% 41.5% 65.5% 43.7%

Anchor free detection head

code yolov7-u6.pt

YOLOv7 with decoupled TAL head (YOLOR + YOLOv5 + YOLOv6)

Model Test Size APval AP50val AP75val
YOLOv7-u6 640 52.6% 69.7% 57.3%

Citation

@article{wang2022yolov7,
  title={{YOLOv7}: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors},
  author={Wang, Chien-Yao and Bochkovskiy, Alexey and Liao, Hong-Yuan Mark},
  journal={arXiv preprint arXiv:2207.02696},
  year={2022}
}
@article{wang2022designing,
  title={Designing Network Design Strategies Through Gradient Path Analysis},
  author={Wang, Chien-Yao and Liao, Hong-Yuan Mark and Yeh, I-Hau},
  journal={arXiv preprint arXiv:2211.04800},
  year={2022}
}

Teaser

YOLOv7-semantic & YOLOv7-panoptic & YOLOv7-caption

YOLOv7-semantic & YOLOv7-detection & YOLOv7-depth (with NTUT)

YOLOv7-3d-detection & YOLOv7-lidar & YOLOv7-road (with NTUT)

Acknowledgements

Expand

About

Quantize yolov7 using pytorch_quantization.πŸš€πŸš€πŸš€

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published