AIMET PyTorch AdaRound API¶
Top-level API¶
-
aimet_torch.adaround.adaround_weight.Adaround.
apply_adaround
(model, dummy_input, params, path, filename_prefix, default_param_bw=4, param_bw_override_list=None, ignore_quant_ops_list=None, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_config_file=None)¶ Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the corresponding quantization encodings to a separate JSON-formatted file that can then be imported by QuantSim for inference or QAT
- Parameters
model (
Module
) – Model to Adarounddummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. If the model has more than one input, pass a tuple. User is expected to place the tensors on the appropriate device.params (
AdaroundParameters
) – Parameters for Adaroundpath (
str
) – path where to store parameter encodingsfilename_prefix (
str
) – Prefix to use for filename of the encodings filedefault_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer parametersparam_bw_override_list (
Optional
[List
[Tuple
[Module
,int
]]]) – List of Tuples. Each Tuple is a module and the corresponding parameter bitwidth to be used for that module.ignore_quant_ops_list (
Optional
[List
[Module
]]) – Ops listed here are skipped during quantization needed for AdaRounding. Do not specify Conv and Linear modules in this list. Doing so, will affect accuracy.default_quant_scheme (
QuantScheme
) – Quantization scheme. Supported options are using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanceddefault_config_file (
Optional
[str
]) – Default configuration file for model quantizers
- Return type
Module
- Returns
Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
Multiple parameters can be specified by the users of the API.
- Parameters
model - Model to apply AdaRound to
params - AdaroundParameters, explained below the API parameters description
default_param_bw - Default bitwidth (4-31) to use for initializing the encodings used for adaptive rounding. Default: 4
default_quant_scheme - Default Quantization scheme used for initializing encodings used for adaptive rounding. Supported options are using Quant Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced. Default: QuantScheme.post_training_tf_enhanced
config_file - Configuration file for model quantizers
- AdaroundParameters
data_loader - The Data Loader containing training data
num_batches - The number of batches to use for adarounding
default_num_iterations - Number of iterations to adaround each layer. Default: 10000
default_reg_param - Regularization parameter, trading off between rounding loss vs reconstruction loss. Default: 0.01
default_beta_range - Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta). Default: (20, 2)
default_warm_start - warm up period, during which rounding loss has zero effect. Default: 20% (0.2)
Enum Definition¶
Quant Scheme Enum
-
class
aimet_common.defs.
QuantScheme
¶ Enumeration of Quant schemes
-
post_training_tf
= 1¶ Tf scheme
-
post_training_tf_enhanced
= 2¶ Tf- enhanced scheme
-
Code Examples¶
Required imports
import logging
import torch
import torch.cuda
from torchvision import models
from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme
from aimet_torch.utils import create_fake_data_loader
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
Evaluation function
def dummy_forward_pass(model: torch.nn.Module, forward_pass_callback_args) -> float:
"""
This is intended to be the user-defined model evaluation function.
AIMET requires the above signature. So if the user's eval function does not
match this signature, please create a simple wrapper.
:param model: Model to evaluate
:param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
the user to determine the type of this parameter. E.g. could be simply an integer representing the number
of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
If set to None, forward_pass_callback will be invoked with no parameters.
:return: single float number (accuracy) representing model's performance
"""
return .5
After applying AdaRound to ResNet18, the AdaRounded model and associated encodings are returned
def apply_adaround_example():
AimetLogger.set_level_for_all_areas(logging.DEBUG)
torch.cuda.empty_cache()
model = models.resnet18(pretrained=True).eval()
model = model.to(torch.device('cuda'))
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape).to(torch.device('cuda'))
# As an illustrating example, a fake data loader is used here.
# For AdaRound, the user should provide the training data loader.
data_loader = create_fake_data_loader(dataset_size=64, batch_size=16, image_size=input_shape[1:])
params = AdaroundParameters(data_loader=data_loader, num_batches=4, default_num_iterations=50,
default_reg_param=0.01, default_beta_range=(20, 2))
# Returns model with adarounded weights and their corresponding encodings
adarounded_model = Adaround.apply_adaround(model, dummy_input, params, path='./',
filename_prefix='resnet18', default_param_bw=4,
default_quant_scheme=QuantScheme.post_training_tf_enhanced,
default_config_file=None)
# Create QuantSim using adarounded_model
sim = QuantizationSimModel(adarounded_model, quant_scheme=quant_scheme, default_param_bw=param_bw,
default_output_bw=output_bw, dummy_input=dummy_input)
# Set and freeze encodings to use same quantization grid and then invoke compute encodings
sim.set_and_freeze_param_encodings(encoding_path='./resnet18.encodings')
sim.compute_encodings(dummy_forward_pass, forward_pass_callback_args=None)