AIMET PyTorch Mixed Precision API

Top-level API

aimet_torch.mixed_precision.choose_mixed_precision(sim, dummy_input, candidates, eval_callback_for_phase1, eval_callback_for_phase2, allowed_accuracy_drop, results_dir, clean_start, forward_pass_callback, use_all_amp_candidates=False, phase2_reverse=False, phase1_optimize=True, amp_search_algo=AMPSearchAlgo.Binary)[source]

High-level API to perform in place Mixed Precision evaluation on the given sim model. A pareto list is created and a curve for Accuracy vs BitOps is saved under the results directory

Parameters:
  • sim (QuantizationSimModel) – Quantized sim model

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to the model. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.

  • candidates (List[Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]]]) –

    List of tuples for all possible bitwidth values for activations and parameters Suppose the possible combinations are- ((Activation bitwidth - 8, Activation data type - int), (Parameter bitwidth - 16, parameter data type - int)) ((Activation bitwidth - 16, Activation data type - float), (Parameter bitwidth - 16, parameter data type - float)) candidates will be [((8, QuantizationDataType.int), (16, QuantizationDataType.int)),

    ((16, QuantizationDataType.float), (16, QuantizationDataType.float))]

  • eval_callback_for_phase1 (CallbackFunc) – An object of CallbackFunc class which takes in Eval function (callable) and eval function parameters. This evaluation callback used to measure sensitivity of each quantizer group during phase 1. The phase 1 involves finding accuracy list/sensitivity of each module. Therefore, a user might want to run the phase 1 with a smaller dataset

  • eval_callback_for_phase2 (CallbackFunc) – An object of CallbackFunc class which takes in Eval function (callable) and eval function parameters. Evaluation callback used to get accuracy of quantized model for phase 2 calculations. The phase 2 involves finding pareto front curve

  • allowed_accuracy_drop (Optional[float]) – Maximum allowed drop in accuracy from FP32 baseline. The pareto front curve is plotted only till the point where the allowable accuracy drop is met. To get a complete plot for picking points on the curve, the user can set the allowable accuracy drop to None.

  • results_dir (str) – Path to save results and cache intermediate results

  • clean_start (bool) – If true, any cached information from previous runs will be deleted prior to starting the mixed-precision analysis. If false, prior cached information will be used if applicable. Note it is the user’s responsibility to set this flag to true if anything in the model or quantization parameters changes compared to the previous run.

  • forward_pass_callback (CallbackFunc) – An object of CallbackFunc class which takes in Forward pass function (callable) and its function parameters. Forward pass callback used to compute quantization encodings

  • use_all_amp_candidates (bool) – Using the “supported_kernels” field in the config file (under defaults and op_type sections), a list of supported candidates can be specified. All the AMP candidates which are passed through the “candidates” field may not be supported based on the data passed through “supported_kernels”. When the field “use_all_amp_candidates” is set to True, the AMP algorithm will ignore the “supported_kernels” in the config file and continue to use all candidates.

  • phase2_reverse (bool) – If user will set this parameter to True, then phase1 of amp algo, that is calculating accuracy list will not be changed, whereas the phase2 algo of amp, which generate the pareto list will be changed. In phase2, algo will start, model with all quantizer groups in least candidate, and one by one, it will put nodes in higher candidate till target accuracy does not meet.

  • phase1_optimize (bool) – If user set this parameter to false then phase1 default logic will be executed else optimized logic will be executed.

  • amp_search_algo (AMPSearchAlgo) – A valid value from the Enum AMPSearchAlgo. Defines the search algorithm to be used for the phase 2 of AMP.

Return type:

Optional[List[Tuple[int, float, QuantizerGroup, int]]]

Returns:

Pareto front list containing information including Bitops, QuantizerGroup candidates and corresponding eval scores. The Pareto front list can be used for plotting a pareto front curve which provides information regarding how bit ops vary w.r.t. accuracy. If the allowable accuracy drop is set to 100% then a user can use the pareto front curve to pick points and re-run, None if we early exit the mixed precision algorithm.


Note: To enable phase-3 set the attribute GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True

Currently only two candidates are supported - ((8,int), (8,int)) & ((16,int), (8,int))


Quantizer Groups definition

class aimet_torch.amp.quantizer_groups.QuantizerGroup(input_quantizers=<factory>, output_quantizers=<factory>, parameter_quantizers=<factory>)[source]

Group of modules and quantizers

get_active_quantizers(name_to_quantizer_dict)[source]

Find all active tensor quantizers associated with this quantizer group

get_candidate(name_to_quantizer_dict)[source]

Gets Activation & parameter bitwidth :type name_to_quantizer_dict: Dict :param name_to_quantizer_dict: Gets module from module name :rtype: Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]] :return: Tuple of Activation, parameter bitwidth and data type

get_input_quantizer_modules()[source]

helper method to get the module names corresponding to input_quantizers

set_quantizers_to_candidate(name_to_quantizer_dict, candidate)[source]

Sets a quantizer group to a given candidate bitwidth :type name_to_quantizer_dict: Dict :param name_to_quantizer_dict: Gets module from module name :type candidate: Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]] :param candidate: candidate with act and param bw and data types

Return type:

None

to_list()[source]

Converts quantizer group to a list :rtype: List[Tuple[str, str]] :return: List containing input/output quantizers & weight quantizers


CallbackFunc Definition

class aimet_common.defs.CallbackFunc(func, func_callback_args=None)[source]

Class encapsulating call back function and it’s arguments

Parameters:
  • func (Callable) – Callable Function

  • func_callback_args – Arguments passed to the callable function


class aimet_torch.amp.mixed_precision_algo.EvalCallbackFactory(data_loader, forward_fn=None)[source]

Factory class for various built-in eval callbacks

Parameters:
  • data_loader (DataLoader) – Data loader to be used for evaluation

  • forward_fn (Optional[Callable[[Module, Any], Tensor]]) – Function that runs forward pass and returns the output tensor. This function is expected to take 1) a model and 2) a single batch yielded from the data loader, and return a single torch.Tensor object which represents the output of the model. The default forward function is roughly equivalent to lambda model, batch: model(batch)

sqnr(num_samples=128)[source]

Returns SQNR eval callback.

Parameters:

num_samples (int) – Number of samples used for evaluation

Return type:

CallbackFunc

Returns:

A callback function that evaluates the input model’s SQNR between fp32 outputs and fake-quantized outputs


Code Examples

Required imports

import torch
from aimet_common.defs import QuantizationDataType, CallbackFunc
from aimet_torch.mixed_precision import choose_mixed_precision
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo

Quantization with mixed precision

def quantize_with_mixed_precision(model):
    """
    Code example showing the call flow for Auto Mixed Precision
    """
    # Define parameters to pass to mixed precision algo
    dummy_input = torch.randn(1, 1, 28, 28).cuda()
    default_bitwidth = 16
    # ((activation bitwidth, activation data type), (param bitwidth, param data type))
    candidates = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
                 ((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
                 ((8, QuantizationDataType.int), (16, QuantizationDataType.int))]
    # Allowed accuracy drop in absolute value
    allowed_accuracy_drop = 0.5 # Implies 50% drop

    eval_callback_for_phase_1 = CallbackFunc(eval_callback_func, func_callback_args=5000)
    eval_callback_for_phase_2 = CallbackFunc(eval_callback_func, func_callback_args=None)

    forward_pass_call_back = CallbackFunc(forward_pass_callback, func_callback_args=dummy_input)

    # Create quant sim
    sim = QuantizationSimModel(model, default_param_bw=default_bitwidth, default_output_bw=default_bitwidth,
                               dummy_input=dummy_input)
    sim.compute_encodings(forward_pass_callback, forward_pass_callback_args=None)

    # Enable phase-3 (optional)
    # GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
    # Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))

    # Call the mixed precision algo with clean start = True i.e. new accuracy list and pareto list will be generated
    # If set to False then pareto front list and accuracy list will be loaded from the provided directory path
    pareto_front_list = choose_mixed_precision(sim, dummy_input, candidates, eval_callback_for_phase_1,
                                               eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                               clean_start=True, forward_pass_callback=forward_pass_call_back)

    print(pareto_front_list)
    sim.export("./data", str(allowed_accuracy_drop), dummy_input)

Quantization with mixed precision start from existing cache

def quantize_with_mixed_precision_start_from_existing_cache(model):
    """
    Code example shows how to start from an existing cache when using the API of Auto Mixed Precision
    """
    # Define parameters to pass to mixed precision algo
    dummy_input = torch.randn(1, 1, 28, 28).cuda()
    default_bitwidth = 16
    # ((activation bitwidth, activation data type), (param bitwidth, param data type))
    candidates = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
                 ((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
                 ((8, QuantizationDataType.int), (16, QuantizationDataType.int))]
    # Allowed accuracy drop in absolute value
    allowed_accuracy_drop = 0.5 # Implies 50% drop

    eval_callback_for_phase_1 = CallbackFunc(eval_callback_func, func_callback_args=5000)
    eval_callback_for_phase_2 = CallbackFunc(eval_callback_func, func_callback_args=None)

    forward_pass_call_back = CallbackFunc(forward_pass_callback, func_callback_args=dummy_input)

    # Create quant sim
    sim = QuantizationSimModel(model, default_param_bw=default_bitwidth, default_output_bw=default_bitwidth,
                               dummy_input=dummy_input)
    sim.compute_encodings(forward_pass_callback, forward_pass_callback_args=None)

    # Enable phase-3 (optional)
    GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True

    # Call the mixed precision algo with clean start = True i.e. new accuracy list and pareto list will be generated
    # If set to False then pareto front list and accuracy list will be loaded from the provided directory path
    # A allowed_accuracy_drop can be specified to export the final model with reference to the pareto list
    pareto_front_list = choose_mixed_precision(sim, dummy_input, candidates, eval_callback_for_phase_1,
                                               eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                               clean_start=True, forward_pass_callback=forward_pass_call_back)

    print(pareto_front_list)
    sim.export("./data", str(allowed_accuracy_drop), dummy_input)

    # Set clean_start to False to start from an existing cache
    # Set allowed_accuracy_drop to 0.9 to export the 90% drop point in pareto list
    allowed_accuracy_drop = 0.9
    pareto_front_list = choose_mixed_precision(sim, dummy_input, candidates, eval_callback_for_phase_1,
                                               eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                               clean_start=False, forward_pass_callback=forward_pass_call_back)
    print(pareto_front_list)
    sim.export("./data", str(allowed_accuracy_drop), dummy_input)

Eval function

def eval_callback_func(model, number_of_samples):
    """ Call eval function for model """
    # Note: A user can populate this function as per their model. This is a toy example to show how the API
    # for the function can look like
    model.perform_eval(number_of_samples)

Forward Pass

def forward_pass_callback(model, input_tensor):
    """ Call forward pass of model """
    # Note: A user can populate this function as per their model. This is a toy example to show how the API
    # for the function can look like
    return model(input_tensor)