Cross-Layer Equalization

This notebook contains an example of how to use AIMET to apply Cross-Layer Equalization (CLE) and Bias Correction (BC). CLE and BC are post-training quantization techniques for improving the quantized accuracy of a model. These techniques help recover quantized accuracy when the model quantization is sensitive to parameter quantization as opposed to activation quantization.

CLE does not need any data samples. BC may optionally need unlabeled data samples.

Cross-layer equalization

AIMET performs the following steps when running CLE:

  1. Batch norm (BN) Folding: Folds BN layers into convolution (Conv) layers immediate before or after the Conv layers.

  2. Cross-layer scaling: For a set of consecutive Conv layers, equalizes the range of tensor values per-channel by scaling their weight tensor values.

  3. High bias folding: Cross-layer scaling may result in high bias parameter values for some layers. This technique folds some of the bias of a layer into the subsequent layer’s parameters.

Overall flow

This example performs the following steps:

  1. Instantiate the example evaluation and training pipeline

  2. Convert an FP32 PyTorch model to ONNX and evaluate the model’s baseline FP32 accuracy

  3. Create a quantization simulation model (with fake quantization ops inserted) and evaluate this simuation model to get a quantized accuracy score

  4. Apply CLE and evaluate the simulation model to get a post-finetuned quantized accuracy score

Note

This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.


Dataset

This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .

Note

The dataloader provided in this example relies on these features of the ImageNet data set:

  • Subfolders train for the training samples and val for the validation samples. See the pytorch dataset description for more details.

  • One subdirectory per class, and one file per image sample.

Note

To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.

Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved.

[ ]:
DATASET_DIR = '/path/to/dataset/'         # Replace this path with a real directory

1. Instantiate the example training and validation pipeline

Use the following training and validation loop for the image classification task.

Things to note:

  • AIMET does not put limitations on how the training and validation pipeline is written. AIMET modifies the user’s model to create a QuantizationSim model, which is still a PyTorch model. The QuantizationSim model can be used in place of the original model when doing inference or training.

  • AIMET doesn not put limitations on the interface of the evaluate() or train() methods. You should be able to use your existing evaluate and train routines as-is.

[ ]:
import torch
import onnxruntime as ort
from Examples.common import image_net_config
from Examples.onnx.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader

class ImageNetDataPipeline:

    @staticmethod
    def get_val_dataloader() -> torch.utils.data.DataLoader:
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(DATASET_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         is_training=False,
                                         num_workers=image_net_config.evaluation['num_workers']).data_loader
        return data_loader

    @staticmethod
    def evaluate(sess: ort.InferenceSession) -> float:
        """
        Given a torch model, evaluates its Top-1 accuracy on the dataset
        :param sess: the model to evaluate
        """
        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      num_workers=image_net_config.evaluation['num_workers'])

        return evaluator.evaluate(sess, iterations=None)


2. Convert an FP32 PyTorch model to ONNX, simplify & then evaluate baseline FP32 accuracy

2.1 Load a pretrained resnet18 model from torchvision.

You can load any pretrained PyTorch model instead.

[ ]:
from torchvision.models import resnet18
import onnx

input_shape = (1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
dummy_input = torch.randn(input_shape)
filename = "./resnet18.onnx"

# Load a pretrained ResNet-18 model in torch
pt_model = resnet18(pretrained=True)

2.2 Export the model to ONNX.

[ ]:

# Export the torch model to onnx torch.onnx.export(pt_model.eval(), dummy_input, filename, export_params=True, do_constant_folding=False, input_names=['input'], output_names=['output'], dynamic_axes={ 'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}, } ) model = onnx.load_model(filename)

2.3 It is recommended to simplify the model before using AIMET

[ ]:
from onnxsim import simplify

try:
    model, _ = simplify(model)
except:
    print('ONNX Simplifier failed. Proceeding with unsimplified model')

2.4 Decide whether to place the model on a CPU or CUDA device.

This example uses CUDA if it is available. You can change this logic and force a device placement if needed.

[ ]:
# cudnn_conv_algo_search is fixing it to default to avoid changing in accuracies/outputs at every inference
if 'CUDAExecutionProvider' in ort.get_available_providers():
    providers = [('CUDAExecutionProvider', {'cudnn_conv_algo_search': 'DEFAULT'}), 'CPUExecutionProvider']
    use_cuda = True
else:
    providers = ['CPUExecutionProvider']
    use_cuda = False

2.5 Create an ONNX runtime session and compute the floating point 32-bit (FP32) accuracy of this model using the evaluate() routine.

[ ]:
sess = ort.InferenceSession(model.SerializeToString(), providers=providers)
accuracy = ImageNetDataPipeline.evaluate(sess)
print(accuracy)

3. Create a quantization simulation model and determine quantized accuracy

Fold Batch Norm layers

Before calculating the simulated quantized accuracy using QuantizationSimModel, fold the BatchNorm (BN) layers into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.

BN folding improves inference performance on quantized runtimes but can degrade accuracy on these platforms. This step simulates this on-target drop in accuracy.

3.1 Use the following code to call AIMET to fold the BN layers in-place on the given model.

[ ]:
from aimet_onnx.batch_norm_fold import fold_all_batch_norms_to_weight

_ = fold_all_batch_norms_to_weight(model)

Create the Quantization Sim Model

3.2 Use AIMET to create a QuantizationSimModel.

In this step, AIMET inserts fake quantization ops in the model graph and configures them.

Key parameters:

  • Setting default_output_bw to 8 performs all activation quantizations in the model using integer 8-bit precision

  • Setting default_param_bw to 8 performs all parameter quantizations in the model using integer 8-bit precision

  • num_batches is the number of batches to use to compute encodings. Only five batches are used here for the sake of speed

See QuantizationSimModel in the AIMET API documentation for a full explanation of the parameters.

[ ]:
from aimet_common.defs import QuantScheme
from aimet_onnx.quantsim import QuantizationSimModel

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_activation_bw=8,
                           default_param_bw=8,
                           use_cuda=use_cuda)

AIMET has added quantizer nodes to the model graph, but before the sim model can be used for inference or training, scale and offset quantization parameters must be calculated for each quantizer node by passing unlabeled data samples through the model to collect range statistics. This process is sometimes referred to as calibration. AIMET refers to it as “computing encodings”.

3.3 Create a routine to pass unlabeled data samples through the model.

The following code is one way to write a routine that passes unlabeled samples through the model to compute encodings. It uses the existing train or validation data loader to extract samples and pass them to the model. Since there is no need to compute loss metrics, it ignores the model output.

[ ]:
def pass_calibration_data(session, samples):
    data_loader = ImageNetDataPipeline.get_val_dataloader()
    batch_size = data_loader.batch_size
    input_name = sess.get_inputs()[0].name

    batch_cntr = 0
    for input_data, target_data in data_loader:

        inputs_batch = input_data.numpy()
        session.run(None, {input_name : inputs_batch})

        batch_cntr += 1
        if (batch_cntr * batch_size) > samples:
            break

A few notes regarding the data samples:

  • A very small percentage of the data samples are needed. For example, the training dataset for ImageNet has 1M samples; 500 or 1000 suffice to compute encodings.

  • The samples should be reasonably well distributed. While it’s not necessary to cover all classes, avoid extreme scenarios like using only dark or only light samples. That is, using only pictures captured at night, say, could skew the results.


3.4 Call AIMET to use the routine to pass data through the model and compute the quantization encodings.

Encodings here refer to scale and offset quantization parameters.

[ ]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=1000)

The QuantizationSim model is now ready to be used for inference or training.

3.5 Pass the model to the same evaluation routine as before to calculate a simulated quantized accuracy score for INT8 quantization for comparison with the FP32 score.

[ ]:
accuracy = ImageNetDataPipeline.evaluate(sim.session)
print(accuracy)

4. Apply CLE

4.1 Perform CLE

The next cell performs cross-layer equalization on the model. As noted before, the function folds batch norms, applies cross-layer scaling, and then folds high biases.

Note

The CLE procedure needs BN statistics. If a BN folded model is provided, CLE runs the cross-layer scaling (CLS) optimization step but skips the high-bias absorption (HBA) step. To avoid this, load the original model again before running CLE.

Note

CLE equalizes the model in-place.

[ ]:
filename = "./resnet18.onnx"
model = onnx.load_model(filename)

It is recommended to simplify the model before using AIMET

[ ]:
from onnxsim import simplify

try:
    model, _ = simplify(model)
except:
    print('ONNX Simplifier failed. Proceeding with unsimplified model')
[ ]:
from aimet_onnx.cross_layer_equalization import equalize_model

equalize_model(model)

4.2 Compute the accuracy of the equalized model.

Create a simulation model as before and evaluate it to determine simulated quantized accuracy.

[ ]:
sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_activation_bw=8,
                           default_param_bw=8,
                           use_cuda=use_cuda)

sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=1000)

accuracy = ImageNetDataPipeline.evaluate(sim.session)
print(accuracy)

For more information

See the AIMET API docs for details about the AIMET APIs and optional parameters.

See the other example notebooks to learn how to use other AIMET post-training quantization techniques.

To learn more about these techniques, see “Data-Free Quantization Through Weight Equalization and Bias Correction” from ICCV 2019.