aimet_torch.quantsim¶
- class aimet_torch.quantsim.QuantizationSimModel(model, dummy_input, quant_scheme=None, rounding_mode=None, default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]¶
Class that simulates the quantized model execution on a target hardware backend.
QuantizationSimModel simulates quantization of a given model by converting all PyTorch modules into quantized modules with input/output/parameter quantizers as necessary.
Example
>>> model = torchvision.models.resnet18() >>> dummy_input = torch.randn(1, 3, 224, 224) >>> sim = QuantizationSimModel(model, dummy_input) >>> print(model) ResNet( (conv1): Conv2d( 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False ) ... ) >>> print(sim.model) ResNet( (conv1): QuantizedConv2d( 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False (param_quantizers): ModuleDict( (weight): QuantizeDequantize(shape=(), qmin=-128, qmax=127, symmetric=True) ) (input_quantizers): ModuleList( (0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False) ) (output_quantizers): ModuleList( (0): None ) ) ... )
Warning
rounding_mode parameter is deprecated. Passing rounding_mode will throw runtime error in >=1.35.
Warning
The default value of quant_scheme will change from QuantScheme.post_training_tf_enhanced to QuantScheme.training_range_learning_with_tf_init in the future versions, and will be deprecated in the longer term.
- Parameters:
model (torch.nn.Module) – Model to simulate the quantized execution of
dummy_input (Tensor | Sequence[Tensor]) – Dummy input to be used to capture the computational graph of the model. All input tensors are expected to be already placed on the appropriate devices to run forward pass of the model.
quant_scheme (QuantScheme, optional) – Quantization scheme that indicates how to observe and calibrate the quantization encodings (Default: QuantScheme.post_training_tf_enhanced)
rounding_mode – Deprecated
default_output_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer inputs and outputs unless otherwise specified in the config file. (Default: 8)
default_param_bw (int, optional) – Default bitwidth (4-31) to use for quantizing all layer parameters unless otherwise specified in the config file. (Default: 8)
in_place (bool, optional) – If True, then the given model is modified in-place into a quantized model. (Default: False)
config_file (str, optional) – Path to the quantization simulation config file (Default: None)
default_data_type (QuantizationDataType, optional) – Default data type to use for quantizing all inputs, outputs and parameters unless otherwise specified in the config file. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 or 32 and default_param_bw=16 or 32. (Default: QuantizationDataType.int)
The following API can be used to Compute encodings for calibration:
- QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args=<class 'aimet_torch.v2.quantsim.quantsim._NOT_SPECIFIED'>)[source]¶
Computes encodings for all quantizers in the model.
This API will invoke forward_pass_callback, a function written by the user that runs forward pass(es) of the quantized model with a small, representative subset of the training dataset. By doing so, the quantizers in the quantized model will observe the inputs and initialize their quantization encodings according to the observed input statistics.
This function is overloaded with the following signatures:
- compute_encodings(forward_pass_callback)[source]
- Parameters:
forward_pass_callback (Callable[[torch.nn.Module], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset
- aimet_torch.quantsim.compute_encodings(forward_pass_callback, forward_pass_callback_args)
- Parameters:
forward_pass_callback (Callable[[torch.nn.Module, T], Any]) – A function that takes a quantized model and runs forward passes with a small, representative subset of training dataset
forward_pass_callback_args (T) – The second argument to forward_pass_callback.
Example
>>> sim = QuantizationSimModel(...) >>> _ = sim.model(input) # Can't run forward until quantizer encodings are initialized RuntimeError: Failed to run QuantizeDequantize since quantization parameters are not initialized. Please initialize the quantization parameters using `compute_encodings()`. >>> def run_forward_pass(quantized_model: torch.nn.Module): ... for input in train_dataloader: ... with torch.no_grad(): ... _ = quantized_model(input) ... >>> sim.compute_encodings(run_forward_pass) >>> _ = sim.model(input) # Now runs successfully!
The following APIs can be used to save and restore the quantized model
- quantsim.save_checkpoint(file_path)¶
This API provides a way for the user to save a checkpoint of the quantized model which can be loaded at a later point to continue fine-tuning e.g. See also load_checkpoint()
- Parameters:
quant_sim_model (
_QuantizationSimModelInterface
) – QuantizationSimModel to save checkpoint forfile_path (
str
) – Path to the file where you want to save the checkpoint
- Returns:
None
- quantsim.load_checkpoint()¶
Load the quantized model
- Parameters:
file_path (
str
) – Path to the file where you want to save the checkpoint- Return type:
_QuantizationSimModelInterface
- Returns:
A new instance of the QuantizationSimModel created after loading the checkpoint
The following API can be used to export the quantized model to target:
- QuantizationSimModel.export(path, filename_prefix, dummy_input, *args, **kwargs)[source]¶
This method exports out the quant-sim model so it is ready to be run on-target.
Specifically, the following are saved:
The sim-model is exported to a regular PyTorch model without any simulation ops
The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)
Optionally, An equivalent model in ONNX format is exported. In addition, nodes in the ONNX model are named the same as the corresponding PyTorch module names. This helps with matching ONNX node to their quant encoding from #2.
- Parameters:
path (
str
) – path where to store model pth and encodingsfilename_prefix (
str
) – Prefix to use for filenames of the model pth and encodings filesdummy_input (
Union
[Tensor
,Tuple
]) – Dummy input to the model. Used to parse model graph. It is required for the dummy_input to be placed on CPU.onnx_export_args – Optional export argument with onnx specific overrides provided as a dictionary or OnnxExportApiArgs object. If not provided, defaults to “opset_version” = None, “input_names” = None, “output_names” = None, and for torch version < 1.10.0, “enable_onnx_checker” = False.
propagate_encodings – If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops. Defaults to False.
export_to_torchscript – If True, export to torchscript. Export to onnx otherwise. Defaults to False.
use_embedded_encodings – If True, another onnx model embedded with fakequant nodes will be exported
export_model – If True, then ONNX model is exported. When False, only encodings are exported. User should disable (False) this flag only if the corresponding ONNX model already exists in the path specified
filename_prefix_encodings – File name prefix to be used when saving encodings. If None, then user defaults to filename_prefix value
Quant Scheme Enum
- class aimet_common.defs.QuantScheme(value)[source]¶
Enumeration of Quant schemes
- post_training_percentile = 6¶
For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.
- post_training_tf = 1¶
For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.
- post_training_tf_enhanced = 2¶
For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.
- training_range_learning_with_tf_enhanced_init = 4¶
For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.
- training_range_learning_with_tf_init = 3¶
For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.