Skip to content

Latest commit

 

History

History
96 lines (82 loc) · 4.46 KB

QAT.md

File metadata and controls

96 lines (82 loc) · 4.46 KB

Quantization-aware Training

Design

Quantization-aware training (QAT) simulates low-precision inference-time computation in the forward pass of the training process. With QAT, all weights and activations are "fake quantized" during both the forward and backward passes of training: that is, float values are rounded to mimic int8 values, but all computations are still done with floating point numbers. Thus, all the weight adjustments during training are made while "aware" of the fact that the model will ultimately be quantized; after quantizing, therefore, this method will usually yield higher accuracy than either dynamic quantization or post-training static quantization.

fake quantize

Usage

First, define a training function as below. accuracy is in the

def training_func_for_nc(model):
    epochs = 8
    iters = 30
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
    for nepoch in range(epochs):
        model.train()
        cnt = 0
        for image, target in train_loader:
            print('.', end='')
            cnt += 1
            output = model(image)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if cnt >= iters:
                break
        if nepoch > 3:
            # Freeze quantizer parameters
            model.apply(torch.quantization.disable_observer)
        if nepoch > 2:
            # Freeze batch norm mean and variance estimates
            model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
    return model

Fuse modules:

model.fuse_model()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.0001)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

Finally, prepare_qat performs the "fake quantization", preparing the model for quantization-aware training, this function already be implemented as a hook :

torch.quantization.prepare_qat(model, inplace=True)

Training a quantized model with high accuracy requires accurate modeling of numerics at inference. INC does the training loop by following:

  • Switch batch norm to use running mean and variance towards the end of training to better match inference numerics.
  • Freeze the quantizer parameters (scale and zero-point) and fine tune the weights.
num_train_batches = 20
# Train and check accuracy after each epoch
for nepoch in range(8):
    train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
    if nepoch > 3:
        # Freeze quantizer parameters
        qat_model.apply(torch.quantization.disable_observer)
    if nepoch > 2:
        # Freeze batch norm mean and variance estimates
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
    # Check the accuracy after each epoch
    quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
    quantized_model.eval()
    top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)
    print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))

When using QAT in INC, you just need to use these APIs:

from neural_compressor.experimental import Quantization, common
quantizer = Quantization("./conf.yaml")
quantizer.model = common.Model(model)
quantizer.q_func = training_func_for_nc
quantizer.eval_dataloader = val_loader
q_model = quantizer.fit()

The quantizer.fit() function will return a best quantized model during timeout constrain.
The yaml define example: The yaml example

Here, we just perform quantization-aware training for a small number of epochs. Nevertheless, quantization-aware training yields an accuracy of over 71% on the entire imagenet dataset, which is close to the floating point accuracy of 71.9%.

More on quantization-aware training:

  • QAT is a super-set of post-training quantization techniques that allows for more debugging. For example, we can analyze if the accuracy of the model is limited by weight or activation quantization.
  • We can simulate the accuracy of a quantized model in floating points since we are using fake-quantization to model the numerics of actual quantized arithmetic.
  • We can easily mimic post-training quantization.

Examples

For related examples, please refer to the QAT models.