AIMET PyTorch Bias Correction API¶
Bias Correction API¶
ConvBnInfoType¶
-
class
aimet_common.bias_correction.
ConvBnInfoType
(input_bn=None, output_bn=None, in_activation_type=<ActivationType.no_activation: 0>, out_activation_type=<ActivationType.no_activation: 0>)¶ Type for hoding convs with bn info and activation types Activation types supported are Relu and Relu6
- Parameters
input_bn – Reference to Input BatchNorm to layer
output_bn – Reference to Output BatchNorm to layer
in_activation_type (
ActivationType
) – Type of Activationout_activation_type (
ActivationType
) – Type of Activation
ActivationType¶
Quantization Params¶
Code Examples¶
Required imports
# Bias Correction related imports
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
from aimet_torch.examples.mobilenet import MobileNetV2
from aimet_torch.utils import create_fake_data_loader
Empirical Bias correction
def bias_correction_empirical():
dataset_size = 2000
batch_size = 64
data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))
model = MobileNetV2()
model.eval()
params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
# Perform Bias Correction
bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
data_loader=data_loader.train_loader, num_bias_correct_samples=512)
Analytical + Empirical Bias correction
def bias_correction_analytical_and_empirical():
dataset_size = 2000
batch_size = 64
data_loader = create_fake_data_loader(dataset_size=dataset_size, batch_size=batch_size, image_size=(3, 224, 224))
model = MobileNetV2()
model.eval()
# Find all BN + Conv pairs for analytical BC and remaining Conv for Empirical BC
module_prop_dict = bias_correction.find_all_conv_bn_with_activation(model, input_shape=(1, 3, 224, 224))
params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
# Perform Bias Correction
bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
data_loader=data_loader, num_bias_correct_samples=512,
conv_bn_dict=module_prop_dict, perform_only_empirical_bias_corr=False)
Bias correction Data loader format example
class BatchIterator:
def __init__(self, data_loader):
self.data_loader = data_loader
def __iter__(self):
for batch, label in self.data_loader:
yield (batch, label)