Skip to content

Latest commit

 

History

History
125 lines (114 loc) · 3.57 KB

README.md

File metadata and controls

125 lines (114 loc) · 3.57 KB

Pytorch_Quantization_EX

0. Introduction

  • Goal : Quantization Model PTQ & QAT
  • Process :
    1. Pytorch model train with custom dataset
    2. Pytorch-Quantization model calibration for ptq
    3. Pytorch-Quantization model fine tuning for qat
    4. Generation TensorRT int8 model from Pytorch-Quantization model
    5. Generation TensorRT int8 model using tensorrt calibration class
  • Sample Model : Resnet18
  • Dataset : imagenet100

1. Development Environment

  • Device
    • Windows 10 laptop
    • CPU i7-11375H
    • GPU RTX-3060
  • Dependency
    • cuda 12.1
    • cudnn 8.9.2
    • tensorrt 8.6.1
    • pytorch 2.1.0+cu121

2. Code Scheme

    Quantization_EX/
    ├── calibrator.py       # calibration class for TensorRT PTQ
    ├── common.py           # utils for TensorRT
    ├── infer.py            # base model infer
    ├── onnx_export.py      # onnx export
    ├── ptq.py              # Post Train Quantization
    ├── qat.py              # Quantization Aware Training
    ├── quant_utils.py      # utils for quantization
    ├── train.py            # base model train
    ├── trt_infer.py        # TensorRT model infer
    ├── utils.py            # utils
    ├── LICENSE
    └── README.md

3. Performance Evaluation

  • Calculation 10000 iteration with one input data [1, 3, 224, 224]
TRT TRT TRT PTQ PT-Q PTQ PT-Q PTQ w bnf PT-Q QAT PT-Q QAT w bnf
Precision FP32 FP16 Int8 Int8 Int8 Int8 Int8
Acc Top-1 [%] 83.08 83.04 83.12 83.18 82.64 83.42 82.80
Avg Latency [ms] 1.188 ms 0.527 ms 0.418 ms 0.566 ms 0.545 ms 0.577 ms 0.534 ms
Avg FPS [frame/sec] 841.74 fps 1896.01 fps 2388.33 fps 1764.55 fps 1834.69 fps 1730.89 fps 1870.99 fps
Gpu Memory [MB] 179 MB 135 MB 123 MB 129 MB 129 MB 129 MB 129 MB
  • PT : Ptorch Quantization
  • TRT : TensorRT
  • bnf : bach normalization folding (conv + bn -> conv')

4. Guide

  • infer -> train -> ptq -> qat -> onnx_export -> trt_infer -> trt_infer_acc

5. Reference