AIMET PyTorch Bias Correction API

Bias Correction API


ConvBnInfoType

class aimet_common.bias_correction.ConvBnInfoType(input_bn=None, output_bn=None, in_activation_type=<ActivationType.no_activation: 0>, out_activation_type=<ActivationType.no_activation: 0>)

Type for hoding convs with bn info and activation types Activation types supported are Relu and Relu6

Parameters
  • input_bn – Reference to Input BatchNorm to layer

  • output_bn – Reference to Output BatchNorm to layer

  • in_activation_type (ActivationType) – Type of Activation

  • out_activation_type (ActivationType) – Type of Activation


ActivationType

class aimet_common.defs.ActivationType

Enums to identify activation type

no_activation = 0

No activation

relu = 1

ReLU activation

relu6 = 2

ReLU6 activation

Quantization Params


Code Examples

Required imports

# Bias Correction related imports
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
from aimet_torch.examples.mobilenet import MobileNetV2
from aimet_torch.utils import create_fake_data_loader

Empirical Bias correction

def bias_correction_empirical():
    dataset_size = 2000
    batch_size = 64

    data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))

    model = MobileNetV2()
    model.eval()

    params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')

    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
                                 data_loader=data_loader.train_loader, num_bias_correct_samples=512)

Analytical + Empirical Bias correction

def bias_correction_analytical_and_empirical():

    dataset_size = 2000
    batch_size = 64

    data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))

    model = MobileNetV2()
    model.eval()

    # Find all BN + Conv pairs for analytical BC and remaining Conv for Empirical BC
    module_prop_dict = bias_correction.find_all_conv_bn_with_activation(model, input_shape=(1, 3, 224, 224))

    params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')

    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
                                 data_loader=data_loader, num_bias_correct_samples=512,
                                 conv_bn_dict=module_prop_dict, perform_only_empirical_bias_corr=False)

Bias correction Data loader format example

class BatchIterator:
    def __init__(self, data_loader):
        self.data_loader = data_loader

    def __iter__(self):
        for batch, label in self.data_loader:
            yield (batch, label)