AIMET PyTorch Bias Correction API

Bias Correction API

aimet_torch.bias_correction.correct_bias(model, quant_params, num_quant_samples, data_loader, num_bias_correct_samples, conv_bn_dict=None, perform_only_empirical_bias_corr=True, layers_to_ignore=None)

Corrects bias for each Conv layer of model (unless ignored). A combination of Analytical and Empirical Bias Correction is used i.e. all the layers which can be corrected using Analytical Bias Correction are corrected using Analytical Bias Correction and remaining layers are corrected using Empirical method.

Returns an in-place corrected floating point model

Parameters
  • model (Module) – Model to be corrected

  • quant_params (QuantParams) – Named tuple for quantization simulation for bias correction

  • num_quant_samples (int) – number of samples of images to pass through quantization sim for bias correction.

  • data_loader – data loader for the model

  • num_bias_correct_samples (int) – number of samples for Bias correction

  • conv_bn_dict (Optional[Dict[Module, ConvBnInfoType]]) – Dict of conv and bn with information related to activation. If None, the function calc it

  • perform_only_empirical_bias_corr (bool) – Default True. If true will perform only empirical Bias Corr for all layers irrespective of the fact that layer is eligible for Analytical Bias Corr.

  • layers_to_ignore (Optional[List[Module]]) – list of layer names for which we need to skip bias correction.


ConvBnInfoType

class aimet_common.bias_correction.ConvBnInfoType(input_bn=None, output_bn=None, in_activation_type=<ActivationType.no_activation: 0>, out_activation_type=<ActivationType.no_activation: 0>)

Type for hoding convs with bn info and activation types Activation types supported are Relu and Relu6

Parameters
  • input_bn – Reference to Input BatchNorm to layer

  • output_bn – Reference to Output BatchNorm to layer

  • in_activation_type (ActivationType) – Type of Activation

  • out_activation_type (ActivationType) – Type of Activation


ActivationType

class aimet_common.defs.ActivationType

Enums to identify activation type

no_activation = 0

No activation

relu = 1

ReLU activation

relu6 = 2

ReLU6 activation

Quantization Params

class aimet_torch.quantsim.QuantParams(weight_bw=8, act_bw=8, round_mode='nearest', quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, config_file=None)

Data type to hold quantization related params.

Constructor

Parameters
  • weight_bw (int) – Weight bitwidth (4-31) to use for quantizing layer weights. Default = 8

  • act_bw (int) – Activation bitwidth(4-31) to use for quantizing layer activations. Default = 8

  • round_mode (str) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’

  • quant_scheme (Union[QuantScheme, str]) – Quantization scheme. Supported options are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced

  • config_file (Optional[str]) – Path to Configuration file for model quantizers


Code Examples

Required imports

# Bias Correction related imports
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
from aimet_torch.examples.mobilenet import MobileNetV2
from aimet_torch.utils import create_fake_data_loader

Empirical Bias correction

def bias_correction_empirical():
    dataset_size = 2000
    batch_size = 64

    data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))

    model = MobileNetV2()
    model.eval()

    params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')

    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
                                 data_loader=data_loader.train_loader, num_bias_correct_samples=512)

Analytical + Empirical Bias correction

def bias_correction_analytical_and_empirical():

    dataset_size = 2000
    batch_size = 64

    data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))

    model = MobileNetV2()
    model.eval()

    # Find all BN + Conv pairs for analytical BC and remaining Conv for Empirical BC
    module_prop_dict = bias_correction.find_all_conv_bn_with_activation(model, input_shape=(1, 3, 224, 224))

    params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')

    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
                                 data_loader=data_loader, num_bias_correct_samples=512,
                                 conv_bn_dict=module_prop_dict, perform_only_empirical_bias_corr=False)

Bias correction Data loader format example

class BatchIterator:
    def __init__(self, data_loader):
        self.data_loader = data_loader

    def __iter__(self):
        for batch, label in self.data_loader:
            yield (batch, label)