Skip to content

PyTorch Quantization Aware Training(QAT,量化感知训练)

Notifications You must be signed in to change notification settings

jnulzl/PyTorch-QAT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch 量化感知训练(Quantization Aware Training—QAT)

说明

软件环境

  • Ubuntu 20.04 x86_64

  • python 3.9

  • onnx==1.14.0

  • onnxsim==0.4.33

  • numpy==1.21.6

  • torch==2.0.0+cu117

  • torchvision==0.15.1+cu117

例子

  • 安装
git clone https://github.com/jnulzl/PyTorch-QAT
cd PyTorch-QAT
pip install -r requirements.txt 
cd data
chmod a+x download_data.sh
./download_data.sh
  • MobileNetV2 QAT
python cifar.py #默认网络为mobilenet_v2
Files already downloaded and verified
Files already downloaded and verified
Training Model...
Epoch: 000 Eval Loss: 2.303 Eval Acc: 0.097
Epoch: 001 Train Loss: 2.024 Train Acc: 0.255 Eval Loss: 1.725 Eval Acc: 0.356
Epoch: 002 Train Loss: 1.693 Train Acc: 0.380 Eval Loss: 1.520 Eval Acc: 0.438
Epoch: 003 Train Loss: 1.533 Train Acc: 0.439 Eval Loss: 1.437 Eval Acc: 0.472
Epoch: 004 Train Loss: 1.441 Train Acc: 0.478 Eval Loss: 1.354 Eval Acc: 0.514
Epoch: 005 Train Loss: 1.368 Train Acc: 0.506 Eval Loss: 1.257 Eval Acc: 0.549
Epoch: 006 Train Loss: 1.289 Train Acc: 0.537 Eval Loss: 1.193 Eval Acc: 0.573
......
Epoch: 095 Train Loss: 0.381 Train Acc: 0.863 Eval Loss: 0.619 Eval Acc: 0.799
Epoch: 096 Train Loss: 0.381 Train Acc: 0.863 Eval Loss: 0.616 Eval Acc: 0.803
Epoch: 097 Train Loss: 0.384 Train Acc: 0.861 Eval Loss: 0.619 Eval Acc: 0.799
Epoch: 098 Train Loss: 0.381 Train Acc: 0.863 Eval Loss: 0.616 Eval Acc: 0.802
Epoch: 099 Train Loss: 0.386 Train Acc: 0.860 Eval Loss: 0.622 Eval Acc: 0.800
Epoch: 100 Train Loss: 0.382 Train Acc: 0.861 Eval Loss: 0.619 Eval Acc: 0.800
Training QAT Model...
Epoch: 000 Eval Loss: 0.634 Eval Acc: 0.795
Epoch: 001 Train Loss: 0.407 Train Acc: 0.853 Eval Loss: 0.628 Eval Acc: 0.797
Epoch: 002 Train Loss: 0.407 Train Acc: 0.853 Eval Loss: 0.630 Eval Acc: 0.794
Epoch: 003 Train Loss: 0.406 Train Acc: 0.853 Eval Loss: 0.629 Eval Acc: 0.798
Epoch: 004 Train Loss: 0.404 Train Acc: 0.853 Eval Loss: 0.623 Eval Acc: 0.794
Epoch: 005 Train Loss: 0.406 Train Acc: 0.854 Eval Loss: 0.617 Eval Acc: 0.799
Epoch: 006 Train Loss: 0.399 Train Acc: 0.855 Eval Loss: 0.623 Eval Acc: 0.797
Epoch: 007 Train Loss: 0.401 Train Acc: 0.854 Eval Loss: 0.629 Eval Acc: 0.793
Epoch: 008 Train Loss: 0.392 Train Acc: 0.858 Eval Loss: 0.637 Eval Acc: 0.794
Epoch: 009 Train Loss: 0.397 Train Acc: 0.857 Eval Loss: 0.631 Eval Acc: 0.793
Epoch: 010 Train Loss: 0.394 Train Acc: 0.857 Eval Loss: 0.641 Eval Acc: 0.797
......
FP32 evaluation accuracy: 0.800
INT8 evaluation accuracy: 0.797
FP32 CPU Inference Latency: 2.32 ms / sample
FP32 CUDA Inference Latency: 3.22 ms / sample
INT8 CPU Inference Latency: 1.56 ms / sample
INT8 JIT CPU Inference Latency: 0.57 ms / sample

从以上结果看出,FP32和Int8测试精度相当(0.800 vs 0.797),但Int8的速度是FP32的约4倍(2.32ms -> 0.57),由此看出QAT后的效果还是很明显的。

以上测试数据因训练参数和硬盘环境不同而不同。

  • ResNet18 QAT

更改cifar10.py中的model_name = 'resnet18',然后

python cifar.py
......

最后训练并导出的onnx模型位于save_models目录!

有用链接

以下是一些本人参考的链接:

About

PyTorch Quantization Aware Training(QAT,量化感知训练)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published