QuantizationMixin

class aimet_torch.nn.QuantizationMixin(*args, **kwargs)[source]

Quantization mixin class for torch.nn.Module.

Specifically, a quantized module will quantize input, output, and parameter tensors with its held QuantizerBase objects during the forward() method and use the inherited torch.nn.Module forward method to compute the layer operation. If all input, output, and parameter quantizers are None, a quantized module will behave exactly the same as its parent torch.nn.Module.

input_quantizers

torch.nn.ModuleList containing QuantizerBase objects to be applied to the layer’s input tensors

output_quantizers

torch.nn.ModuleList containing QuantizerBase objects to be applied to the layer’s output tensors

param_quantizers

torch.nn.ModuleDict mapping parameter names to associated QuantizerBase objects

Examples

>>> qlinear = QuantizedLinear(in_features=10, out_features=10)
>>> print(qlinear)
QuantizedLinear(
  in_features=10, out_features=10, bias=True
  (param_quantizers): ModuleDict(
    (weight): None
    (bias): None
  )
  (input_quantizers): ModuleList(
    (0): None
  )
  (output_quantizers): ModuleList(
    (0): None
  )
)
compute_encodings()[source]

Enters the compute_encodings() context for all QuantizerBase objects in the layer.

Inside this context, each quantizer will observe all inputs passed to the quantizer and will compute quantization encodings upon exiting the context.

Example

>>> qlinear = QuantizedLinear(10, 10)
>>> qlinear.output_quantizers[0] = Quantize((), 8, symmetric=False)
>>> with qlinear.compute_encodings():
>>>     qlinear(torch.randn(16, 10))
>>> print(qlinear.output_quantizers[0].is_initialized())
True
abstract forward(*args, **kwargs)[source]

Computes a quantized version of the parent module’s forward method.

The forward() method should perform the following logic in order:

  1. Apply existing input quantizers to input tensors

  2. Apply existing param quantizers to the layer’s parameters

  3. Call the inherited torch.nn.Module forward method with quantized inputs and parameters

  4. Apply existing output quantizers to the outputs of the forward method

If all input, output, and parameter quantizers are None, this method will behave exactly the same as its parent module’s forward pass.

classmethod from_module(module)[source]

Create an instance of quantized module from a regular module instance.

The resulting quantized module contains the same attributes and parameters as the original module, but may be assigned input, output and parameter quantizers.

Parameters:

module (Module) – Floating point module to quantize

Returns:

Quantized version of the original module

Example

>>> linear = torch.nn.Linear(10, 10)
>>> quantized_linear = QuantizationMixin.from_module(linear)
>>> print(quantized_linear.param_quantizers)
QuantizedLinear(
  in_features=10, out_features=10, bias=True
  (param_quantizers): ModuleDict(
    (weight): None
    (bias): None
  )
  (input_quantizers): ModuleList(
    (0): None
  )
  (output_quantizers): ModuleList(
    (0): None
  )
)
>>> print(quantized_linear.weight is linear.weight)
True
classmethod get_default_kernel()[source]

Return the default kernel of the class

Return type:

Optional[Callable]

Returns:

Default kernel of the class. None if the default kernel is not set.

get_kernel()[source]

Return the kernel to be used by this instance of quantized module.

If the current instance does not have any kernel set, it will retrieve the default kernel of the class.

Return type:

Optional[Callable]

Returns:

The kernel to be used by this instance.

classmethod implements(module_cls)[source]

Decorator for registering quantized definition of the given base class.

Even though AIMET supports quantization of all built-in modules in torch.nn subpackage such as torch.nn.Conv2d or torch.nn.Linear that AIMET is already aware of, QuantizationSimModel will throw a runtime error when it encounters custom modules defined by the users, asking the users to provide the quantized definition of the custom modules that AIMET doesn’t know of.

To declare the quantized definition of a module, QuantizationSimModel requires you to define a subclass of your module decorated with implements(), in which you will implement __quant_init__ and forward methods.

As an example, given a custom module as below:

class MaskedAdd(torch.nn.Module):
   def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor):
       return input + mask * value

its quantized definition should be declared before creating QuantizationSimModel, typically as below:

@QuantizationMixin.implements(MaskedAdd)
class QuantizedMaskedAdd(QuantizationMixin, MaskedAdd):
    # The quantized definition of MaskedAdd should be a subclass of
    # QuantizationMixin and MaskedAdd (Order matters!)
    def __quant_init__(self):
        super().__quant_init__()

        # Declare the number of input/output quantizers
        self.input_quantizers = torch.nn.ModuleList([None, None, None])
        self.output_quantizers = torch.nn.ModuleList([None])

   def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor):
       input_qtzr  = self.input_quantizers[0]
       _           = self.input_quantizers[1] # I don't want to quantize the boolean masks!
       value_qtzr  = self.input_quantizers[2]
       output_qtzr = self.output_quantizers[0]

       if input_qtzr is not None:
           input = input_qtzr(input)

       if value_qtzr is not None:
           value = value_qtzr(value)

       output = super().forward(input, mask, value)

       if output_qtzr is not None:
           output = output_qtzr(output)

       return output
classmethod set_default_kernel(kernel)[source]

Set default kernel for the class.

In general, this signature will follow the signature of the equivalent torch.nn.functional function, but should return a QuantizedTensor object and take in the additional keyword argument output_encodings.

Once set, all instances of cls will call into kernel in the forward pass unless:

  1. The instance is within the compute_encodings() context, or

  2. The kernel has been overridden by a set_kernel() call

Parameters:

kernel – Callable object to be used as the default kernel by all the instances of this class.

Example

>>> from aimet_torch.v2 import quantization as Q
>>> def int_multiply(a, b, output_encodings=None):
...     encodings = [a.encoding, b.encoding, output_encodings]
...     if not all(enc.mapping == "affine" for enc in encodings):
...             raise NotImplementedError
...     q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset)
...     dq_output = q_output *  (a.encoding.scale * b.encoding.scale)
...     return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings)
...
>>> QuantizedMultiply.set_default_kernel(int_multiply)
>>> qmult = QuantizedMultiply()
>>> qmult.get_kernel()
<function int_multiply at ...>
set_kernel(kernel)[source]

Set kernel for this instance of quantized module.

In general, this signature will follow the signature of the equivalent torch.nn.functional function, but should return a QuantizedTensor object and take in the additional keyword argument output_encodings.

Once set, the layer will call into kernel in the forward pass unless within the compute_encodings() context.

Parameters:

kernel – Callable object to be used as the underlying kernel.

Example

>>> from aimet_torch.v2 import quantization as Q
>>> def int_multiply(a, b, output_encodings=None):
...     encodings = [a.encoding, b.encoding, output_encodings]
...     if not all(enc.mapping == "affine" for enc in encodings):
...             raise NotImplementedError
...     q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset)
...     dq_output = q_output *  (a.encoding.scale * b.encoding.scale)
...     return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings)
...
>>> qmult = QuantizedMultiply()
>>> qmult.set_kernel(int_multiply)
classmethod wrap(module_cls)[source]

Wrap a regular module class into a quantized module class

Return type:

Type[Module]