AIMET PyTorch Bias Correction API¶
Bias Correction API¶
-
aimet_torch.bias_correction.
correct_bias
(model, quant_params, num_quant_samples, data_loader, num_bias_correct_samples, conv_bn_dict=None, perform_only_empirical_bias_corr=True, layers_to_ignore=None)¶ Corrects bias for each Conv layer of model (unless ignored). A combination of Analytical and Empirical Bias Correction is used i.e. all the layers which can be corrected using Analytical Bias Correction are corrected using Analytical Bias Correction and remaining layers are corrected using Empirical method.
Returns an in-place corrected floating point model
- Parameters
model (
Module
) – Model to be correctedquant_params (
QuantParams
) – Named tuple for quantization simulation for bias correctionnum_quant_samples (
int
) – number of samples of images to pass through quantization sim for bias correction.data_loader – data loader for the model
num_bias_correct_samples (
int
) – number of samples for Bias correctionconv_bn_dict (
Optional
[Dict
[Module
,ConvBnInfoType
]]) – Dict of conv and bn with information related to activation. If None, the function calc itperform_only_empirical_bias_corr (
bool
) – Default True. If true will perform only empirical Bias Corr for all layers irrespective of the fact that layer is eligible for Analytical Bias Corr.layers_to_ignore (
Optional
[List
[Module
]]) – list of layer names for which we need to skip bias correction.
ConvBnInfoType¶
-
class
aimet_common.bias_correction.
ConvBnInfoType
(input_bn=None, output_bn=None, in_activation_type=<ActivationType.no_activation: 0>, out_activation_type=<ActivationType.no_activation: 0>)¶ Type for hoding convs with bn info and activation types Activation types supported are Relu and Relu6
- Parameters
input_bn – Reference to Input BatchNorm to layer
output_bn – Reference to Output BatchNorm to layer
in_activation_type (
ActivationType
) – Type of Activationout_activation_type (
ActivationType
) – Type of Activation
ActivationType¶
Quantization Params¶
-
class
aimet_torch.quantsim.
QuantParams
(weight_bw=8, act_bw=8, round_mode='nearest', quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, config_file=None)¶ Data type to hold quantization related params.
Constructor
- Parameters
weight_bw (
int
) – Weight bitwidth (4-31) to use for quantizing layer weights. Default = 8act_bw (
int
) – Activation bitwidth(4-31) to use for quantizing layer activations. Default = 8round_mode (
str
) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’quant_scheme (
Union
[QuantScheme
,str
]) – Quantization scheme. Supported options are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhancedconfig_file (
Optional
[str
]) – Path to Configuration file for model quantizers
Code Examples¶
Required imports
# Bias Correction related imports
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
from aimet_torch.examples.mobilenet import MobileNetV2
from aimet_torch.utils import create_fake_data_loader
Empirical Bias correction
def bias_correction_empirical():
dataset_size = 2000
batch_size = 64
data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))
model = MobileNetV2()
model.eval()
params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
# Perform Bias Correction
bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
data_loader=data_loader.train_loader, num_bias_correct_samples=512)
Analytical + Empirical Bias correction
def bias_correction_analytical_and_empirical():
dataset_size = 2000
batch_size = 64
data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))
model = MobileNetV2()
model.eval()
# Find all BN + Conv pairs for analytical BC and remaining Conv for Empirical BC
module_prop_dict = bias_correction.find_all_conv_bn_with_activation(model, input_shape=(1, 3, 224, 224))
params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
# Perform Bias Correction
bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
data_loader=data_loader, num_bias_correct_samples=512,
conv_bn_dict=module_prop_dict, perform_only_empirical_bias_corr=False)
Bias correction Data loader format example
class BatchIterator:
def __init__(self, data_loader):
self.data_loader = data_loader
def __iter__(self):
for batch, label in self.data_loader:
yield (batch, label)