FakeQuantizationMixin

class aimet_torch.v2.nn.FakeQuantizationMixin(*args, **kwargs)[source]

Mixin that implements fake-quantization on top of regular pytorch modules.

Specifically, a fake-quantized module will quantize input, output, and parameter tensors with its held QuantizerBase objects during the forward() method and use the inherited :class:torch.nn.Module` forward method to compute the layer operation. If all input, output, and parameter quantizers are None, a fake-quantized module will behave exactly the same as its parent torch.nn.Module.

A fake-quantized module can be initialized from scratch using the same syntax as the parent module, or can be formed from an existing module using the from_module() method.

input_quantizers

ModuleList containing QuantizerBase objects to be applied to the layer’s input tensors

Type:

nn.ModuleList

output_quantizers

ModuleList containing QuantizerBase objects to be applied to the layer’s output tensors

Type:

nn.ModuleList

param_quantizers

ModuleDict mapping parameter names to associated QuantizerBase objects

Type:

nn.ModuleDict

Examples

>>> qlinear = FakeQuantizedLinear(in_features=10, out_features=20, bias=False)
>>> print(qlinear)
FakeQuantizedLinear(
  in_features=10, out_features=20, bias=False
  (param_quantizers): ModuleDict(
    (weight): None
  )
  (input_quantizers): ModuleList(
    (0): None
  )
  (output_quantizers): ModuleList(
    (0): None
  )
)
>>> linear = torch.nn.Linear(in_features=10, out_features=20, bias=True)
>>> qlinear = FakeQuantizationMixin.from_module(linear)
>>> print(qlinear)
FakeQuantizedLinear(
  in_features=10, out_features=20, bias=True
  (param_quantizers): ModuleDict(
    (weight): None
    (bias): None
  )
  (input_quantizers): ModuleList(
    (0): None
  )
  (output_quantizers): ModuleList(
    (0): None
  )
)
>>> qlinear.weight is linear.weight
True
abstract forward(*args, **kwargs)[source]

Computes a fake-quantized version of the parent module’s forward method.

The forward() method should perform the following logic in order:

  1. Apply existing input quantizers to input tensors

  2. Apply existing param quantizers to the layer’s parameters

  3. Call the inherited torch.nn.Module forward method with quantized inputs and parameters

  4. Apply existing output quantizers to the outputs of the forward method

If all input, output, and parameter quantizers are None, this method will behave exactly the same as its parent module’s forward pass.

__quant_init__()

Initializer for quantized module. This method will be invoked right after __init__().

This method initializes the input_quantizers, output_quantizers, and param_quantizers structures to the appropriate sizes based on the number of input tensors, output tensors, and parameters of the base nn.Module class. All quantizers are initializd to None.

For custom quantized classes, this method should be overridden to set the appropriate lengths of input_quantizers and output_quantizers for the given base class.

compute_encodings()

Enters the compute_encodings() context for all QuantizerBase objects in the layer.

Inside this context, each quantizer will observe all inputs passed to the quantizer and will compute quantization encodings upon exiting the context.

Example

>>> qlinear = QuantizedLinear(10, 10)
>>> qlinear.output_quantizers[0] = Quantize((), 8, symmetric=False)
>>> with qlinear.compute_encodings():
>>>     qlinear(torch.randn(16, 10))
>>> print(qlinear.output_quantizers[0].is_initialized())
True
classmethod from_module(module)

Create an instance of quantized module from a regular module instance.

The resulting quantized module contains the same attributes and parameters as the original module, but may be assigned input, output and parameter quantizers.

Parameters:

module (Module) – Floating point module to quantize

Returns:

Quantized version of the original module

Example

>>> linear = torch.nn.linear(10, 10)
>>> quantized_linear = FakeQuantizationMixin.from_module(linear)
>>> print(quantized_linear.weight is linear.weight)
True
>>> print(quantized_linear.param_quantizers)
ModuleDict(
    (weight): None
    (bias): None
)
classmethod implements(module_cls)[source]

Decorator for registering a fake-quantized implementation of the given base class.

This decorator registers the defined class as the fake-quantized version of module_cls such that calling from_module() on an instance of module_cls will output an instance of the decorated class.

Parameters:

module_cls – The base torch.nn.Module class