FloatQuantizeDequantize

class aimet_torch.quantization.float.FloatQuantizeDequantize(exponent_bits=None, mantissa_bits=None, dtype=None, encoding_analyzer=None)[source]

Simulates quantization by fake-casting the input

If dtype is provided, this is equivalent to

\[\begin{split}out = x.to(dtype).to(x.dtype) \\\end{split}\]

If the exponent and mantissa bits are provided, this is equivalent to

\[out = \left\lceil\frac{x_c}{scale}\right\rfloor * scale\]

where

\[\begin{split}x_c &= clamp(x, -max, max) \\ bias &= 2^{exponent} - \log_2(max) + \log_2(2 - 2^{-mantissa}) - 1 \\ scale &= 2 ^ {\left\lfloor \log_2 |x_c| + bias \right\rfloor - mantissa - bias} \\\end{split}\]

The IEEE standard computes the maximum representable value by

\[\begin{split}max = (2 - 2^{-mantissa}) * 2^{(\left\lfloor 0.5 * exponent\_max \right\rfloor)} \\\end{split}\]

where

\[\begin{split}exponent\_max = 2^{exponent} - 1 \\\end{split}\]
Parameters:
  • exponent_bits (int) – Number of exponent bits to simulate

  • mantissa_bits (int) – Number of mantissa bits to simulate

  • dtype (torch.dtype) – torch.dtype to simulate. This argument is mutually exclusive with exponent_bits and mantissa_bits.

  • encoding_analyzer (EncodingAnalyzer) – If specified, the maximum value to represent will be determined dynamically based on the input statistics for finer precision.

Examples

>>> import aimet_torch.v2.quantization as Q
>>> input = torch.tensor([[ 1.8998, -0.0947],[-1.0891, -0.1727]])
>>> qdq = Q.float.FloatQuantizeDequantize(mantissa_bits=7, exponent_bits=8)
>>> # Unlike AffineQuantizer, FloatQuantizer is initialized without calling compute_encodings()
>>> qdq.is_initialized()
True
>>> qdq.is_bfloat16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8984, -0.0947], [-1.0859, -0.1729]])
>>> from aimet_torch.v2.quantization.encoding_analyzer import MinMaxEncodingAnalyzer
>>> encoding_analyzer = MinMaxEncodingAnalyzer(shape=[])
>>> qdq = Q.float.FloatQuantizeDequantize(dtype=torch.float16, encoding_analyzer=encoding_analyzer)
>>> qdq.is_float16()
True
>>> qdq.bitwidth
16
>>> qdq(input)
tensor([[ 1.8994, -0.0947], [-1.0889, -0.1727]])
property bitwidth

Returns bitwidth of the quantizer

compute_encodings()[source]

Observe inputs and update quantization parameters based on the input statistics. During compute_encodings is enabled, the quantizer forward pass performs dynamic quantization using the batch statistics.

forward(input)[source]
Parameters:

input (Tensor) – Input to quantize and dequantize

Returns:

Quantize-dequantized output

get_encodings()[source]

Return the quantizer’s encodings as an EncodingBase object

Return type:

Optional[FloatEncoding]

get_extra_state()[source]

Get extra state that describes which parameters are initialized.

is_bfloat16()[source]

Returns true if current configuration simulates bfloat16

is_float16()[source]

Returns true if current configuration simulates IEEE float16

load_state_dict(state_dict, strict=True)[source]

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

set_extra_state(state)[source]

Set extra state that describes which parameters are initialized.