aimet_torch.nn

Quantized modules

To simulate the effects of running networks at a reduced bitwidth, AIMET introduced quantized modules, the extension of standard torch.nn.Module with some extra capabilities for quantization. These quantized modules serve as drop-in replacements for their PyTorch counterparts, but can hold input, output, and parameter quantizers to perform quantization operations during the module’s forward pass and compute quantization encodings.

More specifically, a quantized module inherits both from QuantizationMixin and a native torch.nn.Module type, typically with “Quantized-” prefix prepended to the original class name, such as QuantizedConv2d for torch.nn.Conv2d or QuantizedSoftmax for torch.nn.Softmax.

class aimet_torch.nn.QuantizationMixin(*args, **kwargs)[source]

Quantization mixin class for torch.nn.Module.

Specifically, a quantized module will quantize input, output, and parameter tensors with its held QuantizerBase objects during the forward() method and use the inherited torch.nn.Module forward method to compute the layer operation. If all input, output, and parameter quantizers are None, a quantized module will behave exactly the same as its parent torch.nn.Module.

input_quantizers

torch.nn.ModuleList containing QuantizerBase objects to be applied to the layer’s input tensors

output_quantizers

torch.nn.ModuleList containing QuantizerBase objects to be applied to the layer’s output tensors

param_quantizers

torch.nn.ModuleDict mapping parameter names to associated QuantizerBase objects

Examples

>>> qlinear = QuantizedLinear(in_features=10, out_features=10)
>>> print(qlinear)
QuantizedLinear(
  in_features=10, out_features=10, bias=True
  (param_quantizers): ModuleDict(
    (weight): None
    (bias): None
  )
  (input_quantizers): ModuleList(
    (0): None
  )
  (output_quantizers): ModuleList(
    (0): None
  )
)
__quant_init__()

Initializer for quantized module. This method will be invoked right after __init__().

This method initializes the input_quantizers, output_quantizers, and param_quantizers structures to the appropriate sizes based on the number of input tensors, output tensors, and parameters of the base nn.Module class. All quantizers are initializd to None.

For custom quantized classes, this method should be overridden to set the appropriate lengths of input_quantizers and output_quantizers for the given base class.

compute_encodings()[source]

Enters the compute_encodings() context for all QuantizerBase objects in the layer.

Inside this context, each quantizer will observe all inputs passed to the quantizer and will compute quantization encodings upon exiting the context.

Example

>>> qlinear = QuantizedLinear(10, 10)
>>> qlinear.output_quantizers[0] = Quantize((), 8, symmetric=False)
>>> with qlinear.compute_encodings():
>>>     qlinear(torch.randn(16, 10))
>>> print(qlinear.output_quantizers[0].is_initialized())
True
abstract forward(*args, **kwargs)[source]

Computes a quantized version of the parent module’s forward method.

The forward() method should perform the following logic in order:

  1. Apply existing input quantizers to input tensors

  2. Apply existing param quantizers to the layer’s parameters

  3. Call the inherited torch.nn.Module forward method with quantized inputs and parameters

  4. Apply existing output quantizers to the outputs of the forward method

If all input, output, and parameter quantizers are None, this method will behave exactly the same as its parent module’s forward pass.

Configuration

The quantization behavior of a quantized module is controlled by the quantizers contained within the input, output, and parameter quantizer attributes listed below.

Attribute

Type

Description

input_quantizers

torch.nn.ModuleList

List of quantizers for input tensors

param_quantizers

torch.nn.ModuleDict

Dict mapping parameter names to quantizers

output_quantizers

torch.nn.ModuleList

List of quantizers for output tensors

By assigning and configuring quantizers to these structures, we define the type of quantization applied to the corresponding input index, output index, or parameter name. By default, all the quantizers are set to None, meaning that no quantization will be applied to the respective tensor.

Example: Create a linear layer which performs only per-channel weight quantization
>>> import aimet_torch
>>> import aimet_torch.quantization as Q
>>> qlinear = aimet_torch.nn.QuantizedLinear(out_features=10, in_features=5)
>>> # Per-channel weight quantization is performed over the `out_features` dimension, so encodings are shape (10, 1)
>>> per_channel_quantizer = Q.affine.QuantizeDequantize(shape=(10, 1), bitwidth=8, symmetric=True)
>>> qlinear.param_quantizers["weight"] = per_channel_quantizer
Example: Create an elementwise multiply layer which quantizes only the output and the second input
>>> qmul = aimet_torch.nn.custom.QuantizedMultiply()
>>> qmul.output_quantizers[0] = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False)
>>> qmul.input_quantizers[1] = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False)

In some cases, it may make sense for multiple tensors to share the same quantizer. In this case, we can assign the same quantizer to multiple indices.

Example: Create an elementwise add layer which shares the same quantizer between its inputs
>>> qadd = aimet_torch.nn.custom.QuantizedAdd()
>>> quantizer = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False)
>>> qadd.input_quantizers[0] = quantizer
>>> qadd.input_quantizers[1] = quantizer

Computing encodings

Before a module can compute a quantized forward pass, all quantizers must first be calibrated inside a compute_encodings context. When a quantized module enters the compute_encodings context, it first disables all input and output quantization while the quantizers observe the statistics of the activation tensors passing through them. Upon exiting the context, the quantizers calculate appropriate quantization encodings based on these statistics (exactly how the encodings are computed is determined by each quantizer’s encoding analyzer).

Example:
>>> qlinear = aimet_torch.nn.QuantizedLinear(out_features=10, in_features=5)
>>> qlinear.output_quantizers[0] = Q.affine.QuantizeDequantize((1, ), bitwidth=8, symmetric=False)
>>> qlinear.param_quantizers[0] = Q.affine.QuantizeDequantize((10, 1), bitwidth=8, symmetric=True)
>>> with qlinear.compute_encodings():
...     # Pass several samples through the layer to ensure representative statistics
...     for x, _ in calibration_data_loader:
...         qlinear(x)
>>> print(qlinear.output_quantizers[0].is_initialized())
True
>>> print(qlinear.param_quantizers["weight"].is_initialized())
True

API reference

Built-in quantized modules

QuantizationMixin

Quantization mixin class for torch.nn.Module.

QuantizedAdaptiveAvgPool1d

Quantized subclass of torch.nn.AdaptiveAvgPool1d

QuantizedAdaptiveAvgPool2d

Quantized subclass of torch.nn.AdaptiveAvgPool2d

QuantizedAdaptiveAvgPool3d

Quantized subclass of torch.nn.AdaptiveAvgPool3d

QuantizedAdaptiveMaxPool1d

Quantized subclass of torch.nn.AdaptiveMaxPool1d

QuantizedAdaptiveMaxPool2d

Quantized subclass of torch.nn.AdaptiveMaxPool2d

QuantizedAdaptiveMaxPool3d

Quantized subclass of torch.nn.AdaptiveMaxPool3d

QuantizedAlphaDropout

Quantized subclass of torch.nn.AlphaDropout

QuantizedAvgPool1d

Quantized subclass of torch.nn.AvgPool1d

QuantizedAvgPool2d

Quantized subclass of torch.nn.AvgPool2d

QuantizedAvgPool3d

Quantized subclass of torch.nn.AvgPool3d

QuantizedBCELoss

Quantized subclass of torch.nn.BCELoss

QuantizedBCEWithLogitsLoss

Quantized subclass of torch.nn.BCEWithLogitsLoss

QuantizedBatchNorm1d

Quantized subclass of torch.nn.BatchNorm1d

QuantizedBatchNorm2d

Quantized subclass of torch.nn.BatchNorm2d

QuantizedBatchNorm3d

Quantized subclass of torch.nn.BatchNorm3d

QuantizedBilinear

Quantized subclass of torch.nn.Bilinear

QuantizedCELU

Quantized subclass of torch.nn.CELU

QuantizedCTCLoss

Quantized subclass of torch.nn.CTCLoss

QuantizedChannelShuffle

Quantized subclass of torch.nn.ChannelShuffle

QuantizedCircularPad1d

Quantized subclass of torch.nn.CircularPad1d

QuantizedCircularPad2d

Quantized subclass of torch.nn.CircularPad2d

QuantizedCircularPad3d

Quantized subclass of torch.nn.CircularPad3d

QuantizedConstantPad1d

Quantized subclass of torch.nn.ConstantPad2d

QuantizedConstantPad2d

Quantized subclass of torch.nn.ConstantPad2d

QuantizedConstantPad3d

Quantized subclass of torch.nn.ConstantPad3d

QuantizedConv1d

Quantized subclass of torch.nn.Conv1d

QuantizedConv2d

Quantized subclass of torch.nn.Conv2d

QuantizedConv3d

Quantized subclass of torch.nn.Conv3d

QuantizedConvTranspose1d

Quantized subclass of torch.nn.ConvTranspose1d

QuantizedConvTranspose2d

Quantized subclass of torch.nn.ConvTranspose2d

QuantizedConvTranspose3d

Quantized subclass of torch.nn.ConvTranspose3d

QuantizedCosineEmbeddingLoss

Quantized subclass of torch.nn.CosineEmbeddingLoss

QuantizedCosineSimilarity

Quantized subclass of torch.nn.CosineSimilarity

QuantizedCrossEntropyLoss

Quantized subclass of torch.nn.CrossEntropyLoss

QuantizedDropout

Quantized subclass of torch.nn.Dropout

QuantizedDropout1d

Quantized subclass of torch.nn.Dropout1d

QuantizedDropout2d

Quantized subclass of torch.nn.Dropout2d

QuantizedDropout3d

Quantized subclass of torch.nn.Dropout3d

QuantizedELU

Quantized subclass of torch.nn.ELU

QuantizedEmbedding

Quantized subclass of torch.nn.Embedding

QuantizedEmbeddingBag

Quantized subclass of torch.nn.EmbeddingBag

QuantizedFeatureAlphaDropout

Quantized subclass of torch.nn.FeatureAlphaDropout

QuantizedFlatten

Quantized subclass of torch.nn.Flatten

QuantizedFold

Quantized subclass of torch.nn.Fold

QuantizedFractionalMaxPool2d

Quantized subclass of torch.nn.FractionalMaxPool2d

QuantizedFractionalMaxPool3d

Quantized subclass of torch.nn.FractionalMaxPool3d

QuantizedGELU

Quantized subclass of torch.nn.GELU

QuantizedGLU

Quantized subclass of torch.nn.GLU

QuantizedGRU

Quantized subclass of torch.nn.GRU

QuantizedGRUCell

Quantized subclass of torch.nn.GRUCell

QuantizedGaussianNLLLoss

Quantized subclass of torch.nn.GaussianNLLLoss

QuantizedGroupNorm

Quantized subclass of torch.nn.GroupNorm

QuantizedHardshrink

Quantized subclass of torch.nn.Hardshrink

QuantizedHardsigmoid

Quantized subclass of torch.nn.Hardsigmoid

QuantizedHardswish

Quantized subclass of torch.nn.Hardswish

QuantizedHardtanh

Quantized subclass of torch.nn.Hardtanh

QuantizedHingeEmbeddingLoss

Quantized subclass of torch.nn.HingeEmbeddingLoss

QuantizedHuberLoss

Quantized subclass of torch.nn.HuberLoss

QuantizedInstanceNorm1d

Quantized subclass of torch.nn.InstanceNorm1d

QuantizedInstanceNorm2d

Quantized subclass of torch.nn.InstanceNorm2d

QuantizedInstanceNorm3d

Quantized subclass of torch.nn.InstanceNorm3d

QuantizedKLDivLoss

Quantized subclass of torch.nn.KLDivLoss

QuantizedL1Loss

Quantized subclass of torch.nn.L1Loss

QuantizedLPPool1d

Quantized subclass of torch.nn.LPPool1d

QuantizedLPPool2d

Quantized subclass of torch.nn.LPPool2d

QuantizedLSTM

Quantized subclass of torch.nn.LSTM

QuantizedLSTMCell

Quantized subclass of torch.nn.LSTMCell

QuantizedLayerNorm

Quantized subclass of torch.nn.LayerNorm

QuantizedLeakyReLU

Quantized subclass of torch.nn.LeakyReLU

QuantizedLinear

Quantized subclass of torch.nn.Linear

QuantizedLocalResponseNorm

Quantized subclass of torch.nn.LocalResponseNorm

QuantizedLogSigmoid

Quantized subclass of torch.nn.LogSigmoid

QuantizedLogSoftmax

Quantized subclass of torch.nn.LogSoftmax

QuantizedMSELoss

Quantized subclass of torch.nn.MSELoss

QuantizedMarginRankingLoss

Quantized subclass of torch.nn.MarginRankingLoss

QuantizedMaxPool1d

Quantized subclass of torch.nn.MaxPool1d

QuantizedMaxPool2d

Quantized subclass of torch.nn.MaxPool2d

QuantizedMaxPool3d

Quantized subclass of torch.nn.MaxPool3d

QuantizedMaxUnpool1d

Quantized subclass of torch.nn.MaxUnpool1d

QuantizedMaxUnpool2d

Quantized subclass of torch.nn.MaxUnpool2d

QuantizedMaxUnpool3d

Quantized subclass of torch.nn.MaxUnpool3d

QuantizedMish

Quantized subclass of torch.nn.Mish

QuantizedMultiLabelMarginLoss

Quantized subclass of torch.nn.MultiLabelMarginLoss

QuantizedMultiLabelSoftMarginLoss

Quantized subclass of torch.nn.MultiLabelSoftMarginLoss

QuantizedMultiMarginLoss

Quantized subclass of torch.nn.MultiMarginLoss

QuantizedNLLLoss

Quantized subclass of torch.nn.NLLLoss

QuantizedNLLLoss2d

Quantized subclass of torch.nn.NLLLoss2d

QuantizedPReLU

Quantized subclass of torch.nn.PReLU

QuantizedPairwiseDistance

Quantized subclass of torch.nn.PairwiseDistance

QuantizedPixelShuffle

Quantized subclass of torch.nn.PixelShuffle

QuantizedPixelUnshuffle

Quantized subclass of torch.nn.PixelUnshuffle

QuantizedPoissonNLLLoss

Quantized subclass of torch.nn.PoissonNLLLoss

QuantizedRNN

Quantized subclass of torch.nn.RNN

QuantizedRNNCell

Quantized subclass of torch.nn.RNNCell

QuantizedRReLU

Quantized subclass of torch.nn.RReLU

QuantizedReLU

Quantized subclass of torch.nn.ReLU

QuantizedReLU6

Quantized subclass of torch.nn.ReLU6

QuantizedReflectionPad1d

Quantized subclass of torch.nn.ReflectionPad1d

QuantizedReflectionPad2d

Quantized subclass of torch.nn.ReflectionPad2d

QuantizedReflectionPad3d

Quantized subclass of torch.nn.ReflectionPad3d

QuantizedReplicationPad1d

Quantized subclass of torch.nn.ReplicationPad1d

QuantizedReplicationPad2d

Quantized subclass of torch.nn.ReplicationPad2d

QuantizedReplicationPad3d

Quantized subclass of torch.nn.ReplicationPad3d

QuantizedSELU

Quantized subclass of torch.nn.SELU

QuantizedSiLU

Quantized subclass of torch.nn.SiLU

QuantizedSigmoid

Quantized subclass of torch.nn.Sigmoid

QuantizedSmoothL1Loss

Quantized subclass of torch.nn.SmoothL1Loss

QuantizedSoftMarginLoss

Quantized subclass of torch.nn.SoftMarginLoss

QuantizedSoftmax

Quantized subclass of torch.nn.Softmax

QuantizedSoftmax2d

Quantized subclass of torch.nn.Softmax2d

QuantizedSoftmin

Quantized subclass of torch.nn.Softmin

QuantizedSoftplus

Quantized subclass of torch.nn.Softplus

QuantizedSoftshrink

Quantized subclass of torch.nn.Softshrink

QuantizedSoftsign

Quantized subclass of torch.nn.Softsign

QuantizedTanh

Quantized subclass of torch.nn.Tanh

QuantizedTanhshrink

Quantized subclass of torch.nn.Tanhshrink

QuantizedThreshold

Quantized subclass of torch.nn.Threshold

QuantizedTripletMarginLoss

Quantized subclass of torch.nn.TripletMarginLoss

QuantizedTripletMarginWithDistanceLoss

Quantized subclass of torch.nn.TripletMarginWithDistanceLoss

QuantizedUnflatten

Quantized subclass of torch.nn.Unflatten

QuantizedUnfold

Quantized subclass of torch.nn.Unfold

QuantizedUpsample

Quantized subclass of torch.nn.Upsample

QuantizedUpsamplingBilinear2d

Quantized subclass of torch.nn.UpsamplingBilinear2d

QuantizedUpsamplingNearest2d

Quantized subclass of torch.nn.UpsamplingNearest2d

QuantizedZeroPad1d

Quantized subclass of torch.nn.ZeroPad1d

QuantizedZeroPad2d

Quantized subclass of torch.nn.ZeroPad2d

QuantizedZeroPad3d

Quantized subclass of torch.nn.ZeroPad3d