AIMET PyTorch AdaRound API

Top-level API

aimet_torch.adaround.adaround_weight.Adaround.apply_adaround(model, dummy_input, params, path, filename_prefix, default_param_bw=4, param_bw_override_list=None, ignore_quant_ops_list=None, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_config_file=None)

Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the corresponding quantization encodings to a separate JSON-formatted file that can then be imported by QuantSim for inference or QAT

Parameters
  • model (Module) – Model to Adaround

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to the model. Used to parse model graph. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.

  • params (AdaroundParameters) – Parameters for Adaround

  • path (str) – path where to store parameter encodings

  • filename_prefix (str) – Prefix to use for filename of the encodings file

  • default_param_bw (int) – Default bitwidth (4-31) to use for quantizing layer parameters

  • param_bw_override_list (Optional[List[Tuple[Module, int]]]) – List of Tuples. Each Tuple is a module and the corresponding parameter bitwidth to be used for that module.

  • ignore_quant_ops_list (Optional[List[Module]]) – Ops listed here are skipped during quantization needed for AdaRounding. Do not specify Conv and Linear modules in this list. Doing so, will affect accuracy.

  • default_quant_scheme (QuantScheme) – Quantization scheme. Supported options are using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced

  • default_config_file (Optional[str]) – Default configuration file for model quantizers

Return type

Module

Returns

Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path

Adaround Parameters

class aimet_torch.adaround.adaround_weight.AdaroundParameters(data_loader, num_batches, default_num_iterations=10000, default_reg_param=0.01, default_beta_range=(20, 2), default_warm_start=0.2)

Configuration parameters for Adaround

Parameters
  • data_loader (DataLoader[+T_co]) – Data loader

  • num_batches (int) – Number of batches

  • default_num_iterations (int) – Number of iterations to adaround each layer. Default 10000

  • default_reg_param (float) – Regularization parameter, trading off between rounding loss vs reconstruction loss. Default 0.01

  • default_beta_range (Tuple) – Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta). Default (20, 2)

  • default_warm_start (float) – warm up period, during which rounding loss has zero effect. Default 20% (0.2)

Enum Definition

Quant Scheme Enum

class aimet_common.defs.QuantScheme

Enumeration of Quant schemes

post_training_tf = 1

Tf scheme

post_training_tf_enhanced = 2

Tf- enhanced scheme


Code Examples

Required imports


import logging
import torch
import torch.cuda
from torchvision import models

from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme
from aimet_torch.utils import create_fake_data_loader
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters

Evaluation function

def dummy_forward_pass(model: torch.nn.Module, forward_pass_callback_args) -> float:
    """
    This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's eval function does not
    match this signature, please create a simple wrapper.

    :param model: Model to evaluate
    :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
            the user to determine the type of this parameter. E.g. could be simply an integer representing the number
            of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
            If set to None, forward_pass_callback will be invoked with no parameters.
    :return: single float number (accuracy) representing model's performance
    """
    return .5

After applying AdaRound to ResNet18, the AdaRounded model and associated encodings are returned

def apply_adaround_example():

    AimetLogger.set_level_for_all_areas(logging.DEBUG)
    torch.cuda.empty_cache()

    model = models.resnet18(pretrained=True).eval()
    model = model.to(torch.device('cuda'))
    input_shape = (1, 3, 224, 224)
    dummy_input = torch.randn(input_shape).to(torch.device('cuda'))

    # As an illustrating example, a fake data loader is used here.
    # For AdaRound, the user should provide the training data loader.
    data_loader = create_fake_data_loader(dataset_size=64, batch_size=16, image_size=input_shape[1:])

    params = AdaroundParameters(data_loader=data_loader, num_batches=4, default_num_iterations=50,
                                default_reg_param=0.01, default_beta_range=(20, 2))

    # Returns model with adarounded weights and their corresponding encodings
    adarounded_model = Adaround.apply_adaround(model, dummy_input, params, path='./',
                                               filename_prefix='resnet18', default_param_bw=4,
                                               default_quant_scheme=QuantScheme.post_training_tf_enhanced,
                                               default_config_file=None)

    # Create QuantSim using adarounded_model
    sim = QuantizationSimModel(adarounded_model, quant_scheme=quant_scheme, default_param_bw=param_bw,
                               default_output_bw=output_bw, dummy_input=dummy_input)

    # Set and freeze encodings to use same quantization grid and then invoke compute encodings
    sim.set_and_freeze_param_encodings(encoding_path='./resnet18.encodings')
    sim.compute_encodings(dummy_forward_pass, forward_pass_callback_args=None)