QuantizationMixin
- class aimet_torch.v2.nn.QuantizationMixin(*args, **kwargs)[source]
Mixin that adds quantization functionality on top of regular pytorch modules.
Specifically, a quantized module will quantize input, output, and parameter tensors with its held
QuantizerBase
objects during theforward()
method and use the inheritedtorch.nn.Module
forward method to compute the layer operation. If all input, output, and parameter quantizers areNone
, a quantized module will behave exactly the same as its parenttorch.nn.Module
.- input_quantizers
torch.nn.ModuleList
containingQuantizerBase
objects to be applied to the layer’s input tensors
- output_quantizers
torch.nn.ModuleList
containingQuantizerBase
objects to be applied to the layer’s output tensors
- param_quantizers
torch.nn.ModuleDict
mapping parameter names to associatedQuantizerBase
objects
Examples
>>> qlinear = QuantizedLinear(in_features=10, out_features=10) >>> print(qlinear) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) )
- abstract forward(*args, **kwargs)[source]
Computes a quantized version of the parent module’s forward method.
The
forward()
method should perform the following logic in order:Apply existing input quantizers to input tensors
Apply existing param quantizers to the layer’s parameters
Call the inherited
torch.nn.Module
forward method with quantized inputs and parametersApply existing output quantizers to the outputs of the forward method
If all input, output, and parameter quantizers are
None
, this method will behave exactly the same as its parent module’s forward pass.
- __quant_init__()
Initializer for quantized module. This method will be invoked right after
__init__()
.This method initializes the
input_quantizers
,output_quantizers
, andparam_quantizers
structures to the appropriate sizes based on the number of input tensors, output tensors, and parameters of the basenn.Module
class. All quantizers are initializd toNone
.For custom quantized classes, this method should be overridden to set the appropriate lengths of
input_quantizers
andoutput_quantizers
for the given base class.
- set_kernel(kernel)[source]
Set kernel for this instance of quantized module.
In general, this signature will follow the signature of the equivalent
torch.nn.functional
function, but should return aQuantizedTensor
object and take in the additional keyword argumentoutput_encodings
.Once set, the layer will call into
kernel
in the forward pass unless within thecompute_encodings()
context.- Parameters:
kernel – Callable object to be used as the underlying kernel.
Example
>>> from aimet_torch.v2 import quantization as Q >>> def int_multiply(a, b, output_encodings=None): ... encodings = [a.encoding, b.encoding, output_encodings] ... if not all(enc.mapping == "affine" for enc in encodings): ... raise NotImplementedError ... q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset) ... dq_output = q_output * (a.encoding.scale * b.encoding.scale) ... return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings) ... >>> qmult = QuantizedMultiply() >>> qmult.set_kernel(int_multiply)
- classmethod set_default_kernel(kernel)[source]
Set default kernel for the class.
In general, this signature will follow the signature of the equivalent
torch.nn.functional
function, but should return aQuantizedTensor
object and take in the additional keyword argumentoutput_encodings
.Once set, all instances of cls will call into kernel in the forward pass unless:
The instance is within the
compute_encodings()
context, orThe kernel has been overridden by a
set_kernel()
call
- Parameters:
kernel – Callable object to be used as the default kernel by all the instances of this class.
Example
>>> from aimet_torch.v2 import quantization as Q >>> def int_multiply(a, b, output_encodings=None): ... encodings = [a.encoding, b.encoding, output_encodings] ... if not all(enc.mapping == "affine" for enc in encodings): ... raise NotImplementedError ... q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset) ... dq_output = q_output * (a.encoding.scale * b.encoding.scale) ... return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings) ... >>> QuantizedMultiply.set_default_kernel(int_multiply) >>> qmult = QuantizedMultiply() >>> qmult.get_kernel() <function int_multiply at ...>
- compute_encodings()[source]
Enters the
compute_encodings()
context for allQuantizerBase
objects in the layer.Inside this context, each quantizer will observe all inputs passed to the quantizer and will compute quantization encodings upon exiting the context.
Example
>>> qlinear = QuantizedLinear(10, 10) >>> qlinear.output_quantizers[0] = Quantize((), 8, symmetric=False) >>> with qlinear.compute_encodings(): >>> qlinear(torch.randn(16, 10)) >>> print(qlinear.output_quantizers[0].is_initialized()) True
- classmethod from_module(module)[source]
Create an instance of quantized module from a regular module instance.
The resulting quantized module contains the same attributes and parameters as the original module, but may be assigned input, output and parameter quantizers.
- Parameters:
module (
Module
) – Floating point module to quantize- Returns:
Quantized version of the original module
Example
>>> linear = torch.nn.Linear(10, 10) >>> quantized_linear = QuantizationMixin.from_module(linear) >>> print(quantized_linear.param_quantizers) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) ) >>> print(quantized_linear.weight is linear.weight) True
- classmethod get_default_kernel()[source]
Return the default kernel of the class
- Return type:
Optional
[Callable
]- Returns:
Default kernel of the class. None if the default kernel is not set.
- get_kernel()[source]
Return the kernel to be used by this instance of quantized module.
If the current instance does not have any kernel set, it will retrieve the default kernel of the class.
- Return type:
Optional
[Callable
]- Returns:
The kernel to be used by this instance.
- classmethod implements(module_cls)[source]
Decorator for registering quantized definition of the given base class.
Even though AIMET supports quantization of all built-in modules in torch.nn subpackage such as
torch.nn.Conv2d
ortorch.nn.Linear
that AIMET is already aware of,QuantizationSimModel
will throw a runtime error when it encounters custom modules defined by the users, asking the users to provide the quantized definition of the custom modules that AIMET doesn’t know of.To declare the quantized definition of a module,
QuantizationSimModel
requires you to define a subclass of your module decorated withimplements()
, in which you will implement__quant_init__
andforward
methods.As an example, given a custom module as below:
class MaskedAdd(torch.nn.Module): def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor): return input + mask * value
its quantized definition should be declared before creating
QuantizationSimModel
, typically as below:@QuantizationMixin.implements(MaskedAdd) class QuantizedMaskedAdd(QuantizationMixin, MaskedAdd): # The quantized definition of MaskedAdd should be a subclass of # QuantizationMixin and MaskedAdd (Order matters!) def __quant_init__(self): super().__quant_init__() # Declare the number of input/output quantizers self.input_quantizers = torch.nn.ModuleList([None, None, None]) self.output_quantizers = torch.nn.ModuleList([None]) def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor): input_qtzr = self.input_quantizers[0] _ = self.input_quantizers[1] # I don't want to quantize the boolean masks! value_qtzr = self.input_quantizers[2] output_qtzr = self.output_quantizers[0] if input_qtzr is not None: input = input_qtzr(input) if value_qtzr is not None: value = value_qtzr(value) output = super().forward(input, mask, value) if output_qtzr is not None: output = output_qtzr(output) return output