AIMET PyTorch Cross Layer Equalization Primitive API¶
Introduction¶
If a user wants to modify the order of Cross Layer equalization, not use some features or manually tweak the list of layers that need to be equalized, the following APIs can be used.
Higher level APIs can be used for using one or more features one after the other. It automatically finds the layers to be folded or scaled.
Lower level APIs can be used to manually tweak the list of layers to be folded. The user has to pass the list of layers in the correct order that they appear in the model.
Note: Before using High Bias fold, Cross Layer Scaling (CLS) needs to be applied and scaling factors obtained from CLS need to be plugged in to High Bias Fold. And, if there are batchnorm layers, they need to be folded and the info saved to be plugged into high bias fold API.
ClsSetInfo Definition¶
- 
class aimet_torch.cross_layer_equalization.ClsSetInfo(cls_pair_1, cls_pair_2=None)[source]¶
- This class hold information about the layers in a CLS set, along with corresponding scaling factors and other information like if there is a ReLU activation function between the CLS set layers - Constructor takes 2 pairs if Depth-wise separable layer is being folded - Parameters
- cls_pair_1 ( - ClsSetLayerPairInfo) – Pair between two conv or conv and depth-wise conv
- cls_pair_2 ( - Optional[- ClsSetLayerPairInfo]) – Pair between depth-wise conv and point-wise conv
 
 - 
class ClsSetLayerPairInfo(layer1, layer2, scale_factor, relu_activation_between_layers)[source]¶
- Models a pair of layers that were scaled using CLS. And related information. - Parameters
- layer1 ( - Conv2d) – Layer whose bias is folded
- layer2 ( - Conv2d) – Layer to which bias of previous layer’s bias is folded
- scale_factor ( - ndarray) – Scale Factor found from Cross Layer Scaling to scale BN parameters
- relu_activation_between_layers ( - bool) – If the activation between layer1 and layer2 is Relu
 
 
 
Higher Level APIs for Cross Layer Equalization¶
API for Batch Norm Folding
- 
aimet_torch.batch_norm_fold.fold_all_batch_norms(model, input_shapes, dummy_input=None)¶
- Fold all batch_norm layers in a model into the weight of the corresponding conv layers - Parameters
- model ( - Module) – Model
- input_shapes ( - Union[- Tuple,- List[- Tuple]]) – Input shapes for the model (can be one or multiple inputs)
- dummy_input ( - Union[- Tensor,- Tuple,- None]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors
 
- Return type
- List[- Tuple[- Union[- Linear,- Conv1d,- Conv2d,- ConvTranspose2d],- Union[- BatchNorm1d,- BatchNorm2d]]]
- Returns
- A list of pairs of layers [(Conv/Linear, BN layer that got folded)] 
 
API for Cross Layer Scaling
- 
aimet_torch.cross_layer_equalization.CrossLayerScaling.scale_model(model, input_shapes, dummy_input=None)¶
- Uses cross-layer scaling to scale all applicable layers in the given model - Parameters
- model ( - Module) – Model to scale
- input_shapes ( - Union[- Tuple,- List[- Tuple]]) – Input shape for the model (can be one or multiple inputs)
- dummy_input ( - Union[- Tensor,- List[- Tensor],- None]) – Dummy input to the model. Used to parse model graph. User is expected to place the tensors on the appropriate device.
 
- Return type
- List[- ClsSetInfo]
- Returns
- CLS information for each CLS set 
 
API for High Bias Folding
- 
aimet_torch.cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_layers)¶
- Folds bias values greater than 3 * sigma to next layer’s bias - Parameters
- cls_set_info_list ( - List[- ClsSetInfo]) – List of info elements for each cls set
- bn_layers ( - Dict[- Union[- Conv2d,- ConvTranspose2d],- BatchNorm2d]) – Key: Conv/Linear layer Value: Corresponding folded BN layer
 
- Returns
- None 
 
Code Examples for Higher Level APIs¶
Required imports
from torchvision import models
from aimet_torch import cross_layer_equalization
from aimet_torch.cross_layer_equalization import equalize_model
Cross Layer Equalization in auto mode calling each API
def cross_layer_equalization_auto_step_by_step():
    model = models.resnet18(pretrained=True)
    model = model.eval()
    input_shape = (1, 3, 224, 224)
    # Fold BatchNorm layers
    folded_pairs = batch_norm_fold.fold_all_batch_norms(model, input_shape)
    bn_dict = {}
    for conv_bn in folded_pairs:
        bn_dict[conv_bn[0]] = conv_bn[1]
    # Replace any ReLU6 layers with ReLU
    utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU)
    # Perform cross-layer scaling on applicable layer sets
    cls_set_info_list = cross_layer_equalization.CrossLayerScaling.scale_model(model, input_shape)
    # Perform high-bias fold
    cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_dict)
Lower Level APIs for Cross Layer Equalization¶
API for Batch Norm Folding
- 
aimet_torch.batch_norm_fold.fold_given_batch_norms(model, layer_pairs)[source]¶
- Fold a given set of batch_norm layers into conv layers - Parameters
- model – Model 
- layer_pairs – Pairs of conv and batch_norm layers to use for folding 
 
- Returns
- None 
 
API for Cross Layer Scaling
- 
aimet_torch.cross_layer_equalization.CrossLayerScaling.scale_cls_sets(cls_sets)¶
- Scale multiple CLS sets - Parameters
- cls_sets ( - List[- Union[- Tuple[- Conv2d,- Conv2d],- Tuple[- Conv2d,- Conv2d,- Conv2d]]]) – List of CLS sets
- Return type
- List[- Union[- ndarray,- Tuple[- ndarray]]]
- Returns
- Scaling factors calculated and applied for each CLS set in order 
 
API for High bias folding
- 
aimet_torch.cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_layers)
- Folds bias values greater than 3 * sigma to next layer’s bias - Parameters
- cls_set_info_list ( - List[- ClsSetInfo]) – List of info elements for each cls set
- bn_layers ( - Dict[- Union[- Conv2d,- ConvTranspose2d],- BatchNorm2d]) – Key: Conv/Linear layer Value: Corresponding folded BN layer
 
- Returns
- None 
 
Code Examples for Lower Level APIs¶
Required imports
    model = models.resnet18(pretrained=True)
from aimet_torch import cross_layer_equalization
from aimet_torch.cross_layer_equalization import equalize_model
Cross Layer Equalization in manual mode
def cross_layer_equalization_manual():
    model = models.resnet18(pretrained=True)
    model = model.eval()
    # Batch Norm Fold
    # Create a list of conv/linear and BN layers for folding forward or backward
    layer_list = [(model.conv1, model.bn1),
                  (model.layer1[0].conv1, model.layer1[0].bn1)]
    # Save the corresponding BN layers (needed only for high bias folding)
    bn_dict = {}
    for conv_bn in layer_list:
        bn_dict[conv_bn[0]] = conv_bn[1]
    batch_norm_fold.fold_given_batch_norms(model, layer_list)
    # Replace any ReLU6 layers with ReLU
    utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU)
    # Cross Layer Scaling
    # Create a list of consecutive conv layers to be equalized
    consecutive_layer_list = [(model.conv1, model.layer1[0].conv1),
                              (model.layer1[0].conv1, model.layer1[0].conv2)]
    scaling_factor_list = cross_layer_equalization.CrossLayerScaling.scale_cls_sets(consecutive_layer_list)
    # High Bias Fold
    # Create a list of consecutive conv layers whose previous layers bias has to be folded to next layers bias
    ClsSetInfo = cross_layer_equalization.ClsSetInfo
    ClsPairInfo = cross_layer_equalization.ClsSetInfo.ClsSetLayerPairInfo
    cls_set_info_list = [ClsSetInfo(ClsPairInfo(model.conv1, model.layer1[0].conv1, scaling_factor_list[0], True)),
                         ClsSetInfo(ClsPairInfo(model.layer1[0].conv1, model.layer1[0].conv2, scaling_factor_list[1], True))]
    cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_dict)
Cross Layer Equalization in manual mode for Depthwise Separable layer
def cross_layer_equalization_depthwise_layers():
    model = MobileNetV2().to(torch.device('cpu'))
    model.eval()
    # Batch Norm Fold
    # Create a list of conv/linear and BN layers for folding forward or backward
    layer_list = [(model.features[0][0], model.features[0][1]),
                  (model.features[1].conv[0], model.features[1].conv[1]),
                  (model.features[1].conv[3], model.features[1].conv[4])]
    # Save the corresponding BN layers (needed only for high bias folding)
    bn_dict = {}
    for conv_bn in layer_list:
        bn_dict[conv_bn[0]] = conv_bn[1]
    batch_norm_fold.fold_given_batch_norms(model, layer_list)
    # Replace any ReLU6 layers with ReLU
    utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU)
    # Cross Layer Scaling
    # Create a list of consecutive conv layers to be equalized
    consecutive_layer_list = [(model.features[0][0], model.features[1].conv[0], model.features[1].conv[3])]
    scaling_factor_list = cross_layer_equalization.CrossLayerScaling.scale_cls_sets(consecutive_layer_list)
    # High Bias Fold
    # Create a list of consecutive conv layers whose previous layers bias has to be folded to next layers bias
    ClsSetInfo = cross_layer_equalization.ClsSetInfo
    ClsPairInfo = cross_layer_equalization.ClsSetInfo.ClsSetLayerPairInfo
    cls_set_info_list = [ClsSetInfo(ClsPairInfo(model.features[0][0], model.features[1].conv[0], scaling_factor_list[0][0], True)),
                         ClsSetInfo(ClsPairInfo(model.features[1].conv[0], model.features[1].conv[3], scaling_factor_list[0][1], True))]
    cross_layer_equalization.HighBiasFold.bias_fold(cls_set_info_list, bn_dict)