Quantized Modules
To simulate the effects of running networks at a reduced bitwidth, AIMET introduced quantized modules, the extension of standard torch.nn.Modules with some extra capabilities for quantization. These quantized modules serve as drop-in replacements for their PyTorch counterparts, but can hold input, output, and parameter quantizers to perform quantization operations during the module’s forward pass and compute quantization encodings.
More specifically, a quantized module inherits both from QuantizationMixin and a native torch.nn.Module type, typically with “Quantized-” prefix prepended to the original class name, such as QuantizedConv2d for torch.nn.Conv2d or QuantizedSoftmax for torch.nn.Softmax. For more detailed API reference of QuantizationMixin class, see QuantizationMixin API reference. For the full list of all built-in quantized modules in AIMET, see Quantized Module Classes
Top-level API
- class aimet_torch.v2.nn.QuantizationMixin(*args, **kwargs)[source]
Mixin that adds quantization functionality on top of regular pytorch modules.
Specifically, a quantized module will quantize input, output, and parameter tensors with its held
QuantizerBase
objects during theforward()
method and use the inheritedtorch.nn.Module
forward method to compute the layer operation. If all input, output, and parameter quantizers areNone
, a quantized module will behave exactly the same as its parenttorch.nn.Module
.- input_quantizers
torch.nn.ModuleList
containingQuantizerBase
objects to be applied to the layer’s input tensors
- output_quantizers
torch.nn.ModuleList
containingQuantizerBase
objects to be applied to the layer’s output tensors
- param_quantizers
torch.nn.ModuleDict
mapping parameter names to associatedQuantizerBase
objects
Examples
>>> qlinear = QuantizedLinear(in_features=10, out_features=10) >>> print(qlinear) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) )
- __quant_init__()
Initializer for quantized module. This method will be invoked right after
__init__()
.This method initializes the
input_quantizers
,output_quantizers
, andparam_quantizers
structures to the appropriate sizes based on the number of input tensors, output tensors, and parameters of the basenn.Module
class. All quantizers are initializd toNone
.For custom quantized classes, this method should be overridden to set the appropriate lengths of
input_quantizers
andoutput_quantizers
for the given base class.
- compute_encodings()[source]
Enters the
compute_encodings()
context for allQuantizerBase
objects in the layer.Inside this context, each quantizer will observe all inputs passed to the quantizer and will compute quantization encodings upon exiting the context.
Example
>>> qlinear = QuantizedLinear(10, 10) >>> qlinear.output_quantizers[0] = Quantize((), 8, symmetric=False) >>> with qlinear.compute_encodings(): >>> qlinear(torch.randn(16, 10)) >>> print(qlinear.output_quantizers[0].is_initialized()) True
- abstract forward(*args, **kwargs)[source]
Computes a quantized version of the parent module’s forward method.
The
forward()
method should perform the following logic in order:Apply existing input quantizers to input tensors
Apply existing param quantizers to the layer’s parameters
Call the inherited
torch.nn.Module
forward method with quantized inputs and parametersApply existing output quantizers to the outputs of the forward method
If all input, output, and parameter quantizers are
None
, this method will behave exactly the same as its parent module’s forward pass.
Configuration
The quantization behavior of a quantized module is controlled by the quantizers contained within the input, output, and parameter quantizer attributes listed below.
Attribute |
Type |
Description |
---|---|---|
input_quantizers |
torch.nn.ModuleList |
List of quantizers for input tensors |
param_quantizers |
torch.nn.ModuleDict |
Dict mapping parameter names to quantizers |
output_quantizers |
torch.nn.ModuleList |
List of quantizers for output tensors |
By assigning and configuring quantizers to these structures, we define the type of quantization applied to the corresponding input index, output index, or parameter name. By default, all the quantizers are set to None, meaning that no quantization will be applied to the respective tensor.
- Example: Create a linear layer which performs only per-channel weight quantization
>>> import aimet_torch.v2 as aimet >>> import aimet_torch.quantization as Q >>> qlinear = aimet.nn.QuantizedLinear(out_features=10, in_features=5) >>> # Per-channel weight quantization is performed over the `out_features` dimension, so encodings are shape (10, 1) >>> per_channel_quantizer = Q.affine.QuantizeDequantize(shape=(10, 1), bitwidth=8, symmetric=True) >>> qlinear.param_quantizers["weight"] = per_channel_quantizer
- Example: Create an elementwise multiply layer which quantizes only the output and the second input
>>> qmul = aimet.nn.custom.QuantizedMultiply() >>> qmul.output_quantizers[0] = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False) >>> qmul.input_quantizers[1] = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False)
In some cases, it may make sense for multiple tensors to share the same quantizer. In this case, we can assign the same quantizer to multiple indices.
- Example: Create an elementwise add layer which shares the same quantizer between its inputs
>>> qadd = aimet.nn.custom.QuantizedAdd() >>> quantizer = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False) >>> qadd.input_quantizers[0] = quantizer >>> qadd.input_quantizers[1] = quantizer
Computing Encodings
Before a module can compute a quantized forward pass, all quantizers must first be calibrated inside a compute_encodings context. When a quantized module enters the compute_encodings context, it first disables all input and output quantization while the quantizers observe the statistics of the activation tensors passing through them. Upon exiting the context, the quantizers calculate appropriate quantization encodings based on these statistics (exactly how the encodings are computed is determined by each quantizer’s encoding analyzer).
- Example:
>>> qlinear = aimet.nn.QuantizedLinear(out_features=10, in_features=5) >>> qlinear.output_quantizers[0] = Q.affine.QuantizeDequantize((1, ), bitwidth=8, symmetric=False) >>> qlinear.param_quantizers[0] = Q.affine.QuantizeDequantize((10, 1), bitwidth=8, symmetric=True) >>> with qlinear.compute_encodings(): ... # Pass several samples through the layer to ensure representative statistics ... for x, _ in calibration_data_loader: ... qlinear(x) >>> print(qlinear.output_quantizers[0].is_initialized()) True >>> print(qlinear.param_quantizers["weight"].is_initialized()) True
Quantized Module Classes
nn.Module |
QuantizationMixin |
---|---|
torch.nn.AdaptiveAvgPool1d |
QuantizedAdaptiveAvgPool1d |
torch.nn.AdaptiveAvgPool2d |
QuantizedAdaptiveAvgPool2d |
torch.nn.AdaptiveAvgPool3d |
QuantizedAdaptiveAvgPool3d |
torch.nn.AdaptiveMaxPool1d |
QuantizedAdaptiveMaxPool1d |
torch.nn.AdaptiveMaxPool2d |
QuantizedAdaptiveMaxPool2d |
torch.nn.AdaptiveMaxPool3d |
QuantizedAdaptiveMaxPool3d |
torch.nn.AlphaDropout |
QuantizedAlphaDropout |
torch.nn.AvgPool1d |
QuantizedAvgPool1d |
torch.nn.AvgPool2d |
QuantizedAvgPool2d |
torch.nn.AvgPool3d |
QuantizedAvgPool3d |
torch.nn.BatchNorm1d |
QuantizedBatchNorm1d |
torch.nn.BatchNorm2d |
QuantizedBatchNorm2d |
torch.nn.BatchNorm3d |
QuantizedBatchNorm3d |
torch.nn.CELU |
QuantizedCELU |
torch.nn.ChannelShuffle |
QuantizedChannelShuffle |
torch.nn.ConstantPad1d |
QuantizedConstantPad1d |
torch.nn.ConstantPad2d |
QuantizedConstantPad2d |
torch.nn.ConstantPad3d |
QuantizedConstantPad3d |
torch.nn.Conv1d |
QuantizedConv1d |
torch.nn.Conv2d |
QuantizedConv2d |
torch.nn.Conv3d |
QuantizedConv3d |
torch.nn.ConvTranspose1d |
QuantizedConvTranspose1d |
torch.nn.ConvTranspose2d |
QuantizedConvTranspose2d |
torch.nn.ConvTranspose3d |
QuantizedConvTranspose3d |
torch.nn.Dropout |
QuantizedDropout |
torch.nn.Dropout2d |
QuantizedDropout2d |
torch.nn.Dropout3d |
QuantizedDropout3d |
torch.nn.ELU |
QuantizedELU |
torch.nn.FeatureAlphaDropout |
QuantizedFeatureAlphaDropout |
torch.nn.Flatten |
QuantizedFlatten |
torch.nn.Fold |
QuantizedFold |
torch.nn.FractionalMaxPool2d |
QuantizedFractionalMaxPool2d |
torch.nn.FractionalMaxPool3d |
QuantizedFractionalMaxPool3d |
torch.nn.GELU |
QuantizedGELU |
torch.nn.GLU |
QuantizedGLU |
torch.nn.GroupNorm |
QuantizedGroupNorm |
torch.nn.Hardshrink |
QuantizedHardshrink |
torch.nn.Hardsigmoid |
QuantizedHardsigmoid |
torch.nn.Hardswish |
QuantizedHardswish |
torch.nn.Hardtanh |
QuantizedHardtanh |
torch.nn.InstanceNorm1d |
QuantizedInstanceNorm1d |
torch.nn.InstanceNorm2d |
QuantizedInstanceNorm2d |
torch.nn.InstanceNorm3d |
QuantizedInstanceNorm3d |
torch.nn.LPPool1d |
QuantizedLPPool1d |
torch.nn.LPPool2d |
QuantizedLPPool2d |
torch.nn.LayerNorm |
QuantizedLayerNorm |
torch.nn.LeakyReLU |
QuantizedLeakyReLU |
torch.nn.Linear |
QuantizedLinear |
torch.nn.LocalResponseNorm |
QuantizedLocalResponseNorm |
torch.nn.LogSigmoid |
QuantizedLogSigmoid |
torch.nn.LogSoftmax |
QuantizedLogSoftmax |
torch.nn.MaxPool1d |
QuantizedMaxPool1d |
torch.nn.MaxPool2d |
QuantizedMaxPool2d |
torch.nn.MaxPool3d |
QuantizedMaxPool3d |
torch.nn.MaxUnpool1d |
QuantizedMaxUnpool1d |
torch.nn.MaxUnpool2d |
QuantizedMaxUnpool2d |
torch.nn.MaxUnpool3d |
QuantizedMaxUnpool3d |
torch.nn.Mish |
QuantizedMish |
torch.nn.PReLU |
QuantizedPReLU |
torch.nn.PixelShuffle |
QuantizedPixelShuffle |
torch.nn.PixelUnshuffle |
QuantizedPixelUnshuffle |
torch.nn.RReLU |
QuantizedRReLU |
torch.nn.ReLU |
QuantizedReLU |
torch.nn.ReLU6 |
QuantizedReLU6 |
torch.nn.ReflectionPad1d |
QuantizedReflectionPad1d |
torch.nn.ReflectionPad2d |
QuantizedReflectionPad2d |
torch.nn.ReplicationPad1d |
QuantizedReplicationPad1d |
torch.nn.ReplicationPad2d |
QuantizedReplicationPad2d |
torch.nn.ReplicationPad3d |
QuantizedReplicationPad3d |
torch.nn.SELU |
QuantizedSELU |
torch.nn.SiLU |
QuantizedSiLU |
torch.nn.Sigmoid |
QuantizedSigmoid |
torch.nn.Softmax |
QuantizedSoftmax |
torch.nn.Softmax2d |
QuantizedSoftmax2d |
torch.nn.Softmin |
QuantizedSoftmin |
torch.nn.Softplus |
QuantizedSoftplus |
torch.nn.Softshrink |
QuantizedSoftshrink |
torch.nn.Softsign |
QuantizedSoftsign |
torch.nn.Tanh |
QuantizedTanh |
torch.nn.Tanhshrink |
QuantizedTanhshrink |
torch.nn.Threshold |
QuantizedThreshold |
torch.nn.Unflatten |
QuantizedUnflatten |
torch.nn.Unfold |
QuantizedUnfold |
torch.nn.Upsample |
QuantizedUpsample |
torch.nn.UpsamplingBilinear2d |
QuantizedUpsamplingBilinear2d |
torch.nn.UpsamplingNearest2d |
QuantizedUpsamplingNearest2d |
torch.nn.ZeroPad2d |
QuantizedZeroPad2d |
torch.nn.BCELoss |
QuantizedBCELoss |
torch.nn.BCEWithLogitsLoss |
QuantizedBCEWithLogitsLoss |
torch.nn.Bilinear |
QuantizedBilinear |
torch.nn.CTCLoss |
QuantizedCTCLoss |
torch.nn.CosineSimilarity |
QuantizedCosineSimilarity |
torch.nn.CrossEntropyLoss |
QuantizedCrossEntropyLoss |
torch.nn.HingeEmbeddingLoss |
QuantizedHingeEmbeddingLoss |
torch.nn.HuberLoss |
QuantizedHuberLoss |
torch.nn.KLDivLoss |
QuantizedKLDivLoss |
torch.nn.L1Loss |
QuantizedL1Loss |
torch.nn.MSELoss |
QuantizedMSELoss |
torch.nn.MultiLabelMarginLoss |
QuantizedMultiLabelMarginLoss |
torch.nn.MultiLabelSoftMarginLoss |
QuantizedMultiLabelSoftMarginLoss |
torch.nn.MultiMarginLoss |
QuantizedMultiMarginLoss |
torch.nn.NLLLoss |
QuantizedNLLLoss |
torch.nn.NLLLoss2d |
QuantizedNLLLoss2d |
torch.nn.PairwiseDistance |
QuantizedPairwiseDistance |
torch.nn.PoissonNLLLoss |
QuantizedPoissonNLLLoss |
torch.nn.SmoothL1Loss |
QuantizedSmoothL1Loss |
torch.nn.SoftMarginLoss |
QuantizedSoftMarginLoss |
torch.nn.CosineEmbeddingLoss |
QuantizedCosineEmbeddingLoss |
torch.nn.GaussianNLLLoss |
QuantizedGaussianNLLLoss |
torch.nn.MarginRankingLoss |
QuantizedMarginRankingLoss |
torch.nn.TripletMarginLoss |
QuantizedTripletMarginLoss |
torch.nn.TripletMarginWithDistanceLoss |
QuantizedTripletMarginWithDistanceLoss |
torch.nn.Embedding |
QuantizedEmbedding |
torch.nn.EmbeddingBag |
QuantizedEmbeddingBag |
torch.nn.GRU |
QuantizedGRU |
torch.nn.RNN |
QuantizedRNN |
torch.nn.GRUCell |
QuantizedGRUCell |
torch.nn.RNNCell |
QuantizedRNNCell |
torch.nn.LSTM |
QuantizedLSTM |
torch.nn.LSTMCell |
QuantizedLSTMCell |
aimet_torch.v2.nn.custom.AvgPool2d |
QuantizedAvgPool2d |
aimet_torch.v2.nn.custom.CumSum |
QuantizedCumSum |
aimet_torch.v2.nn.custom.Sin |
QuantizedSin |
aimet_torch.v2.nn.custom.Cos |
QuantizedCos |
aimet_torch.v2.nn.custom.RSqrt |
QuantizedRSqrt |
aimet_torch.v2.nn.custom.Reshape |
QuantizedReshape |
aimet_torch.v2.nn.custom.MatMul |
QuantizedMatMul |
aimet_torch.v2.nn.custom.Add |
QuantizedAdd |
aimet_torch.v2.nn.custom.Multiply |
QuantizedMultiply |
aimet_torch.v2.nn.custom.Subtract |
QuantizedSubtract |
aimet_torch.v2.nn.custom.Divide |
QuantizedDivide |
aimet_torch.v2.nn.custom.Bmm |
QuantizedBmm |
aimet_torch.v2.nn.custom.Baddbmm |
QuantizedBaddbmm |
aimet_torch.v2.nn.custom.Addmm |
QuantizedAddmm |
aimet_torch.v2.nn.custom.Concat |
QuantizedConcat |