AIMET PyTorch Bias Correction API¶
Bias Correction API¶
ConvBnInfoType¶
- 
class aimet_common.bias_correction.ConvBnInfoType(input_bn=None, output_bn=None, in_activation_type=<ActivationType.no_activation: 0>, out_activation_type=<ActivationType.no_activation: 0>)¶
- Type for hoding convs with bn info and activation types Activation types supported are Relu and Relu6 - Parameters
- input_bn – Reference to Input BatchNorm to layer 
- output_bn – Reference to Output BatchNorm to layer 
- in_activation_type ( - ActivationType) – Type of Activation
- out_activation_type ( - ActivationType) – Type of Activation
 
 
ActivationType¶
Quantization Params¶
Code Examples¶
Required imports
# Bias Correction related imports
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
from aimet_torch.examples.mobilenet import MobileNetV2
from aimet_torch.utils import create_fake_data_loader
Empirical Bias correction
def bias_correction_empirical():
    dataset_size = 2000
    batch_size = 64
    data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))
    model = MobileNetV2()
    model.eval()
    params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
                                 data_loader=data_loader.train_loader, num_bias_correct_samples=512)
Analytical + Empirical Bias correction
def bias_correction_analytical_and_empirical():
    dataset_size = 2000
    batch_size = 64
    data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))
    model = MobileNetV2()
    model.eval()
    # Find all BN + Conv pairs for analytical BC and remaining Conv for Empirical BC
    module_prop_dict = bias_correction.find_all_conv_bn_with_activation(model, input_shape=(1, 3, 224, 224))
    params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
                                 data_loader=data_loader, num_bias_correct_samples=512,
                                 conv_bn_dict=module_prop_dict, perform_only_empirical_bias_corr=False)
Bias correction Data loader format example
class BatchIterator:
    def __init__(self, data_loader):
        self.data_loader = data_loader
    def __iter__(self):
        for batch, label in self.data_loader:
            yield (batch, label)
